sudokien commited on
Commit
42eec75
1 Parent(s): 03f4e17

Only do model name translation for Llama-2 and CodeLlama

Browse files

The function translate_llama2 adds '-hf' suffix to all meta-llama models, which is incorrect. This commit fixes that.

Files changed (1) hide show
  1. src/model_utils.py +4 -4
src/model_utils.py CHANGED
@@ -27,8 +27,8 @@ def extract_from_url(name: str):
27
  return path[1:]
28
 
29
 
30
- def translate_llama2(text):
31
- "Translates llama-2 to its hf counterpart"
32
  if not text.endswith("-hf"):
33
  return text + "-hf"
34
  return text
@@ -36,8 +36,8 @@ def translate_llama2(text):
36
 
37
  def get_model(model_name: str, library: str, access_token: str):
38
  "Finds and grabs model from the Hub, and initializes on `meta`"
39
- if "meta-llama" in model_name:
40
- model_name = translate_llama2(model_name)
41
  if library == "auto":
42
  library = None
43
  model_name = extract_from_url(model_name)
 
27
  return path[1:]
28
 
29
 
30
+ def translate_llama(text):
31
+ "Translates Llama-2 and CodeLlama to its hf counterpart"
32
  if not text.endswith("-hf"):
33
  return text + "-hf"
34
  return text
 
36
 
37
  def get_model(model_name: str, library: str, access_token: str):
38
  "Finds and grabs model from the Hub, and initializes on `meta`"
39
+ if "meta-llama/Llama-2-" in model_name or "meta-llama/CodeLlama-" in model_name:
40
+ model_name = translate_llama(model_name)
41
  if library == "auto":
42
  library = None
43
  model_name = extract_from_url(model_name)