File size: 3,724 Bytes
8988bbf 41bb1cf 8988bbf 2fa4e4c 8988bbf e74047c 8988bbf e74047c 8988bbf e74047c 8988bbf e74047c 8988bbf e74047c 8988bbf e74047c 8988bbf e74047c 8988bbf e74047c 8988bbf 2fa4e4c 8988bbf 2fa4e4c 8988bbf e74047c 8988bbf e74047c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"Qwen/Qwen2-0.5B-Instruct"
from threading import Thread
from models.base_model import Simulator
from transformers import TextIteratorStreamer
from transformers import AutoModelForCausalLM, AutoTokenizer
class Qwen2Simulator(Simulator):
def __init__(self, model_name_or_path):
"""
在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
"""
self.tokenizer = None
# self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = None
# self.model = AutoModelForCausalLM.from_pretrained(
# model_name_or_path,
# torch_dtype="auto",
# device_map="auto"
# )
# self.model.eval()
self.generation_kwargs = dict(
do_sample=True,
temperature=0.7,
# repetition_penalty=
max_length=500,
max_new_tokens=20
)
def generate_query(self, messages, stream=True):
"""
:param messages:
:return:
"""
assert messages[-1]["role"] != "user"
inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
inputs = inputs + "<|im_start|>user\n"
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=120.0,
skip_special_tokens=True)
stream_generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer
).update(self.generation_kwargs)
thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs)
thread.start()
for new_text in streamer:
print(new_text)
yield new_text
# return self._generate(input_ids)
def generate_response(self, messages, stream=True):
assert messages[-1]["role"] == "user"
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
add_generation_prompt=True
).to(self.model.device)
streamer = TextIteratorStreamer(
tokenizer=self.tokenizer,
# skip_prompt=True,
# timeout=120.0,
# skip_special_tokens=True
)
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer
).update(self.generation_kwargs)
print(generation_kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
print(new_text)
yield new_text
def _generate(self, input_ids):
input_ids_length = input_ids.shape[-1]
response = self.model.generate(input_ids=input_ids, **self.generation_kwargs)
return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)
# bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")
if __name__ == "__main__":
bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
messages = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "hi, what your name"}
]
streamer = bot.generate_response(messages)
# print(output)
# messages = [
# {"role": "system", "content": "you are a helpful assistant"},
# {"role": "user", "content": "hi, what your name"},
# {"role": "assistant", "content": "My name is Jordan"}
# ]
# streamer = bot.generate_query(messages)
print(list(streamer))
|