File size: 3,724 Bytes
8988bbf
 
 
41bb1cf
8988bbf
 
 
 
 
 
 
 
 
 
 
 
2fa4e4c
 
 
 
 
 
 
 
 
8988bbf
 
 
 
 
e74047c
8988bbf
 
e74047c
8988bbf
 
 
 
 
 
 
 
 
 
 
 
 
e74047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8988bbf
 
 
 
 
 
 
 
e74047c
 
 
 
 
 
8988bbf
e74047c
8988bbf
 
 
e74047c
 
 
8988bbf
 
 
e74047c
8988bbf
 
e74047c
 
 
 
 
8988bbf
2fa4e4c
8988bbf
 
 
 
2fa4e4c
8988bbf
 
e74047c
8988bbf
e74047c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"Qwen/Qwen2-0.5B-Instruct"

from threading import Thread
from models.base_model import Simulator

from transformers import TextIteratorStreamer
from transformers import AutoModelForCausalLM, AutoTokenizer


class Qwen2Simulator(Simulator):

    def __init__(self, model_name_or_path):
        """
        在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
        """

        self.tokenizer = None
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = None
        # self.model = AutoModelForCausalLM.from_pretrained(
        #     model_name_or_path,
        #     torch_dtype="auto",
        #     device_map="auto"
        # )
        # self.model.eval()
        self.generation_kwargs = dict(
            do_sample=True,
            temperature=0.7,
            # repetition_penalty=
            max_length=500,
            max_new_tokens=20
        )

    def generate_query(self, messages, stream=True):
        """
        :param messages:
        :return:
        """
        assert messages[-1]["role"] != "user"
        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        inputs = inputs + "<|im_start|>user\n"
        input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)

        streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=120.0,
                                        skip_special_tokens=True)

        stream_generation_kwargs = dict(
            input_ids=input_ids,
            streamer=streamer
        ).update(self.generation_kwargs)
        thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs)
        thread.start()

        for new_text in streamer:
            print(new_text)
            yield new_text
        # return self._generate(input_ids)

    def generate_response(self, messages, stream=True):
        assert messages[-1]["role"] == "user"
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(self.model.device)

        streamer = TextIteratorStreamer(
            tokenizer=self.tokenizer,
            # skip_prompt=True,
            # timeout=120.0,
            # skip_special_tokens=True
        )

        generation_kwargs = dict(
            input_ids=input_ids,
            streamer=streamer
        ).update(self.generation_kwargs)
        print(generation_kwargs)

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        for new_text in streamer:
            print(new_text)
            yield new_text

    def _generate(self, input_ids):
        input_ids_length = input_ids.shape[-1]
        response = self.model.generate(input_ids=input_ids, **self.generation_kwargs)
        return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)



# bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")


if __name__ == "__main__":
    bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
    messages = [
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "hi, what your name"}
    ]
    streamer = bot.generate_response(messages)
    # print(output)

    # messages = [
    #     {"role": "system", "content": "you are a helpful assistant"},
    #     {"role": "user", "content": "hi, what your name"},
    #     {"role": "assistant", "content": "My name is Jordan"}
    # ]
    # streamer = bot.generate_query(messages)
    print(list(streamer))