Spaces:
Runtime error
Runtime error
https://huggingface.co/spaces/microsoft/Promptist/discussions/1/files#d2h-120906
Browse files
app.py
CHANGED
@@ -14,9 +14,10 @@ prompter_model, prompter_tokenizer = load_prompter()
|
|
14 |
def generate(plain_text):
|
15 |
input_ids = prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids
|
16 |
eos_id = prompter_tokenizer.eos_token_id
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
return res
|
21 |
|
22 |
txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")
|
|
|
14 |
def generate(plain_text):
|
15 |
input_ids = prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids
|
16 |
eos_id = prompter_tokenizer.eos_token_id
|
17 |
+
# Just use 1 beam and get 1 output, this is much, much, much faster than 8 beams and 8 outputs and we're only using the first.
|
18 |
+
outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
|
19 |
+
# Use [input_ids.shape[-1]:] because the decoded tokenised version of plain_text may have a different number of characters to the original
|
20 |
+
res = tokenizer.decode(outputs[0][input_ids.shape[-1]:])
|
21 |
return res
|
22 |
|
23 |
txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")
|