File size: 4,598 Bytes
f15ed8e
 
 
 
f5f6359
6250663
f15ed8e
 
f5f6359
5c2ab5b
1dc34d7
75e294c
1dc34d7
 
22ac850
1dc34d7
0488844
f15ed8e
 
f5f6359
0eb1946
f15ed8e
334c746
1dc34d7
335ddf3
1dc34d7
 
7942c52
0eb1946
48bb7a7
f15ed8e
 
 
7942c52
1dc34d7
0b127ab
7942c52
 
f15ed8e
 
 
1dc34d7
 
f15ed8e
 
 
0eb1946
f15ed8e
 
 
 
 
0eb1946
7942c52
f15ed8e
 
 
 
0eb1946
 
f15ed8e
0eb1946
72ded5c
f15ed8e
0eb1946
f15ed8e
 
0eb1946
f15ed8e
 
 
 
0eb1946
f5f6359
f15ed8e
 
 
7942c52
f15ed8e
 
 
 
 
 
 
 
 
 
 
 
0b127ab
f15ed8e
 
 
 
 
 
00cff3f
f15ed8e
 
 
00cff3f
f15ed8e
 
00cff3f
f15ed8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb1946
f5f6359
f15ed8e
7250fa7
 
 
 
 
 
 
 
 
f15ed8e
522d261
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8000"))

DESCRIPTION = """\
# ConvAI 9b v2 Chat
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    model_id = "CreitinGameplays/ConvAI-9b-v2"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
    tokenizer.use_default_system_prompt = False

system_prompt_text = "You are a helpful, respectful and honest AI assistant named ChatGPT. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to a question, please don’t share false information."

@spaces.GPU(duration=90)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str = system_prompt_text,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 1.0,
    top_k: int = 0,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=5.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6, value=system_prompt_text),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=1.0,
        ),
        gr.Slider(
            label="Top-k",
            minimum=0,
            maximum=1000,
            step=1,
            value=0,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    chat_interface.render()
    
if __name__ == "__main__":
    demo.queue(max_size=10).launch(show_error=True)