File size: 3,012 Bytes
54fe16b
 
31a1ff8
e2b5fc2
d34759f
d6839dc
54fe16b
9e4ba23
af83917
54fe16b
 
87aa391
54fe16b
 
 
 
 
 
 
 
 
31a1ff8
54fe16b
0a1707e
5081c38
54fe16b
5081c38
 
 
 
54fe16b
adf28c0
54fe16b
 
 
 
87aa391
 
54fe16b
0a1707e
 
54fe16b
 
 
 
 
 
 
5251b14
7ea603e
54fe16b
 
 
 
 
 
53b40bf
 
 
54fe16b
 
 
53b40bf
 
0a1707e
 
54fe16b
 
5081c38
 
54fe16b
d6662ca
54fe16b
 
 
 
 
5081c38
 
54fe16b
 
 
53b40bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import spaces

import gradio as gr

import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str)  # model path
    parser.add_argument("--n_gpus", type=int, default=1)  # n_gpu
    return parser.parse_args()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
    global model, tokenizer, device
    messages = [{'role': 'system', 'content': system_prompt}]
    for human, assistant in history:
        messages.append({'role': 'user', 'content': human})
        messages.append({'role': 'assistant', 'content': assistant})
    messages.append({'role': 'user', 'content': message})
    problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
    stop_tokens = ["<|endoftext|>", "<|im_end|>"]
    streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc.input_ids
    attention_mask = enc.attention_mask

    if input_ids.shape[1] > MAX_LENGTH:
        input_ids = input_ids[:, -MAX_LENGTH:]

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    generate_kwargs = dict(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        streamer=streamer,
        do_sample=True,
        top_p=0.95,
        temperature=0.5,
        max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
        use_cache=True,
        eos_token_id=100278 # <|im_end|>
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)



if __name__ == "__main__":
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    gr.ChatInterface(
        predict,
        title="StableLM 2 Chat - Demo",
        description="StableLM 2 Chat - StabilityAI",
        theme="soft",
        chatbot=gr.Chatbot(label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("You are a helpful assistant.", label="System Prompt"),
            gr.Slider(0, 1, 0.5, label="Temperature"),
            gr.Slider(100, 2048, 1024, label="Max Tokens"),
        ],
        additional_inputs_accordion_name="Parameters",
    ).queue().launch()