coffeeee commited on
Commit
59fcbd6
1 Parent(s): 0f9549e

readded state

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -19,9 +19,12 @@ generation_config = GenerationConfig.from_pretrained('gpt2-medium')
19
  generation_config.max_new_tokens = response_length
20
  generation_config.pad_token_id = generation_config.eos_token_id
21
 
22
- def generate_response(story_so_far, new_prompt):
23
 
24
- truncated_story = story_so_far[:1024 - response_length]
 
 
 
 
25
 
26
  set_seed(random.randint(0, 4000000000))
27
  inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
@@ -29,9 +32,21 @@ def generate_response(story_so_far, new_prompt):
29
  max_length=1024 - response_length)
30
 
31
  output = model.generate(inputs, do_sample=True, generation_config=generation_config)
32
- response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(truncated_story) + 1) if truncated_story else 0):])
33
 
34
- return (story_so_far + '\n' if story_so_far else '') + response
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def clean_paragraph(entry):
37
  paragraphs = entry.split('\n')
@@ -44,9 +59,10 @@ def clean_paragraph(entry):
44
  return capitalize_first_char("\n".join(paragraphs))
45
 
46
  def reset():
47
- # global outputs
48
- # outputs = []
49
- return None
 
50
 
51
  def capitalize_first_char(entry):
52
  for i in range(len(entry)):
@@ -58,17 +74,19 @@ with gr.Blocks() as demo:
58
  story = gr.Textbox(interactive=False, lines=20)
59
  story.style(show_copy_button=True)
60
 
 
 
61
  prompt = gr.Textbox(placeholder="Continue the story here!", lines=3, max_lines=3)
62
 
63
  with gr.Row():
64
  gen_button = gr.Button('Generate')
65
- # undo_button = gr.Button("Undo")
66
  res_button = gr.Button("Reset")
67
 
68
- prompt.submit(generate_response, [story, prompt], story, scroll_to_output=True)
69
- gen_button.click(generate_response, [story, prompt], story, scroll_to_output=True)
70
- # undo_button.click(undo, [], story, scroll_to_output=True)
71
- res_button.click(reset, [], story, scroll_to_output=True)
72
 
73
  demo.launch(inbrowser=True)
74
 
 
19
  generation_config.max_new_tokens = response_length
20
  generation_config.pad_token_id = generation_config.eos_token_id
21
 
 
22
 
23
+
24
+
25
+ def generate_response(outputs, new_prompt):
26
+
27
+ story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)])
28
 
29
  set_seed(random.randint(0, 4000000000))
30
  inputs = tokenizer.encode(story_so_far + '\n' + new_prompt if story_so_far else new_prompt,
 
32
  max_length=1024 - response_length)
33
 
34
  output = model.generate(inputs, do_sample=True, generation_config=generation_config)
 
35
 
36
+ response = clean_paragraph(tokenizer.batch_decode(output)[0][((len(story_so_far) + 1) if story_so_far else 0):])
37
+ outputs.append(response)
38
+ return {
39
+ user_outputs: outputs,
40
+ story: ((story_so_far + '\n' if story_so_far else '') + response).replace('\n', '\n\n')
41
+ }
42
+
43
+ def undo(outputs):
44
+
45
+ outputs = outputs[:-1]
46
+ return {
47
+ user_outputs: outputs,
48
+ story: "\n".join(outputs)
49
+ }
50
 
51
  def clean_paragraph(entry):
52
  paragraphs = entry.split('\n')
 
59
  return capitalize_first_char("\n".join(paragraphs))
60
 
61
  def reset():
62
+ return {
63
+ user_outputs: None,
64
+ story: None
65
+ }
66
 
67
  def capitalize_first_char(entry):
68
  for i in range(len(entry)):
 
74
  story = gr.Textbox(interactive=False, lines=20)
75
  story.style(show_copy_button=True)
76
 
77
+ user_outputs = gr.State()
78
+
79
  prompt = gr.Textbox(placeholder="Continue the story here!", lines=3, max_lines=3)
80
 
81
  with gr.Row():
82
  gen_button = gr.Button('Generate')
83
+ undo_button = gr.Button("Undo")
84
  res_button = gr.Button("Reset")
85
 
86
+ prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
87
+ gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
88
+ undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
89
+ res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
90
 
91
  demo.launch(inbrowser=True)
92