ngrigg commited on
Commit
059c9d2
1 Parent(s): 784fe5f

Fix padding and truncation issues

Browse files
Files changed (1) hide show
  1. llama_models.py +2 -5
llama_models.py CHANGED
@@ -14,17 +14,14 @@ def load_model(model_name):
14
  if not tokenizer or not model:
15
  print("Loading model and tokenizer...")
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
17
  model = AutoModelForCausalLM.from_pretrained(model_name) # Ensure correct model class
18
  print("Model and tokenizer loaded successfully.")
19
  return tokenizer, model
20
 
21
  async def process_text_local(model_name, text):
22
- print("Loading model and tokenizer...")
23
  tokenizer, model = load_model(model_name)
24
- print("Encoding text...")
25
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
26
- print("Generating output...")
27
  outputs = model.generate(**inputs, max_length=512)
28
- print("Decoding output...")
29
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
  return result
 
14
  if not tokenizer or not model:
15
  print("Loading model and tokenizer...")
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ tokenizer.pad_token = tokenizer.eos_token # Set pad_token to eos_token
18
  model = AutoModelForCausalLM.from_pretrained(model_name) # Ensure correct model class
19
  print("Model and tokenizer loaded successfully.")
20
  return tokenizer, model
21
 
22
  async def process_text_local(model_name, text):
 
23
  tokenizer, model = load_model(model_name)
24
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) # Set max_length to 512
 
 
25
  outputs = model.generate(**inputs, max_length=512)
 
26
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  return result