lopho commited on
Commit
07b5d00
1 Parent(s): c122e26

saner defaults, more input sanitization, shorter queue

Browse files
README.md CHANGED
@@ -19,5 +19,3 @@ models:
19
  tags:
20
  - jax-diffusers-event
21
  ---
22
-
23
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
19
  tags:
20
  - jax-diffusers-event
21
  ---
 
 
app.py CHANGED
@@ -33,7 +33,7 @@ if _model.failed != False:
33
 
34
  _examples = []
35
  _expath = 'examples'
36
- for x in os.listdir(_expath):
37
  with open(os.path.join(_expath, x, 'params.json'), 'r') as f:
38
  ex = json.load(f)
39
  ex['image_input'] = None
@@ -56,22 +56,23 @@ def generate(
56
  cfg = 15.0,
57
  cfg_image = 9.0,
58
  seed = 0,
59
- fps = 24,
60
  num_frames = 24,
61
  height = 512,
62
  width = 512,
63
- scheduler_type = 'DPM',
64
- output_format = 'webp'
65
  ) -> str:
66
- num_frames = int(num_frames)
67
- inference_steps = int(inference_steps)
68
- height = int(height)
69
- width = int(width)
70
  height = (height // 64) * 64
71
  width = (width // 64) * 64
72
  cfg = max(cfg, 1.0)
73
  cfg_image = max(cfg_image, 1.0)
74
- seed = int(seed)
 
75
  if seed < 0:
76
  seed = -seed
77
  if hint_image is not None:
@@ -79,11 +80,12 @@ def generate(
79
  hint_image = hint_image.convert('RGB')
80
  if hint_image.size != (width, height):
81
  hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
 
82
  if scheduler_type not in SCHEDULERS:
83
- scheduler_type = 'DPM'
84
  output_format = output_format.lower()
85
  if output_format not in _output_formats:
86
- output_format = 'webp'
87
  mask_image = None
88
  images = _model.generate(
89
  prompt = [prompt] * _model.device_count,
@@ -100,26 +102,24 @@ def generate(
100
  scheduler_type = scheduler_type
101
  )
102
  _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
103
- buffer = BytesIO()
104
- images[1].save(
105
- buffer,
106
- format = output_format,
107
- save_all = True,
108
- append_images = images[2:],
109
- loop = 0,
110
- duration = round(1000 / fps),
111
- allow_mixed = True
112
- )
113
- data = f'data:image/{output_format};base64,' + base64.b64encode(buffer.getvalue()).decode()
114
- buffer.close()
115
- buffer = BytesIO()
116
- images[-1].save(buffer, format ='png')
117
- last_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
118
- buffer.close()
119
- buffer = BytesIO()
120
- images[0].save(buffer, format ='png')
121
- first_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
122
- buffer.close()
123
  return data, last_data, first_data
124
 
125
  def check_if_compiled(hint_image, inference_steps, height, width, num_frames, scheduler_type, message):
@@ -140,11 +140,11 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
140
  intro1 = gr.Markdown("""
141
  # Make-A-Video Stable Diffusion JAX
142
 
143
- We have extended a pretrained LDM inpainting image generation model with temporal convolutions and attention.
144
- By taking advantage of the extra 5 input channels of the inpaint model, we guide the video generation with a hint image.
145
  In this demo the hint image can be given by the user, otherwise it is generated by an generative image model.
146
 
147
- The temporal layers are a port of [Make-A-Video PyTorch](https://github.com/lucidrains/make-a-video-pytorch) to FLAX.
148
  The convolution is pseudo 3D and seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D.
149
  Temporal attention is purely self attention and also separately attends to time.
150
 
@@ -160,7 +160,7 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
160
  **Please be patient. The model might have to compile with current parameters.**
161
 
162
  This can take up to 5 minutes on the first run, and 2-3 minutes on later runs.
163
- The compilation will be cached and consecutive runs with the same parameters
164
  will be much faster.
165
 
166
  Changes to the following parameters require the model to compile
@@ -170,7 +170,9 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
170
  - Input image vs. no input image
171
  - Noise scheduler type
172
 
173
- If you encounter any issues, please report them here: [Space discussions](https://huggingface.co/spaces/TempoFunk/makeavid-sd-jax/discussions)
 
 
174
  """)
175
 
176
  with gr.Row(variant = variant):
@@ -221,14 +223,14 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
221
  inference_steps_input = gr.Slider(
222
  label = 'Steps',
223
  minimum = 2,
224
- maximum = 100,
225
  value = 20,
226
  step = 1,
227
  interactive = True
228
  )
229
  num_frames_input = gr.Slider(
230
  label = 'Number of frames to generate',
231
- minimum = 1,
232
  maximum = 24,
233
  step = 1,
234
  value = 24,
@@ -236,7 +238,7 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
236
  )
237
  width_input = gr.Slider(
238
  label = 'Width',
239
- minimum = 64,
240
  maximum = 576,
241
  step = 64,
242
  value = 512,
@@ -244,7 +246,7 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
244
  )
245
  height_input = gr.Slider(
246
  label = 'Height',
247
- minimum = 64,
248
  maximum = 576,
249
  step = 64,
250
  value = 512,
@@ -253,7 +255,7 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
253
  scheduler_input = gr.Dropdown(
254
  label = 'Noise scheduler',
255
  choices = list(SCHEDULERS.keys()),
256
- value = 'DPM',
257
  interactive = True
258
  )
259
  with gr.Row():
@@ -279,32 +281,33 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
279
  value = 'example.gif',
280
  interactive = False
281
  )
282
- tips = gr.Markdown('🤫 *Secret tip*: take the last frame as input for the next generation.')
283
  with gr.Row():
284
  last_frame_output = gr.Image(
285
  label = 'Last frame',
286
  interactive = False
287
  )
288
  first_frame_output = gr.Image(
289
- label = 'First frame',
290
  interactive = False
291
  )
292
  examples_lst = []
293
  for x in _examples:
294
  examples_lst.append([
295
- x['image_output'],
296
- x['prompt'],
297
- x['neg_prompt'],
298
- x['image_input'],
299
- x['cfg'],
300
- x['cfg_image'],
301
- x['seed'],
302
- x['fps'],
303
- x['num_frames'],
304
- x['height'],
305
- x['width'],
306
- x['scheduler'],
307
- x['format']
 
308
  ])
309
  examples = gr.Examples(
310
  examples = examples_lst,
@@ -317,10 +320,11 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
317
  cfg_image_input,
318
  seed_input,
319
  fps_input,
 
 
320
  num_frames_input,
321
  height_input,
322
  width_input,
323
- scheduler_input,
324
  output_format
325
  ],
326
  postprocess = False
@@ -355,6 +359,6 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
355
  )
356
  #cancel_button.click(fn = lambda: None, cancels = ev)
357
 
358
- demo.queue(concurrency_count = 1, max_size = 10)
359
  demo.launch()
360
 
 
33
 
34
  _examples = []
35
  _expath = 'examples'
36
+ for x in sorted(os.listdir(_expath)):
37
  with open(os.path.join(_expath, x, 'params.json'), 'r') as f:
38
  ex = json.load(f)
39
  ex['image_input'] = None
 
56
  cfg = 15.0,
57
  cfg_image = 9.0,
58
  seed = 0,
59
+ fps = 12,
60
  num_frames = 24,
61
  height = 512,
62
  width = 512,
63
+ scheduler_type = 'dpm',
64
+ output_format = 'gif'
65
  ) -> str:
66
+ num_frames = min(24, max(2, int(num_frames)))
67
+ inference_steps = min(60, max(2, int(inference_steps)))
68
+ height = min(576, max(256, int(height)))
69
+ width = min(576, max(256, int(width)))
70
  height = (height // 64) * 64
71
  width = (width // 64) * 64
72
  cfg = max(cfg, 1.0)
73
  cfg_image = max(cfg_image, 1.0)
74
+ fps = min(1000, max(1, int(fps)))
75
+ seed = min(2**32-2, int(seed))
76
  if seed < 0:
77
  seed = -seed
78
  if hint_image is not None:
 
80
  hint_image = hint_image.convert('RGB')
81
  if hint_image.size != (width, height):
82
  hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
83
+ scheduler_type = scheduler_type.lower()
84
  if scheduler_type not in SCHEDULERS:
85
+ scheduler_type = 'dpm'
86
  output_format = output_format.lower()
87
  if output_format not in _output_formats:
88
+ output_format = 'gif'
89
  mask_image = None
90
  images = _model.generate(
91
  prompt = [prompt] * _model.device_count,
 
102
  scheduler_type = scheduler_type
103
  )
104
  _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
105
+ with BytesIO() as buffer:
106
+ images[1].save(
107
+ buffer,
108
+ format = output_format,
109
+ save_all = True,
110
+ append_images = images[2:],
111
+ loop = 0,
112
+ duration = round(1000 / fps),
113
+ allow_mixed = True,
114
+ optimize = True
115
+ )
116
+ data = f'data:image/{output_format};base64,' + base64.b64encode(buffer.getvalue()).decode()
117
+ with BytesIO() as buffer:
118
+ images[-1].save(buffer, format = 'png', optimize = True)
119
+ last_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
120
+ with BytesIO() as buffer:
121
+ images[0].save(buffer, format ='png', optimize = True)
122
+ first_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
 
 
123
  return data, last_data, first_data
124
 
125
  def check_if_compiled(hint_image, inference_steps, height, width, num_frames, scheduler_type, message):
 
140
  intro1 = gr.Markdown("""
141
  # Make-A-Video Stable Diffusion JAX
142
 
143
+ We have extended a pretrained latent-diffusion inpainting image generation model with **temporal convolutions and attention**.
144
+ We guide the video generation with a hint image by taking advantage of the extra 5 input channels of the inpainting model.
145
  In this demo the hint image can be given by the user, otherwise it is generated by an generative image model.
146
 
147
+ The temporal layers are a port of [Make-A-Video PyTorch](https://github.com/lucidrains/make-a-video-pytorch) to [JAX](https://github.com/google/jax) utilizing [FLAX](https://github.com/google/flax).
148
  The convolution is pseudo 3D and seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D.
149
  Temporal attention is purely self attention and also separately attends to time.
150
 
 
160
  **Please be patient. The model might have to compile with current parameters.**
161
 
162
  This can take up to 5 minutes on the first run, and 2-3 minutes on later runs.
163
+ The compilation will be cached and later runs with the same parameters
164
  will be much faster.
165
 
166
  Changes to the following parameters require the model to compile
 
170
  - Input image vs. no input image
171
  - Noise scheduler type
172
 
173
+ If you encounter any issues, please report them here: [Space discussions](https://huggingface.co/spaces/TempoFunk/makeavid-sd-jax/discussions) (or DM [@lopho](https://twitter.com/lopho))
174
+
175
+ <small>Leave a ❤️ like if you like. Consider it a dopamine donation at no cost.</small>
176
  """)
177
 
178
  with gr.Row(variant = variant):
 
223
  inference_steps_input = gr.Slider(
224
  label = 'Steps',
225
  minimum = 2,
226
+ maximum = 60,
227
  value = 20,
228
  step = 1,
229
  interactive = True
230
  )
231
  num_frames_input = gr.Slider(
232
  label = 'Number of frames to generate',
233
+ minimum = 2,
234
  maximum = 24,
235
  step = 1,
236
  value = 24,
 
238
  )
239
  width_input = gr.Slider(
240
  label = 'Width',
241
+ minimum = 256,
242
  maximum = 576,
243
  step = 64,
244
  value = 512,
 
246
  )
247
  height_input = gr.Slider(
248
  label = 'Height',
249
+ minimum = 256,
250
  maximum = 576,
251
  step = 64,
252
  value = 512,
 
255
  scheduler_input = gr.Dropdown(
256
  label = 'Noise scheduler',
257
  choices = list(SCHEDULERS.keys()),
258
+ value = 'dpm',
259
  interactive = True
260
  )
261
  with gr.Row():
 
281
  value = 'example.gif',
282
  interactive = False
283
  )
284
+ tips = gr.Markdown('🤫 *Secret tip*: try using the last frame as input for the next generation.')
285
  with gr.Row():
286
  last_frame_output = gr.Image(
287
  label = 'Last frame',
288
  interactive = False
289
  )
290
  first_frame_output = gr.Image(
291
+ label = 'Initial frame',
292
  interactive = False
293
  )
294
  examples_lst = []
295
  for x in _examples:
296
  examples_lst.append([
297
+ x['image_output'],
298
+ x['prompt'],
299
+ x['neg_prompt'],
300
+ x['image_input'],
301
+ x['cfg'],
302
+ x['cfg_image'],
303
+ x['seed'],
304
+ x['fps'],
305
+ x['steps'],
306
+ x['scheduler'],
307
+ x['num_frames'],
308
+ x['height'],
309
+ x['width'],
310
+ x['format']
311
  ])
312
  examples = gr.Examples(
313
  examples = examples_lst,
 
320
  cfg_image_input,
321
  seed_input,
322
  fps_input,
323
+ inference_steps_input,
324
+ scheduler_input,
325
  num_frames_input,
326
  height_input,
327
  width_input,
 
328
  output_format
329
  ],
330
  postprocess = False
 
359
  )
360
  #cancel_button.click(fn = lambda: None, cancels = ev)
361
 
362
+ demo.queue(concurrency_count = 1, max_size = 8)
363
  demo.launch()
364
 
example.webp DELETED

Git LFS Details

  • SHA256: ffd7cb93989a8e311395799f6d6e566e698ad7654f9f5a471196d8c781f46c1f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
examples/example_04_furry_moster/params.json CHANGED
@@ -8,7 +8,7 @@
8
  "width": 512,
9
  "height": 512,
10
  "scheduler": "dpm",
11
- "fps": 20,
12
  "format": "gif",
13
  "num_frames": 24
14
  }
 
8
  "width": 512,
9
  "height": 512,
10
  "scheduler": "dpm",
11
+ "fps": 12,
12
  "format": "gif",
13
  "num_frames": 24
14
  }
examples/example_06_sophie/params.json CHANGED
@@ -3,7 +3,7 @@
3
  "neg_prompt": "",
4
  "cfg": 15,
5
  "cfg_image": 9,
6
- "seed": 1,
7
  "steps": 20,
8
  "width": 512,
9
  "height": 512,
 
3
  "neg_prompt": "",
4
  "cfg": 15,
5
  "cfg_image": 9,
6
+ "seed": 0,
7
  "steps": 20,
8
  "width": 512,
9
  "height": 512,
makeavid_sd/inference.py CHANGED
@@ -45,8 +45,8 @@ SchedulerStateType = Union[
45
  ]
46
 
47
  SCHEDULERS: Dict[str, SchedulerType] = {
48
- 'DPM': FlaxDPMSolverMultistepScheduler, # husbando
49
- 'DDIM': FlaxDDIMScheduler,
50
  #'PLMS': FlaxPNDMScheduler, # its not correctly implemented in diffusers, output is bad, but at least it "works"
51
  #'LMS': FlaxLMSDiscreteScheduler, # borked
52
  # image_latents, image_scheduler_state = scheduler.step(
@@ -224,8 +224,8 @@ class InferenceUNetPseudo3D:
224
  return tokens, neg_tokens, hint, mask
225
 
226
  def generate(self,
227
- prompt: Union[str, List[str]],
228
- inference_steps: int,
229
  hint_image: Union[Image.Image, List[Image.Image], None] = None,
230
  mask_image: Union[Image.Image, List[Image.Image], None] = None,
231
  neg_prompt: Union[str, List[str]] = '',
@@ -235,7 +235,7 @@ class InferenceUNetPseudo3D:
235
  width: int = 512,
236
  height: int = 512,
237
  seed: int = 0,
238
- scheduler_type: str = 'DDIM'
239
  ) -> List[List[Image.Image]]:
240
  assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
241
  assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'
 
45
  ]
46
 
47
  SCHEDULERS: Dict[str, SchedulerType] = {
48
+ 'dpm': FlaxDPMSolverMultistepScheduler, # husbando
49
+ 'ddim': FlaxDDIMScheduler,
50
  #'PLMS': FlaxPNDMScheduler, # its not correctly implemented in diffusers, output is bad, but at least it "works"
51
  #'LMS': FlaxLMSDiscreteScheduler, # borked
52
  # image_latents, image_scheduler_state = scheduler.step(
 
224
  return tokens, neg_tokens, hint, mask
225
 
226
  def generate(self,
227
+ prompt: Union[str, List[str]] = '',
228
+ inference_steps: int = 20,
229
  hint_image: Union[Image.Image, List[Image.Image], None] = None,
230
  mask_image: Union[Image.Image, List[Image.Image], None] = None,
231
  neg_prompt: Union[str, List[str]] = '',
 
235
  width: int = 512,
236
  height: int = 512,
237
  seed: int = 0,
238
+ scheduler_type: str = 'dpm'
239
  ) -> List[List[Image.Image]]:
240
  assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
241
  assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'