import os import gradio as gr from text_generation import Client, InferenceAPIClient def get_client(model: str): if model == "Rallio67/joi2_20B_instruct_alpha": return Client(os.getenv("API_URL")) return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) def get_usernames(model: str): if model == "Rallio67/joi2_20B_instruct_alpha": return "User: ", "Joi: " return "User: ", "Assistant: " def predict( model: str, inputs: str, top_p: float, temperature: float, top_k: int, repetition_penalty: float, watermark: bool, chatbot, history, ): client = get_client(model) user_name, assistant_name = get_usernames(model) history.append(inputs) past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith("\n\n" + assistant_name): model_data = "\n\n" + assistant_name + model_data past.append(user_data + model_data + "\n\n") if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = "".join(past) + inputs + "\n\n" + assistant_name # truncate total_inputs total_inputs = total_inputs[-1000:] partial_words = "" for i, response in enumerate(client.generate_stream( total_inputs, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, watermark=watermark, temperature=temperature, max_new_tokens=500, stop_sequences=[user_name.rstrip(), assistant_name.rstrip()], )): if response.token.special: continue partial_words = partial_words + response.token.text if partial_words.endswith(user_name.rstrip()): partial_words = partial_words.rstrip(user_name.rstrip()) if partial_words.endswith(assistant_name.rstrip()): partial_words = partial_words.rstrip(assistant_name.rstrip()) if i == 0: history.append(" " + partial_words) else: history[-1] = partial_words chat = [ (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2) ] yield chat, history def reset_textbox(): return gr.update(value="") title = """