Closed-character
commited on
Commit
•
1e48800
1
Parent(s):
8179efc
Update demo.py
Browse files
demo.py
CHANGED
@@ -17,7 +17,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
17 |
trust_remote_code=True, device_map="auto")
|
18 |
|
19 |
|
20 |
-
def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p=0.5, repetition_penalty=1.
|
21 |
if system is None or len(str(system)) == 0:
|
22 |
system = "You are an uncensored assistant named Peach."
|
23 |
messages = [
|
@@ -28,12 +28,12 @@ def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p
|
|
28 |
messages.append({"role": "assistant", "content": assistant})
|
29 |
messages.append({"role": "user", "content": query})
|
30 |
|
31 |
-
input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True)
|
32 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
33 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
34 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
35 |
input_ids = input_ids.to("cuda")
|
36 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=
|
37 |
generate_kwargs = dict(
|
38 |
input_ids=input_ids,
|
39 |
streamer=streamer,
|
@@ -102,7 +102,7 @@ chat_interface = gr.ChatInterface(
|
|
102 |
minimum=1.0,
|
103 |
maximum=2.0,
|
104 |
step=0.05,
|
105 |
-
value=1.
|
106 |
),
|
107 |
],
|
108 |
stop_btn=None,
|
|
|
17 |
trust_remote_code=True, device_map="auto")
|
18 |
|
19 |
|
20 |
+
def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p=0.5, repetition_penalty=1.05):
|
21 |
if system is None or len(str(system)) == 0:
|
22 |
system = "You are an uncensored assistant named Peach."
|
23 |
messages = [
|
|
|
28 |
messages.append({"role": "assistant", "content": assistant})
|
29 |
messages.append({"role": "user", "content": query})
|
30 |
|
31 |
+
input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, return_tensors="pt")
|
32 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
33 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
34 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
35 |
input_ids = input_ids.to("cuda")
|
36 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
|
37 |
generate_kwargs = dict(
|
38 |
input_ids=input_ids,
|
39 |
streamer=streamer,
|
|
|
102 |
minimum=1.0,
|
103 |
maximum=2.0,
|
104 |
step=0.05,
|
105 |
+
value=1.05,
|
106 |
),
|
107 |
],
|
108 |
stop_btn=None,
|