Crystalcareai commited on
Commit
c4e57ac
1 Parent(s): cc66ab7

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +17 -15
generate.py CHANGED
@@ -1,12 +1,7 @@
1
  import torch
2
- from transformers.generation.utils import (
3
- GenerationMixin,
4
- validate_stopping_criteria,
5
- StoppingCriteriaList,
6
- )
7
  from transformers import TextStreamer
8
 
9
-
10
  def custom_generate(
11
  self,
12
  input_ids,
@@ -72,13 +67,14 @@ def custom_generate(
72
  last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
73
 
74
  new_ids_sampled = torch.multinomial(
75
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
 
76
 
77
  # Assign the new id to the last token
78
  if last_token_idx + 1 >= len(base_answer_ids):
79
  # Add padding everywhere
80
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
81
- device=device)
82
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
83
  if attention_mask is not None:
84
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
@@ -94,15 +90,20 @@ def custom_generate(
94
  # Check if the end token is generated
95
  if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
96
  finished_generating[answer_idx] = 1
97
-
98
  if finished_generating.all():
99
  break
100
 
101
  if streamer is not None:
102
  streamer.put(new_ids_sampled)
103
 
104
- return generated_token_ids
 
 
 
 
105
 
 
106
 
107
  def generate(
108
  self,
@@ -153,10 +154,9 @@ def generate(
153
  torch_dtype=torch.bfloat16,
154
  **model_kwargs,
155
  ):
156
-
157
  if max_new_tokens is None:
158
- max_new_tokens = 128
159
-
160
  # Set model attributes
161
  self.max_thoughts = n_ahead + n_ahead_talk + 1
162
  self.merged_talk_heads = merged_talk_heads
@@ -190,7 +190,7 @@ def generate(
190
 
191
  generated_token_ids = custom_generate(
192
  self,
193
- input_ids=input_ids,
194
  attention_mask=attention_mask,
195
  max_new_tokens=max_new_tokens,
196
  min_length=min_length,
@@ -225,4 +225,6 @@ def generate(
225
  **model_kwargs,
226
  )
227
 
228
- return generated_token_ids
 
 
 
1
  import torch
2
+ from transformers.generation.utils import GenerationMixin, validate_stopping_criteria, StoppingCriteriaList
 
 
 
 
3
  from transformers import TextStreamer
4
 
 
5
  def custom_generate(
6
  self,
7
  input_ids,
 
67
  last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
68
 
69
  new_ids_sampled = torch.multinomial(
70
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1
71
+ )
72
 
73
  # Assign the new id to the last token
74
  if last_token_idx + 1 >= len(base_answer_ids):
75
  # Add padding everywhere
76
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
77
+ device=device)
78
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
79
  if attention_mask is not None:
80
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
 
90
  # Check if the end token is generated
91
  if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
92
  finished_generating[answer_idx] = 1
93
+
94
  if finished_generating.all():
95
  break
96
 
97
  if streamer is not None:
98
  streamer.put(new_ids_sampled)
99
 
100
+ # Check if dynamic_temperature argument is present
101
+ if 'dynamic_temperature' in kwargs and kwargs['dynamic_temperature'] is not None:
102
+ # Convert generated token IDs to strings and return them
103
+ generated_text = self.tokenizer.batch_decode(generated_token_ids, skip_special_tokens=True)
104
+ return generated_text
105
 
106
+ return generated_token_ids
107
 
108
  def generate(
109
  self,
 
154
  torch_dtype=torch.bfloat16,
155
  **model_kwargs,
156
  ):
 
157
  if max_new_tokens is None:
158
+ max_new_tokens = 128
159
+
160
  # Set model attributes
161
  self.max_thoughts = n_ahead + n_ahead_talk + 1
162
  self.merged_talk_heads = merged_talk_heads
 
190
 
191
  generated_token_ids = custom_generate(
192
  self,
193
+ input_ids=input_ids,
194
  attention_mask=attention_mask,
195
  max_new_tokens=max_new_tokens,
196
  min_length=min_length,
 
225
  **model_kwargs,
226
  )
227
 
228
+ # Convert generated token IDs to strings and return them
229
+ generated_text = self.tokenizer.batch_decode(generated_token_ids, skip_special_tokens=True)
230
+ return generated_text