Darpan commited on
Commit
632b9b5
1 Parent(s): b5160f5

Add script for Chat demo

Browse files
Files changed (1) hide show
  1. app_chat.py +106 -0
app_chat.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
2
+ from peft import PeftModel
3
+ import torch
4
+ import transformers
5
+ import gradio as gr
6
+ import time
7
+
8
+ MODEL = "decapoda-research/llama-7b-hf"
9
+ LORA_WEIGHTS = "tloen/alpaca-lora-7b"
10
+ device = "cpu"
11
+ print(f"Model device = {device}", flush=True)
12
+
13
+ def load_model():
14
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL)
15
+ model = LlamaForCausalLM.from_pretrained(MODEL, device_map={"": device}, low_cpu_mem_usage=True)
16
+ model = PeftModel.from_pretrained(model, LORA_WEIGHTS, device_map={"": device}, torch_dtype=torch.float16)
17
+ model.eval()
18
+
19
+ return model, tokenizer
20
+
21
+ def generate_prompt(input):
22
+ return f""" Below A dialog, where User interacts with you - the AI.
23
+
24
+ ### Instruction: AI is helpful, kind, obedient, honest, and knows its own limits.
25
+
26
+ ### User: {input}
27
+
28
+ ### Response:
29
+ """
30
+
31
+ def eval_prompt(
32
+ model,
33
+ tokenizer,
34
+ input: str,
35
+ temparature = 0.7,
36
+ top_p = 0.75,
37
+ top_k = 40,
38
+ num_beams = 1,
39
+ max_new_tokens = 128,
40
+ **kwargs):
41
+
42
+ prompt = generate_prompt(input)
43
+ inputs = tokenizer(prompt, return_tensors = "pt")
44
+ input_ids = inputs["input_ids"]
45
+ generation_config = GenerationConfig(
46
+ temparatue = temparature,
47
+ top_p = top_p,
48
+ top_k = top_k,
49
+ num_beams = num_beams,
50
+ repetition_penalty = 1.17,
51
+ ** kwargs,)
52
+
53
+ # with torch.inference_mode():
54
+ with torch.no_grad():
55
+ generation_output = model.generate(
56
+ input_ids = input_ids,
57
+ generation_config = generation_config,
58
+ return_dict_in_generate = True,
59
+ output_scores = True,
60
+ max_new_tokens = max_new_tokens,
61
+ )
62
+ s = generation_output.sequences[0]
63
+ response = tokenizer.decode(s)
64
+ print(f"Bot response: {response.split('### Response:')[-1].strip()}")
65
+ bot_response = response.split("### Response:")[-1].strip()
66
+ return bot_response
67
+
68
+ def run_app(model, tokenizer):
69
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=True) as chat:
70
+ chatbot = gr.Chatbot(label = "Alpaca Demo")
71
+ msg = gr.Textbox(show_label = False, placeholder = "Enter your text here")
72
+ clear = gr.Button("Clear")
73
+ temparature = gr.Slider(minimum=0, maximum=1, value=0.8, label="Temparature")
74
+
75
+ def user(user_msg, history):
76
+ return "", history + [[user_msg, None]]
77
+
78
+ def bot(history):
79
+ print("Processing user input for Alpaca response...")
80
+ last_input = history[-1][0]
81
+ print(f"User input = {last_input}")
82
+
83
+ tick = time.time()
84
+ bot_response = eval_prompt(model, tokenizer, last_input)
85
+ print(f"Inference time = {time.time() - tick} seconds")
86
+
87
+ history[-1][1] = bot_response
88
+ print("Response generated and added to history.\n")
89
+ return history
90
+
91
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
92
+ bot, chatbot, chatbot
93
+ )
94
+
95
+ clear.click(lambda: None, None, chatbot, queue=False)
96
+
97
+
98
+ chat.queue()
99
+ chat.launch(share=True)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ model, tokenizer = load_model()
104
+
105
+ # Run the actual gradio app
106
+ run_app(model, tokenizer)