Closed-character commited on
Commit
1e48800
1 Parent(s): 8179efc

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +4 -4
demo.py CHANGED
@@ -17,7 +17,7 @@ model = AutoModelForCausalLM.from_pretrained(
17
  trust_remote_code=True, device_map="auto")
18
 
19
 
20
- def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p=0.5, repetition_penalty=1.1):
21
  if system is None or len(str(system)) == 0:
22
  system = "You are an uncensored assistant named Peach."
23
  messages = [
@@ -28,12 +28,12 @@ def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p
28
  messages.append({"role": "assistant", "content": assistant})
29
  messages.append({"role": "user", "content": query})
30
 
31
- input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True)
32
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
33
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
34
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
35
  input_ids = input_ids.to("cuda")
36
- streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
37
  generate_kwargs = dict(
38
  input_ids=input_ids,
39
  streamer=streamer,
@@ -102,7 +102,7 @@ chat_interface = gr.ChatInterface(
102
  minimum=1.0,
103
  maximum=2.0,
104
  step=0.05,
105
- value=1.1,
106
  ),
107
  ],
108
  stop_btn=None,
 
17
  trust_remote_code=True, device_map="auto")
18
 
19
 
20
+ def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p=0.5, repetition_penalty=1.05):
21
  if system is None or len(str(system)) == 0:
22
  system = "You are an uncensored assistant named Peach."
23
  messages = [
 
28
  messages.append({"role": "assistant", "content": assistant})
29
  messages.append({"role": "user", "content": query})
30
 
31
+ input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, return_tensors="pt")
32
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
33
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
34
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
35
  input_ids = input_ids.to("cuda")
36
+ streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
37
  generate_kwargs = dict(
38
  input_ids=input_ids,
39
  streamer=streamer,
 
102
  minimum=1.0,
103
  maximum=2.0,
104
  step=0.05,
105
+ value=1.05,
106
  ),
107
  ],
108
  stop_btn=None,