Spaces:
Paused
Paused
CreitinGameplays
commited on
Commit
•
7942c52
1
Parent(s):
63bf5db
Update app.py
Browse files
app.py
CHANGED
@@ -25,16 +25,17 @@ if torch.cuda.is_available():
|
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
26 |
tokenizer.use_default_system_prompt = False
|
27 |
|
|
|
28 |
|
29 |
@spaces.GPU(duration=90)
|
30 |
def generate(
|
31 |
message: str,
|
32 |
chat_history: list[tuple[str, str]],
|
33 |
-
system_prompt: str =
|
34 |
max_new_tokens: int = 1024,
|
35 |
temperature: float = 0.6,
|
36 |
-
top_p: float = 0
|
37 |
-
top_k: int =
|
38 |
repetition_penalty: float = 1.2,
|
39 |
) -> Iterator[str]:
|
40 |
conversation = []
|
@@ -50,7 +51,7 @@ def generate(
|
|
50 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
51 |
input_ids = input_ids.to(model.device)
|
52 |
|
53 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=
|
54 |
generate_kwargs = dict(
|
55 |
{"input_ids": input_ids},
|
56 |
streamer=streamer,
|
@@ -74,7 +75,7 @@ def generate(
|
|
74 |
chat_interface = gr.ChatInterface(
|
75 |
fn=generate,
|
76 |
additional_inputs=[
|
77 |
-
gr.Textbox(label="System prompt", lines=6, value=
|
78 |
gr.Slider(
|
79 |
label="Max new tokens",
|
80 |
minimum=1,
|
|
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
26 |
tokenizer.use_default_system_prompt = False
|
27 |
|
28 |
+
system_prompt_text = "You are a helpful, respectful and honest AI assistant named ChatGPT. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to a question, please don’t share false information."
|
29 |
|
30 |
@spaces.GPU(duration=90)
|
31 |
def generate(
|
32 |
message: str,
|
33 |
chat_history: list[tuple[str, str]],
|
34 |
+
system_prompt: str = system_prompt_text,
|
35 |
max_new_tokens: int = 1024,
|
36 |
temperature: float = 0.6,
|
37 |
+
top_p: float = 1.0,
|
38 |
+
top_k: int = 0,
|
39 |
repetition_penalty: float = 1.2,
|
40 |
) -> Iterator[str]:
|
41 |
conversation = []
|
|
|
51 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
52 |
input_ids = input_ids.to(model.device)
|
53 |
|
54 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=5.0, skip_prompt=True, skip_special_tokens=True)
|
55 |
generate_kwargs = dict(
|
56 |
{"input_ids": input_ids},
|
57 |
streamer=streamer,
|
|
|
75 |
chat_interface = gr.ChatInterface(
|
76 |
fn=generate,
|
77 |
additional_inputs=[
|
78 |
+
gr.Textbox(label="System prompt", lines=6, value=system_prompt_text),
|
79 |
gr.Slider(
|
80 |
label="Max new tokens",
|
81 |
minimum=1,
|