import datetime import json import os import shutil from typing import Optional from typing import Tuple import gradio as gr import torch from fastchat.serve.inference import compress_module from fastchat.serve.inference import raise_warning_for_old_weights from huggingface_hub import Repository from huggingface_hub import hf_hub_download from huggingface_hub import snapshot_download from peft import LoraConfig from peft import get_peft_model from peft import set_peft_model_state_dict from transformers import AutoModelForCausalLM from transformers import GenerationConfig from transformers import LlamaTokenizer print(datetime.datetime.now()) NUM_THREADS = 1 print(NUM_THREADS) print("starting server ...") BASE_MODEL = "decapoda-research/llama-13b-hf" LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep" HF_TOKEN = os.environ.get("HF_TOKEN", None) DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None) repo = None LOCAL_DIR = "/home/user/data/" PROMPT_LANG = "en" assert PROMPT_LANG in ["ja", "en"] if HF_TOKEN and DATASET_REPOSITORY: try: shutil.rmtree(LOCAL_DIR) except Exception: pass repo = Repository( local_dir=LOCAL_DIR, clone_from=DATASET_REPOSITORY, use_auth_token=HF_TOKEN, repo_type="dataset", ) repo.git_pull() tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) if torch.cuda.is_available(): device = "cuda" else: device = "cpu" try: if torch.backends.mps.is_available(): device = "mps" except Exception: pass resume_from_checkpoint = snapshot_download( repo_id=LORA_WEIGHTS, use_auth_token=HF_TOKEN ) checkpoint_name = hf_hub_download( repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN ) if device == "cuda": model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16 ) elif device == "mps": model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map={"": device}, load_in_8bit=True, torch_dtype=torch.float16, ) else: model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map={"": device}, load_in_8bit=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, ) config = LoraConfig.from_pretrained(resume_from_checkpoint) model = get_peft_model(model, config) adapters_weights = torch.load(checkpoint_name) set_peft_model_state_dict(model, adapters_weights) raise_warning_for_old_weights(BASE_MODEL, model) compress_module(model, device) # if device == "cuda" or device == "mps": # model = model.to(device) def generate_prompt(instruction: str, input: Optional[str] = None): if input: if PROMPT_LANG == "ja": return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### Response:\n" elif PROMPT_LANG == "en": return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Input: {input} ### Response:""" else: raise ValueError("PROMPT_LANG") else: if PROMPT_LANG == "ja": return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 返答:\n" elif PROMPT_LANG == "en": return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Response:""" else: raise ValueError("PROMPT_LANG") if device != "cpu": model.half() model.eval() if torch.__version__ >= "2": model = torch.compile(model) def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs): current_hour = now.strftime("%Y-%m-%d_%H") file_name = f"prompts_{LORA_WEIGHTS.split('/')[-1]}_{current_hour}.jsonl" if repo is not None: repo.git_pull(rebase=True) with open(os.path.join(LOCAL_DIR, file_name), "a", encoding="utf-8") as f: json.dump( { "inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs, }, f, ensure_ascii=False, ) f.write("\n") repo.push_to_hub() # we cant add typing now # https://github.com/gradio-app/gradio/issues/3514 def evaluate( instruction, input=None, temperature=0.7, max_tokens=384, repetition_penalty=1.0, ): num_beams: int = 1 top_p: float = 0.75 top_k: int = 40 prompt = generate_prompt(instruction, input) inputs = tokenizer(prompt, return_tensors="pt") if len(inputs["input_ids"][0]) > max_tokens + 10: if HF_TOKEN and DATASET_REPOSITORY: try: now = datetime.datetime.now() current_time = now.strftime("%Y-%m-%d %H:%M:%S") print(f"[{current_time}] Pushing prompt and completion to the Hub") save_inputs_and_outputs( now, prompt, "", { "temperature": temperature, "top_p": top_p, "top_k": top_k, "num_beams": num_beams, "max_tokens": max_tokens, "repetition_penalty": repetition_penalty, }, ) except Exception as e: print(e) return ( f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} tokens are used.", gr.update(interactive=True), gr.update(interactive=True), ) input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, num_beams=num_beams, pad_token_id=tokenizer.pad_token_id, eos_token=tokenizer.eos_token_id, ) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_tokens - len(input_ids), ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True) if prompt.endswith("Response:"): output = output.split("### Response:")[1].strip() elif prompt.endswith("返答:"): output = output.split("### 返答:")[1].strip() else: raise ValueError(f"No valid prompt ends. {prompt}") if HF_TOKEN and DATASET_REPOSITORY: try: now = datetime.datetime.now() current_time = now.strftime("%Y-%m-%d %H:%M:%S") print(f"[{current_time}] Pushing prompt and completion to the Hub") save_inputs_and_outputs( now, prompt, output, { "temperature": temperature, "top_p": top_p, "top_k": top_k, "num_beams": num_beams, "max_tokens": max_tokens, "repetition_penalty": repetition_penalty, }, ) except Exception as e: print(e) return output, gr.update(interactive=True), gr.update(interactive=True) def reset_textbox(): return gr.update(value=""), gr.update(value=""), gr.update(value="") def no_interactive() -> Tuple[gr.Request, gr.Request]: return gr.update(interactive=False), gr.update(interactive=False) title = """