# Load the model. # Note: It can take a while to download LLaMA and add the adapter modules. # You can also use the 13B model by loading in 4bits. import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer model_name = "decapoda-research/llama-7b-hf" adapters_name = 'timdettmers/guanaco-7b' print(f"Starting to load the model {model_name} into memory") m = AutoModelForCausalLM.from_pretrained( model_name, #load_in_4bit=True, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="." ) m = PeftModel.from_pretrained(m, adapters_name, offload_dir="") m = m.merge_and_unload() tok = LlamaTokenizer.from_pretrained(model_name) tok.bos_token_id = 1 stop_token_ids = [0] print(f"Successfully loaded the model {model_name} into memory") # Setup the gradio Demo. import datetime import os from threading import Event, Thread from uuid import uuid4 import gradio as gr import requests max_new_tokens = 1536 start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def convert_history_to_text(history): text = start_message + "".join( [ "".join( [ f"### Human: {item[0]}\n", f"### Assistant: {item[1]}\n", ] ) for item in history[:-1] ] ) text += "".join( [ "".join( [ f"### Human: {history[-1][0]}\n", f"### Assistant: {history[-1][1]}\n", ] ) ] ) return text def log_conversation(conversation_id, history, messages, generate_kwargs): logging_url = os.getenv("LOGGING_URL", None) if logging_url is None: return timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "generate_kwargs": generate_kwargs, } try: requests.post(logging_url, json=data) except requests.exceptions.RequestException as e: print(f"Error logging conversation: {e}") def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): print(f"history: {history}") # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = convert_history_to_text(history) # Tokenize the messages string input_ids = tok(messages, return_tensors="pt").input_ids input_ids = input_ids.to(m.device) streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, stopping_criteria=StoppingCriteriaList([stop]), ) stream_complete = Event() def generate_and_signal_complete(): m.generate(**generate_kwargs) stream_complete.set() def log_after_stream_complete(): stream_complete.wait() log_conversation( conversation_id, history, messages, { "top_k": top_k, "top_p": top_p, "temperature": temperature, "repetition_penalty": repetition_penalty, }, ) t1 = Thread(target=generate_and_signal_complete) t1.start() t2 = Thread(target=log_after_stream_complete) t2.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text history[-1][1] = partial_text yield history def get_uuid(): return str(uuid4()) with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """