Bellamy66's picture
Update app.py
f8aee35 verified
raw
history blame contribute delete
No virus
3.69 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
torch.set_default_device("cuda")
# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained(
"mlabonne/phixtral-4x2_8",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"mlabonne/phixtral-4x2_8",
torch_dtype="auto",
load_in_8bit=True,
trust_remote_code=True
)
# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [50256, 50295] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return True
return False
# Function to generate model predictions.
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
# Formatting the input for the model.
system_prompt = "<|im_start|>system\nYou are Phixtral, a helpful AI assistant.<|im_end|>"
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
print(messages)
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '<|im_end|>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message
# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
description="""
<center><img src="https://i.imgur.com/CJSeIGg.png" width="33%"></center>\n\n
Chat with [mlabonne/phixtral-2x2_8](https://huggingface.co/mlabonne/phixtral-2x2_8), the first Mixture of Experts made by merging two fine-tuned [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) models.
This small model (4.46B param) is good for various tasks, such as programming, dialogues, story writing, and more.\n\n
❀️ If you like this work, please follow me on [Hugging Face](https://huggingface.co/mlabonne) and [Twitter](https://twitter.com/maximelabonne).
""",
examples=[
'Can you solve the equation 2x + 3 = 11 for x?',
'Write an epic poem about Ancient Rome.',
'Who was the first person to walk on the Moon?',
'Use a list comprehension to create a list of squares for numbers from 1 to 10.',
'Recommend some popular science fiction books.',
'Can you write a short story about a time-traveling detective?'
],
theme=gr.themes.Soft(primary_hue="orange"),
).launch()