Spaces:
Running
Running
File size: 3,012 Bytes
54fe16b 31a1ff8 e2b5fc2 d34759f d6839dc 54fe16b 9e4ba23 af83917 54fe16b 87aa391 54fe16b 31a1ff8 54fe16b 0a1707e 5081c38 54fe16b 5081c38 54fe16b adf28c0 54fe16b 87aa391 54fe16b 0a1707e 54fe16b 5251b14 7ea603e 54fe16b 53b40bf 54fe16b 53b40bf 0a1707e 54fe16b 5081c38 54fe16b d6662ca 54fe16b 5081c38 54fe16b 53b40bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import argparse
import os
import spaces
import gradio as gr
import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str) # model path
parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
return parser.parse_args()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
global model, tokenizer, device
messages = [{'role': 'system', 'content': system_prompt}]
for human, assistant in history:
messages.append({'role': 'user', 'content': human})
messages.append({'role': 'assistant', 'content': assistant})
messages.append({'role': 'user', 'content': message})
problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids
attention_mask = enc.attention_mask
if input_ids.shape[1] > MAX_LENGTH:
input_ids = input_ids[:, -MAX_LENGTH:]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
do_sample=True,
top_p=0.95,
temperature=0.5,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
use_cache=True,
eos_token_id=100278 # <|im_end|>
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
gr.ChatInterface(
predict,
title="StableLM 2 Chat - Demo",
description="StableLM 2 Chat - StabilityAI",
theme="soft",
chatbot=gr.Chatbot(label="Chat History",),
textbox=gr.Textbox(placeholder="input", container=False, scale=7),
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs=[
gr.Textbox("You are a helpful assistant.", label="System Prompt"),
gr.Slider(0, 1, 0.5, label="Temperature"),
gr.Slider(100, 2048, 1024, label="Max Tokens"),
],
additional_inputs_accordion_name="Parameters",
).queue().launch() |