Anthony G
slight change
6bf9d73
raw
history blame
2.7 kB
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
import warnings
warnings.filterwarnings("ignore")
PEFT_MODEL = "givyboy/phi-2-finetuned-mental-health-conversational"
SYSTEM_PROMPT = """Answer the following question truthfully.
If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'."""
USER_PROMPT = lambda x: f"""<HUMAN>: {x}\n<ASSISTANT>: """
ADD_RESPONSE = lambda x, y: f"""<HUMAN>: {x}\n<ASSISTANT>: {y}"""
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=torch.float16,
# )
config = PeftConfig.from_pretrained(PEFT_MODEL)
peft_base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
# quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
offload_folder="offload/",
offload_state_dict=True,
)
peft_model = PeftModel.from_pretrained(
peft_base_model,
PEFT_MODEL,
offload_folder="offload/",
offload_state_dict=True,
)
peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
peft_tokenizer.pad_token = peft_tokenizer.eos_token
pipeline = transformers.pipeline(
"text-generation",
model=peft_model,
tokenizer=peft_tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def format_message(message: str, history: list[str], memory_limit: int = 3) -> str:
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return f"{SYSTEM_PROMPT}\n{USER_PROMPT(message)}"
formatted_message = f"{SYSTEM_PROMPT}\n{ADD_RESPONSE(history[0][0], history[0][1])}"
for msg, ans in history[1:]:
formatted_message += f"\n{ADD_RESPONSE(msg, ans)}"
formatted_message += f"\n{USER_PROMPT(message)}"
return formatted_message
def get_model_response(message: str, history: list[str]) -> str:
formatted_message = format_message(message, history)
sequences = pipeline(
formatted_message,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=peft_tokenizer.eos_token_id,
max_length=600,
)[0]
print(sequences["generated_text"])
output = sequences["generated_text"].split("<ASSISTANT>:")[-1].strip()
# print(f"Response: {output}")
return output
gr.ChatInterface(fn=get_model_response).launch()