Spaces:
Runtime error
Runtime error
fixed global var wack stuff
Browse files
app.py
CHANGED
@@ -19,35 +19,19 @@ generation_config = GenerationConfig.from_pretrained('gpt2-medium')
|
|
19 |
generation_config.max_new_tokens = response_length
|
20 |
generation_config.pad_token_id = generation_config.eos_token_id
|
21 |
|
|
|
22 |
|
23 |
-
|
24 |
|
25 |
-
|
26 |
-
def generate_response(new_prompt):
|
27 |
-
print('a')
|
28 |
-
global outputs
|
29 |
-
print('b')
|
30 |
-
story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)])
|
31 |
-
print('c')
|
32 |
set_seed(random.randint(0, 4000000000))
|
33 |
inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
|
34 |
return_tensors='pt', truncation=True,
|
35 |
max_length=1024 - response_length)
|
36 |
-
|
37 |
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
|
38 |
-
|
39 |
-
response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(story_so_far) + 1) if story_so_far else 0):])
|
40 |
-
print('f')
|
41 |
-
outputs.append(response)
|
42 |
-
print('g')
|
43 |
-
return ((story_so_far + '\n' if story_so_far else '') + response).replace('\n', '\n\n')
|
44 |
|
45 |
-
|
46 |
-
global outputs
|
47 |
-
print(outputs)
|
48 |
-
outputs = outputs[:-1]
|
49 |
-
print(outputs)
|
50 |
-
return "\n".join(outputs).replace('\n', '\n\n')
|
51 |
|
52 |
def clean_paragraph(entry):
|
53 |
paragraphs = entry.split('\n')
|
@@ -60,8 +44,8 @@ def clean_paragraph(entry):
|
|
60 |
return capitalize_first_char("\n".join(paragraphs))
|
61 |
|
62 |
def reset():
|
63 |
-
global outputs
|
64 |
-
outputs = []
|
65 |
return None
|
66 |
|
67 |
def capitalize_first_char(entry):
|
@@ -78,12 +62,12 @@ with gr.Blocks() as demo:
|
|
78 |
|
79 |
with gr.Row():
|
80 |
gen_button = gr.Button('Generate')
|
81 |
-
undo_button = gr.Button("Undo")
|
82 |
res_button = gr.Button("Reset")
|
83 |
|
84 |
prompt.submit(generate_response, prompt, story, scroll_to_output=True)
|
85 |
-
gen_button.click(generate_response, prompt, story, scroll_to_output=True)
|
86 |
-
undo_button.click(undo, [], story, scroll_to_output=True)
|
87 |
res_button.click(reset, [], story, scroll_to_output=True)
|
88 |
|
89 |
demo.launch(inbrowser=True)
|
|
|
19 |
generation_config.max_new_tokens = response_length
|
20 |
generation_config.pad_token_id = generation_config.eos_token_id
|
21 |
|
22 |
+
def generate_response(story_so_far, new_prompt):
|
23 |
|
24 |
+
truncated_story = story_so_far[:1024 - response_length]
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
set_seed(random.randint(0, 4000000000))
|
27 |
inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
|
28 |
return_tensors='pt', truncation=True,
|
29 |
max_length=1024 - response_length)
|
30 |
+
|
31 |
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
|
32 |
+
response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(truncated_story) + 1) if truncated_story else 0):])
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
return ((story_so_far + '\n' if story_so_far else '') + response).replace('\n', '\n\n')
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def clean_paragraph(entry):
|
37 |
paragraphs = entry.split('\n')
|
|
|
44 |
return capitalize_first_char("\n".join(paragraphs))
|
45 |
|
46 |
def reset():
|
47 |
+
# global outputs
|
48 |
+
# outputs = []
|
49 |
return None
|
50 |
|
51 |
def capitalize_first_char(entry):
|
|
|
62 |
|
63 |
with gr.Row():
|
64 |
gen_button = gr.Button('Generate')
|
65 |
+
# undo_button = gr.Button("Undo")
|
66 |
res_button = gr.Button("Reset")
|
67 |
|
68 |
prompt.submit(generate_response, prompt, story, scroll_to_output=True)
|
69 |
+
gen_button.click(generate_response, [story, prompt], story, scroll_to_output=True)
|
70 |
+
# undo_button.click(undo, [], story, scroll_to_output=True)
|
71 |
res_button.click(reset, [], story, scroll_to_output=True)
|
72 |
|
73 |
demo.launch(inbrowser=True)
|