wmpscc commited on
Commit
45d104f
1 Parent(s): 7da9532

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +5 -5
generate.py CHANGED
@@ -86,10 +86,10 @@ class LmGeneration:
86
  total_len = args.seq_length
87
 
88
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
- tokens = torch.full((batch, total_len), self.tokenizer.pad_token_id).to(device).long()
90
  for idx, t in enumerate(prompt_tokens):
91
  tokens[idx, : len(t)] = torch.tensor(t).long()
92
- mask = tokens != self.tokenizer.pad_token_id
93
  start_pos = min_prompt_len
94
  prev_pos = 0
95
  continue_exsample = [i for i in range(batch)]
@@ -118,7 +118,7 @@ class LmGeneration:
118
  continue_exsample = []
119
  for i, t in enumerate(tokens.tolist()):
120
  try:
121
- t.index(self.tokenizer.eos_token_id)
122
  except ValueError:
123
  if cut_off is not None:
124
  if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
@@ -134,8 +134,8 @@ class LmGeneration:
134
  for i, t in enumerate(tokens.tolist()):
135
  t = t[: args.seq_length]
136
  try:
137
- t = t[: t.index(self.tokenizer.pad_token_id)]
138
- t = t[: t.index(self.tokenizer.eos_token_id)]
139
  except ValueError:
140
  pass
141
  decoder.append(self.tokenizer.decode(t))
 
86
  total_len = args.seq_length
87
 
88
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long()
90
  for idx, t in enumerate(prompt_tokens):
91
  tokens[idx, : len(t)] = torch.tensor(t).long()
92
+ mask = tokens != self.tokenizer.pad_token
93
  start_pos = min_prompt_len
94
  prev_pos = 0
95
  continue_exsample = [i for i in range(batch)]
 
118
  continue_exsample = []
119
  for i, t in enumerate(tokens.tolist()):
120
  try:
121
+ t.index(self.tokenizer.eos_token)
122
  except ValueError:
123
  if cut_off is not None:
124
  if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
 
134
  for i, t in enumerate(tokens.tolist()):
135
  t = t[: args.seq_length]
136
  try:
137
+ t = t[: t.index(self.tokenizer.pad_token)]
138
+ t = t[: t.index(self.tokenizer.eos_token)]
139
  except ValueError:
140
  pass
141
  decoder.append(self.tokenizer.decode(t))