zetavg commited on
Commit
4870204
1 Parent(s): a5d7977

extract inference

Browse files
llama_lora/lib/inference.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+
4
+ from .streaming_generation_utils import Iteratorize, Stream
5
+
6
+
7
+ def generate(
8
+ # model
9
+ model,
10
+ tokenizer,
11
+ # input
12
+ prompt,
13
+ generation_config,
14
+ max_new_tokens,
15
+ stopping_criteria=[],
16
+ # output options
17
+ stream_output=False
18
+ ):
19
+ device = get_device()
20
+
21
+ inputs = tokenizer(prompt, return_tensors="pt")
22
+ input_ids = inputs["input_ids"].to(device)
23
+ generate_params = {
24
+ "input_ids": input_ids,
25
+ "generation_config": generation_config,
26
+ "return_dict_in_generate": True,
27
+ "output_scores": True,
28
+ "max_new_tokens": max_new_tokens,
29
+ "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
30
+ }
31
+
32
+ if stream_output:
33
+ # Stream the reply 1 token at a time.
34
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
35
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
36
+
37
+ def generate_with_callback(callback=None, **kwargs):
38
+ kwargs["stopping_criteria"].insert(
39
+ 0,
40
+ Stream(callback_func=callback)
41
+ )
42
+ with torch.no_grad():
43
+ model.generate(**kwargs)
44
+
45
+ def generate_with_streaming(**kwargs):
46
+ return Iteratorize(
47
+ generate_with_callback, kwargs, callback=None
48
+ )
49
+
50
+ with generate_with_streaming(**generate_params) as generator:
51
+ for output in generator:
52
+ decoded_output = tokenizer.decode(output, skip_special_tokens=True)
53
+ yield decoded_output, output
54
+ if output[-1] in [tokenizer.eos_token_id]:
55
+ break
56
+ return # early return for stream_output
57
+
58
+ # Without streaming
59
+ with torch.no_grad():
60
+ generation_output = model.generate(**generate_params)
61
+ output = generation_output.sequences[0]
62
+ decoded_output = tokenizer.decode(output, skip_special_tokens=True)
63
+ yield decoded_output, output
64
+ return
65
+
66
+
67
+ def get_device():
68
+ if torch.cuda.is_available():
69
+ return "cuda"
70
+ else:
71
+ return "cpu"
72
+
73
+ try:
74
+ if torch.backends.mps.is_available():
75
+ return "mps"
76
+ except: # noqa: E722
77
+ pass
llama_lora/{utils/callbacks.py → lib/streaming_generation_utils.py} RENAMED
File without changes
llama_lora/models.py CHANGED
@@ -60,9 +60,10 @@ def get_new_base_model(base_model_name):
60
  base_model_name, device_map={"": device}, low_cpu_mem_usage=True
61
  )
62
 
63
- model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
64
- model.config.bos_token_id = 1
65
- model.config.eos_token_id = 2
 
66
 
67
  return model
68
 
 
60
  base_model_name, device_map={"": device}, low_cpu_mem_usage=True
61
  )
62
 
63
+ tokenizer = get_tokenizer(base_model_name)
64
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
65
+ model.config.bos_token_id = tokenizer.bos_token_id = 1
66
+ model.config.eos_token_id = tokenizer.eos_token_id = 2
67
 
68
  return model
69
 
llama_lora/ui/inference_ui.py CHANGED
@@ -8,12 +8,12 @@ from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
  from ..models import get_model, get_tokenizer, get_device
 
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
  get_info_of_available_lora_model)
15
  from ..utils.prompter import Prompter
16
- from ..utils.callbacks import Iteratorize, Stream
17
 
18
  device = get_device()
19
 
@@ -103,8 +103,6 @@ def do_inference(
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
- inputs = tokenizer(prompt, return_tensors="pt")
107
- input_ids = inputs["input_ids"].to(device)
108
  generation_config = GenerationConfig(
109
  temperature=temperature,
110
  top_p=top_p,
@@ -113,26 +111,56 @@ def do_inference(
113
  num_beams=num_beams,
114
  )
115
 
116
- generate_params = {
117
- "input_ids": input_ids,
118
- "generation_config": generation_config,
119
- "return_dict_in_generate": True,
120
- "output_scores": True,
121
- "max_new_tokens": max_new_tokens,
122
- }
123
-
124
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
125
  if Global.should_stop_generating:
126
  return True
127
  return False
128
 
129
  Global.should_stop_generating = False
130
- generate_params.setdefault(
131
- "stopping_criteria", transformers.StoppingCriteriaList()
132
- )
133
- generate_params["stopping_criteria"].append(
134
- ui_generation_stopping_criteria
135
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if stream_output:
138
  # Stream the reply 1 token at a time.
 
8
 
9
  from ..globals import Global
10
  from ..models import get_model, get_tokenizer, get_device
11
+ from ..lib.inference import generate
12
  from ..utils.data import (
13
  get_available_template_names,
14
  get_available_lora_model_names,
15
  get_info_of_available_lora_model)
16
  from ..utils.prompter import Prompter
 
17
 
18
  device = get_device()
19
 
 
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
 
 
106
  generation_config = GenerationConfig(
107
  temperature=temperature,
108
  top_p=top_p,
 
111
  num_beams=num_beams,
112
  )
113
 
 
 
 
 
 
 
 
 
114
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
115
  if Global.should_stop_generating:
116
  return True
117
  return False
118
 
119
  Global.should_stop_generating = False
120
+
121
+ generation_args = {
122
+ 'model': model,
123
+ 'tokenizer': tokenizer,
124
+ 'prompt': prompt,
125
+ 'generation_config': generation_config,
126
+ 'max_new_tokens': max_new_tokens,
127
+ 'stopping_criteria': [ui_generation_stopping_criteria],
128
+ 'stream_output': stream_output
129
+ }
130
+
131
+ for (decoded_output, output) in generate(**generation_args):
132
+ raw_output_str = None
133
+ if show_raw:
134
+ raw_output_str = str(output)
135
+ response = prompter.get_response(decoded_output)
136
+
137
+ if Global.should_stop_generating:
138
+ return
139
+
140
+ yield (
141
+ gr.Textbox.update(
142
+ value=response, lines=inference_output_lines),
143
+ raw_output_str)
144
+
145
+ if Global.should_stop_generating:
146
+ # If the user stops the generation, and then clicks the
147
+ # generation button again, they may mysteriously landed
148
+ # here, in the previous, should-be-stopped generation
149
+ # function call, with the new generation function not be
150
+ # called at all. To workaround this, we yield a message
151
+ # and setting lines=1, and if the front-end JS detects
152
+ # that lines has been set to 1 (rows="1" in HTML),
153
+ # it will automatically click the generate button again
154
+ # (gr.Textbox.update() does not support updating
155
+ # elem_classes or elem_id).
156
+ # [WORKAROUND-UI01]
157
+ yield (
158
+ gr.Textbox.update(
159
+ value="Please retry", lines=1),
160
+ None)
161
+
162
+ return
163
+
164
 
165
  if stream_output:
166
  # Stream the reply 1 token at a time.