Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -30,14 +30,10 @@ class OrcaChatBot:
|
|
30 |
self.system_message = system_message
|
31 |
|
32 |
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
|
33 |
-
|
34 |
-
prompt = f"system\n{self.system_message}\nuser\n{user_message}\nassistant"
|
35 |
-
|
36 |
-
# Encode the prompt
|
37 |
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
38 |
input_ids = inputs["input_ids"].to(self.model.device)
|
39 |
|
40 |
-
# Generate a response
|
41 |
output_ids = self.model.generate(
|
42 |
input_ids,
|
43 |
max_length=input_ids.shape[1] + max_new_tokens,
|
@@ -45,10 +41,9 @@ class OrcaChatBot:
|
|
45 |
top_p=top_p,
|
46 |
repetition_penalty=repetition_penalty,
|
47 |
pad_token_id=self.tokenizer.eos_token_id,
|
48 |
-
do_sample=True
|
49 |
)
|
50 |
|
51 |
-
# Decode the generated response
|
52 |
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
53 |
|
54 |
return response
|
@@ -75,5 +70,4 @@ iface = gr.Interface(
|
|
75 |
theme="ParityError/Anime"
|
76 |
)
|
77 |
|
78 |
-
# Launch the Gradio interface
|
79 |
iface.launch()
|
|
|
30 |
self.system_message = system_message
|
31 |
|
32 |
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
|
33 |
+
prompt = f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" if self.conversation_history is None else self.conversation_history + f"<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
|
|
|
|
|
|
|
34 |
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
35 |
input_ids = inputs["input_ids"].to(self.model.device)
|
36 |
|
|
|
37 |
output_ids = self.model.generate(
|
38 |
input_ids,
|
39 |
max_length=input_ids.shape[1] + max_new_tokens,
|
|
|
41 |
top_p=top_p,
|
42 |
repetition_penalty=repetition_penalty,
|
43 |
pad_token_id=self.tokenizer.eos_token_id,
|
44 |
+
do_sample=True
|
45 |
)
|
46 |
|
|
|
47 |
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
48 |
|
49 |
return response
|
|
|
70 |
theme="ParityError/Anime"
|
71 |
)
|
72 |
|
|
|
73 |
iface.launch()
|