zetavg commited on
Commit
00263ef
1 Parent(s): 1a203ff

fix inference output

Browse files
llama_lora/lib/inference.py CHANGED
@@ -4,7 +4,6 @@ import transformers
4
  from .get_device import get_device
5
  from .streaming_generation_utils import Iteratorize, Stream
6
 
7
-
8
  def generate(
9
  # model
10
  model,
@@ -30,18 +29,34 @@ def generate(
30
  "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  if stream_output:
34
  # Stream the reply 1 token at a time.
35
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
36
  # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
 
37
 
38
  def generate_with_callback(callback=None, **kwargs):
 
39
  kwargs["stopping_criteria"].insert(
40
  0,
41
  Stream(callback_func=callback)
42
  )
43
  with torch.no_grad():
44
- model.generate(**kwargs)
45
 
46
  def generate_with_streaming(**kwargs):
47
  return Iteratorize(
@@ -50,16 +65,22 @@ def generate(
50
 
51
  with generate_with_streaming(**generate_params) as generator:
52
  for output in generator:
53
- decoded_output = tokenizer.decode(output, skip_special_tokens=True)
54
  yield decoded_output, output
55
  if output[-1] in [tokenizer.eos_token_id]:
56
  break
 
 
 
 
 
 
57
  return # early return for stream_output
58
 
59
  # Without streaming
60
  with torch.no_grad():
61
  generation_output = model.generate(**generate_params)
62
  output = generation_output.sequences[0]
63
- decoded_output = tokenizer.decode(output, skip_special_tokens=True)
64
  yield decoded_output, output
65
  return
 
4
  from .get_device import get_device
5
  from .streaming_generation_utils import Iteratorize, Stream
6
 
 
7
  def generate(
8
  # model
9
  model,
 
29
  "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
30
  }
31
 
32
+ skip_special_tokens = True
33
+
34
+ if '/dolly' in tokenizer.name_or_path:
35
+ # dolly has additional_special_tokens as ['### End', '### Instruction:', '### Response:'], skipping them will break the prompter's reply extraction.
36
+ skip_special_tokens = False
37
+ # Ensure generation stops once it generates "### End"
38
+ end_key_token_id = tokenizer.encode("### End")
39
+ end_key_token_id = end_key_token_id[0] # 50277
40
+ if isinstance(generate_params['generation_config'].eos_token_id, str):
41
+ generate_params['generation_config'].eos_token_id = [generate_params['generation_config'].eos_token_id]
42
+ elif not generate_params['generation_config'].eos_token_id:
43
+ generate_params['generation_config'].eos_token_id = []
44
+ generate_params['generation_config'].eos_token_id.append(end_key_token_id)
45
+
46
  if stream_output:
47
  # Stream the reply 1 token at a time.
48
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
49
  # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
50
+ generation_output = None
51
 
52
  def generate_with_callback(callback=None, **kwargs):
53
+ nonlocal generation_output
54
  kwargs["stopping_criteria"].insert(
55
  0,
56
  Stream(callback_func=callback)
57
  )
58
  with torch.no_grad():
59
+ generation_output = model.generate(**kwargs)
60
 
61
  def generate_with_streaming(**kwargs):
62
  return Iteratorize(
 
65
 
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
+ decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
  yield decoded_output, output
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
+
73
+ if generation_output:
74
+ output = generation_output.sequences[0]
75
+ decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
+ yield decoded_output, output
77
+
78
  return # early return for stream_output
79
 
80
  # Without streaming
81
  with torch.no_grad():
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
+ decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
  yield decoded_output, output
86
  return
llama_lora/ui/inference_ui.py CHANGED
@@ -160,84 +160,6 @@ def do_inference(
160
  None)
161
 
162
  return
163
-
164
-
165
- if stream_output:
166
- # Stream the reply 1 token at a time.
167
- # This is based on the trick of using 'stopping_criteria' to create an iterator,
168
- # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
169
-
170
- def generate_with_callback(callback=None, **kwargs):
171
- kwargs.setdefault(
172
- "stopping_criteria", transformers.StoppingCriteriaList()
173
- )
174
- kwargs["stopping_criteria"].append(
175
- Stream(callback_func=callback)
176
- )
177
- with torch.no_grad():
178
- model.generate(**kwargs)
179
-
180
- def generate_with_streaming(**kwargs):
181
- return Iteratorize(
182
- generate_with_callback, kwargs, callback=None
183
- )
184
-
185
- with generate_with_streaming(**generate_params) as generator:
186
- for output in generator:
187
- # new_tokens = len(output) - len(input_ids[0])
188
- decoded_output = tokenizer.decode(output)
189
-
190
- if output[-1] in [tokenizer.eos_token_id]:
191
- break
192
-
193
- raw_output = None
194
- if show_raw:
195
- raw_output = str(output)
196
- response = prompter.get_response(decoded_output)
197
-
198
- if Global.should_stop_generating:
199
- return
200
-
201
- yield (
202
- gr.Textbox.update(
203
- value=response, lines=inference_output_lines),
204
- raw_output)
205
-
206
- if Global.should_stop_generating:
207
- # If the user stops the generation, and then clicks the
208
- # generation button again, they may mysteriously landed
209
- # here, in the previous, should-be-stopped generation
210
- # function call, with the new generation function not be
211
- # called at all. To workaround this, we yield a message
212
- # and setting lines=1, and if the front-end JS detects
213
- # that lines has been set to 1 (rows="1" in HTML),
214
- # it will automatically click the generate button again
215
- # (gr.Textbox.update() does not support updating
216
- # elem_classes or elem_id).
217
- # [WORKAROUND-UI01]
218
- yield (
219
- gr.Textbox.update(
220
- value="Please retry", lines=1),
221
- None)
222
- return # early return for stream_output
223
-
224
- # Without streaming
225
- with torch.no_grad():
226
- generation_output = model.generate(**generate_params)
227
- s = generation_output.sequences[0]
228
- output = tokenizer.decode(s)
229
- raw_output = None
230
- if show_raw:
231
- raw_output = str(s)
232
-
233
- response = prompter.get_response(output)
234
- if Global.should_stop_generating:
235
- return
236
-
237
- yield (
238
- gr.Textbox.update(value=response, lines=inference_output_lines),
239
- raw_output)
240
-
241
  except Exception as e:
242
  raise gr.Error(e)
243
 
 
160
  None)
161
 
162
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  except Exception as e:
164
  raise gr.Error(e)
165
 
llama_lora/utils/prompter.py CHANGED
@@ -131,8 +131,13 @@ class Prompter(object):
131
  def get_response(self, output: str) -> str:
132
  if self.template_name == "None":
133
  return output
 
 
 
 
 
134
  return self.template["response_split"].join(
135
- output.split(self.template["response_split"])[1:]
136
  ).strip()
137
 
138
  def get_variable_names(self) -> List[str]:
 
131
  def get_response(self, output: str) -> str:
132
  if self.template_name == "None":
133
  return output
134
+
135
+ splitted_output = output.split(self.template["response_split"])
136
+ # if len(splitted_output) <= 1:
137
+ # return output.strip()
138
+
139
  return self.template["response_split"].join(
140
+ splitted_output[1:]
141
  ).strip()
142
 
143
  def get_variable_names(self) -> List[str]: