imthanhlv commited on
Commit
3800c65
1 Parent(s): 4962857

added prompt debug

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -119,11 +119,11 @@ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
119
  with torch.no_grad():
120
  if embed is not None:
121
  generated = embed
122
- else:
123
- if tokens is None:
124
- tokens = torch.tensor(tokenizer.encode(prompt))
125
- tokens = tokens.unsqueeze(0).to(device)
126
- generated = model.gpt.transformer.wte(tokens)
127
  for i in range(entry_length):
128
  outputs = model.gpt(inputs_embeds=generated)
129
  logits = outputs.logits
 
119
  with torch.no_grad():
120
  if embed is not None:
121
  generated = embed
122
+ if prompt is not None:
123
+ tokens = torch.tensor(tokenizer.encode(prompt))
124
+ tokens = tokens.unsqueeze(0).to(device)
125
+ prompt_tokens = model.gpt.transformer.wte(tokens)
126
+ print(">>>>", generated.shape, prompt_tokens.shape)
127
  for i in range(entry_length):
128
  outputs = model.gpt(inputs_embeds=generated)
129
  logits = outputs.logits