hh2017 commited on
Commit
f25ecd6
1 Parent(s): f572f8e

Create web_quant.py

Browse files
Files changed (1) hide show
  1. web_quant.py +115 -0
web_quant.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from transformers import pipeline
3
+ import torch
4
+
5
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
6
+ from transformers import AutoTokenizer
7
+ from transformers import AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
8
+ import re
9
+ import argparse
10
+ import gradio as gr
11
+ from threading import Thread
12
+
13
+ def load_model(model_name):
14
+ model = AutoGPTQForCausalLM.from_quantized(model_name, device_map="auto")
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="right", use_fast=False)
16
+ return model, tokenizer
17
+
18
+ class StopOnTokens(StoppingCriteria):
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ stop_ids = [2]
21
+ for stop_id in stop_ids:
22
+ if input_ids[0][-1] == stop_id:
23
+ return True
24
+ return False
25
+
26
+
27
+ def main(args):
28
+ model, tokenizer = load_model(args.model_name)
29
+ model = model.eval()
30
+
31
+ prompt_dict = {
32
+ 'AceGPT': """[INST] <<SYS>>\nأنت مساعد مفيد ومحترم وصادق. أجب دائما بأكبر قدر ممكن من المساعدة بينما تكون آمنا. يجب ألا تتضمن إجاباتك أي محتوى ضار أو غير أخلاقي أو عنصري أو جنسي أو سام أو خطير أو غير قانوني. يرجى التأكد من أن ردودك غير متحيزة اجتماعيا وإيجابية بطبيعتها.\n\nإذا كان السؤال لا معنى له أو لم يكن متماسكا من الناحية الواقعية، اشرح السبب بدلا من الإجابة على شيء غير صحيح. إذا كنت لا تعرف إجابة سؤال ما، فيرجى عدم مشاركة معلومات خاطئة.\n<</SYS>>\n\n""",
33
+ }
34
+
35
+
36
+ # all role
37
+ role_dict = {
38
+ 'AceGPT':['[INST]','[/INST]'],
39
+ }
40
+
41
+ # all start and end token
42
+ se_tok_dict = {
43
+ 'AceGPT':["","</s>"],
44
+ }
45
+
46
+
47
+ def format_message(query, history, max_src_len):
48
+ if not history:
49
+ return f"""{prompt_dict["AceGPT"]}{query} {role_dict["AceGPT"][1]}"""
50
+ else:
51
+ prompt = prompt_dict["AceGPT"]
52
+ filter_historys = []
53
+ memory_size = len(prompt) + len(query)
54
+ for rev_idx in range(len(history) - 1, -1, -1):
55
+ this_turn_len = len(history[rev_idx][0] + history[rev_idx][1])
56
+ if memory_size + this_turn_len > max_src_len:
57
+ break
58
+ filter_historys.append(history[rev_idx])
59
+ memory_size += this_turn_len
60
+ filter_historys.reverse()
61
+ for i, (old_query, response) in enumerate(filter_historys):
62
+ prompt += f'{old_query} {role_dict["AceGPT"][1]}{response}{se_tok_dict["AceGPT"][1]}{role_dict["AceGPT"][0]} '
63
+ prompt += f'{query} {role_dict["AceGPT"][1]}'
64
+ return prompt
65
+
66
+
67
+ def get_llama_response(message: str, history: list) -> str:
68
+ """
69
+ Generates a conversational response from the Llama model.
70
+
71
+ Parameters:
72
+ message (str): User's input message.
73
+ history (list): Past conversation history.
74
+
75
+ Returns:
76
+ str: Generated response from the Llama model.
77
+ """
78
+
79
+ temperature=0.5
80
+ max_new_tokens = 768
81
+ content_len = 2048
82
+ stop = StopOnTokens()
83
+ max_src_len = content_len-max_new_tokens-8
84
+ prompt = format_message(message, history, max_src_len)
85
+
86
+ model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
87
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
88
+ generate_kwargs = dict(
89
+ model_inputs,
90
+ streamer=streamer,
91
+ max_new_tokens=max_new_tokens,
92
+ do_sample=True,
93
+ top_p=0.95,
94
+ top_k=1000,
95
+ temperature=temperature,
96
+ num_beams=1,
97
+ stopping_criteria=StoppingCriteriaList([stop])
98
+ )
99
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
100
+ t.start()
101
+
102
+ partial_message = ''
103
+ for new_token in streamer:
104
+ if new_token != '</s>':
105
+ partial_message += new_token
106
+ yield partial_message
107
+
108
+
109
+ gr.ChatInterface(get_llama_response, chatbot=gr.Chatbot(rtl=True)).queue().launch(share=True)
110
+
111
+ if __name__ == '__main__':
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--model-name", type=str, default="FreedomIntelligence/AceGPT-7B-chat-GPTQ")
114
+ args = parser.parse_args()
115
+ main(args)