mikegarts commited on
Commit
d523bde
1 Parent(s): 16ed949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -11
app.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import os
 
2
  import gradio as gr
3
 
4
  import torch
@@ -18,7 +20,7 @@ if has_cuda:
18
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
19
  device = "cuda"
20
  else:
21
- pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN)
22
  device = "cpu"
23
 
24
  pipe.to(device)
@@ -32,6 +34,8 @@ tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
32
 
33
  summarizer = pipeline("summarization")
34
 
 
 
35
  def break_until_dot(txt):
36
  return txt.rsplit('.', 1)[0] + '.'
37
 
@@ -41,7 +45,8 @@ def generate(prompt):
41
 
42
  outputs = model.generate(
43
  input_ids=input_ids,
44
- max_length=120,
 
45
  temperature=0.7,
46
  num_return_sequences=3,
47
  do_sample=True
@@ -49,10 +54,6 @@ def generate(prompt):
49
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
  return break_until_dot(decoded)
51
 
52
- def generate_image(prompt, inference_steps):
53
- prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
54
- img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps)
55
- return img.images[0]
56
 
57
  def generate_story(prompt):
58
  story = generate(prompt=prompt)
@@ -60,7 +61,53 @@ def generate_story(prompt):
60
  summary = break_until_dot(summary)
61
  return story, summary, gr.update(visible=True)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  title = gr.Markdown('## Lord of the rings app')
65
  description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
66
  f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
@@ -72,13 +119,32 @@ with gr.Blocks() as demo:
72
  bt_make_text = gr.Button("Generate text")
73
  bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
74
 
75
- image = gr.Image(label='Illustration for your story', shape=(512, 512))
76
- inference_steps = gr.Slider(5, 30, value=15, step=1, label=f"Num inference steps (more steps makes a better image but takes more time)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
79
- bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps], outputs=image)
 
 
 
 
 
80
 
81
  if READ_TOKEN:
82
- demo.launch()
83
  else:
84
- demo.launch(share=True, debug=True)
 
1
+ import time
2
  import os
3
+ import PIL
4
  import gradio as gr
5
 
6
  import torch
 
20
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
21
  device = "cuda"
22
  else:
23
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN)
24
  device = "cpu"
25
 
26
  pipe.to(device)
 
34
 
35
  summarizer = pipeline("summarization")
36
 
37
+ #######################################################
38
+
39
  def break_until_dot(txt):
40
  return txt.rsplit('.', 1)[0] + '.'
41
 
 
45
 
46
  outputs = model.generate(
47
  input_ids=input_ids,
48
+ max_length=120,
49
+ min_length=50,
50
  temperature=0.7,
51
  num_return_sequences=3,
52
  do_sample=True
 
54
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  return break_until_dot(decoded)
56
 
 
 
 
 
57
 
58
  def generate_story(prompt):
59
  story = generate(prompt=prompt)
 
61
  summary = break_until_dot(summary)
62
  return story, summary, gr.update(visible=True)
63
 
64
+ def on_change_event(app_state):
65
+ print(f'on_change_event {app_state}')
66
+ if app_state and app_state['running'] and app_state['img']:
67
+ img = app_state['img']
68
+ step = app_state['step']
69
+ print(f'Updating the image:! {app_state}')
70
+ app_state['dots'] += 1
71
+ app_state['dots'] = app_state['dots'] % 10
72
+ message = app_state['status_msg'] + ' *' * app_state['dots']
73
+ print (f'message={message}')
74
+ return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message)
75
+ else:
76
+ return gr.update(label='images list'), gr.update(value='')
77
+
78
  with gr.Blocks() as demo:
79
+
80
+ def generate_image(prompt, inference_steps, app_state):
81
+ app_state['running'] = True
82
+ app_state['img_list'] = []
83
+ app_state['status_msg'] = 'Starting'
84
+ def callback(step, ts, latents):
85
+ app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}'
86
+ latents = 1 / 0.18215 * latents
87
+ res = pipe.vae.decode(latents).sample
88
+ res = (res / 2 + 0.5).clamp(0, 1)
89
+ res = res.cpu().permute(0, 2, 3, 1).detach().numpy()
90
+ res = pipe.numpy_to_pil(res)[0]
91
+ app_state['img'] = res
92
+ app_state['step'] = step
93
+ app_state['img_list'].append(res)
94
+ app_state['status_msg'] = f'Generating step ({step + 1})'
95
+
96
+ prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
97
+ img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=1)
98
+ app_state['running'] = False
99
+ app_state['img'] = None
100
+ app_state['status_msg'] = ''
101
+ app_state['dots'] = 0
102
+ return gr.update(value=img.images[0], label='Generated image')
103
+
104
+ app_state = gr.State({'img': None,
105
+ 'step':0,
106
+ 'running':False,
107
+ 'status_msg': '',
108
+ 'img_list': [],
109
+ 'dots': 0
110
+ })
111
  title = gr.Markdown('## Lord of the rings app')
112
  description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
113
  f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
 
119
  bt_make_text = gr.Button("Generate text")
120
  bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
121
 
122
+ img_description = gr.Markdown('Image generation takes some time'
123
+ ' but here you can see what is generated from the latent state of the diffuser every few steps.'
124
+ ' Usually there is a significant improvement around step 12 that yields a much better image')
125
+ status_msg = gr.Markdown()
126
+
127
+ gallery = gr.Gallery()
128
+ image = gr.Image(label='Illustration for your story', show_label=True)
129
+
130
+ gallery.style(grid=[4])
131
+
132
+ inference_steps = gr.Slider(5, 30,
133
+ value=20,
134
+ step=1,
135
+ visible=True,
136
+ label=f"Num inference steps (more steps yields a better image but takes more time)")
137
+
138
 
139
  bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
140
+ bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image)
141
+
142
+ eventslider = gr.Slider(visible=False)
143
+ dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5)
144
+ eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep])
145
+
146
 
147
  if READ_TOKEN:
148
+ demo.queue().launch()
149
  else:
150
+ demo.queue().launch(share=True, debug=True)