Tonic commited on
Commit
24f6499
1 Parent(s): d755450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -8
app.py CHANGED
@@ -30,14 +30,10 @@ class OrcaChatBot:
30
  self.system_message = system_message
31
 
32
  def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
33
- # Prepare the prompt
34
- prompt = f"system\n{self.system_message}\nuser\n{user_message}\nassistant"
35
-
36
- # Encode the prompt
37
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
38
  input_ids = inputs["input_ids"].to(self.model.device)
39
 
40
- # Generate a response
41
  output_ids = self.model.generate(
42
  input_ids,
43
  max_length=input_ids.shape[1] + max_new_tokens,
@@ -45,10 +41,9 @@ class OrcaChatBot:
45
  top_p=top_p,
46
  repetition_penalty=repetition_penalty,
47
  pad_token_id=self.tokenizer.eos_token_id,
48
- do_sample=True # Enable sampling-based generation
49
  )
50
 
51
- # Decode the generated response
52
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
53
 
54
  return response
@@ -75,5 +70,4 @@ iface = gr.Interface(
75
  theme="ParityError/Anime"
76
  )
77
 
78
- # Launch the Gradio interface
79
  iface.launch()
 
30
  self.system_message = system_message
31
 
32
  def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
33
+ prompt = f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" if self.conversation_history is None else self.conversation_history + f"<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
 
 
 
34
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
35
  input_ids = inputs["input_ids"].to(self.model.device)
36
 
 
37
  output_ids = self.model.generate(
38
  input_ids,
39
  max_length=input_ids.shape[1] + max_new_tokens,
 
41
  top_p=top_p,
42
  repetition_penalty=repetition_penalty,
43
  pad_token_id=self.tokenizer.eos_token_id,
44
+ do_sample=True
45
  )
46
 
 
47
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
48
 
49
  return response
 
70
  theme="ParityError/Anime"
71
  )
72
 
 
73
  iface.launch()