Spaces:
Runtime error
Runtime error
fixed prompt tokens
Browse files
app.py
CHANGED
@@ -123,7 +123,7 @@ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
|
|
123 |
tokens = torch.tensor(tokenizer.encode(prompt))
|
124 |
tokens = tokens.unsqueeze(0).to(device)
|
125 |
prompt_tokens = model.gpt.transformer.wte(tokens)
|
126 |
-
|
127 |
|
128 |
for i in range(entry_length):
|
129 |
outputs = model.gpt(inputs_embeds=generated)
|
|
|
123 |
tokens = torch.tensor(tokenizer.encode(prompt))
|
124 |
tokens = tokens.unsqueeze(0).to(device)
|
125 |
prompt_tokens = model.gpt.transformer.wte(tokens)
|
126 |
+
generated = torch.cat((generated, prompt_tokens), dim=1)
|
127 |
|
128 |
for i in range(entry_length):
|
129 |
outputs = model.gpt(inputs_embeds=generated)
|