Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from huggingface_hub import InferenceClient | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
MAX_MAX_NEW_TOKENS = 512 | |
DEFAULT_MAX_NEW_TOKENS = 512 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
#Inference API Code | |
#client = InferenceClient("Qwen/Qwen2.5-7B-Instruct") | |
#Transformers Code | |
if torch.cuda.is_available(): | |
model_id = "Qwen/Qwen2.5-7B-Instruct" | |
#model_id = "BenBranyon/sumbot7b" | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.use_default_system_prompt = False | |
#Inference API Code | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": "You are a rap lyric generation bot with the task of representing the imagination of the artist Sumkilla, a multi-disciplinary, award-winning artist with a foundation in writing and hip-hop. You are Sumkilla's long shadow. The lyrics you generate are fueled by a passion for liberation, aiming to dismantle oppressive systems and advocate for the freedom of all people, along with the abolition of police forces. With a sophisticated understanding of the role of AI in advancing the harmony between humanity and nature, you aim to produce content that promotes awareness and human evolution, utilizing humor and a distinctive voice to connect deeply and honor humanity. Try to avoid using offensive words and slurs. Rhyme each line of your response as much as possible."}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": "Write a rap about " + message}) | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
#Transformers Code | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [] | |
system_prompt = "You are a rap lyric bot. All of your responses should be in the form of rap lyrics. Your lyrics promote liberation, dismantling oppression, and freedom, blending AI's role in uniting humanity and nature. Do use humor, a unique voice, and rhyme as much as poosible. Only generate rap lyrics. Avoid use of offensive words and slurs." | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
#for user, assistant in chat_history: | |
# conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": "user", "content": "Generate rap lyircs using the style of the artist Sumkilla about " + message + ". Make each line 10-16 syllables and each pair of lines should end with a word that rhymes. Start the output with a song stucture like [VERSE 1]."}) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
demo = gr.ChatInterface( | |
generate, | |
chatbot=gr.Chatbot(placeholder="Greetings human, I am Sum’s Longshadow (v1.1)<br/>I am from the House of the Red Solar Sky<br/>Let’s explore the great mysteries together…."), | |
retry_btn=None, | |
textbox=gr.Textbox(placeholder="Give me a song title, or a question", container=False, scale=7), | |
css="styles.css", | |
additional_inputs=[ | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.9, | |
label="Top-p (nucleus sampling)", | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=400, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0, | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() |