ironjr commited on
Commit
eeac019
1 Parent(s): 49b7092

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -146,6 +146,7 @@ opt.colors = [
146
  # '#92C6EC',
147
  # '#FECAC0',
148
  ]
 
149
 
150
 
151
  ### Event handlers
@@ -294,7 +295,7 @@ def import_state(state, json_text):
294
  current_palette = state.current_palette
295
  # active_palettes = state.active_palettes
296
  state_dict = json.loads(json_text)
297
- for k in ('inpainting_mode', 'is_runing', 'active_palettes', 'current_palette'):
298
  if k in state_dict:
299
  del state_dict[k]
300
  state = argparse.Namespace(**state_dict)
@@ -362,15 +363,15 @@ def register(state, drawpad):
362
  # prompts, negative_prompts = preprocess_prompts(
363
  # prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
364
 
365
- model.update_background(
366
  background.convert('RGB'),
367
  prompt=None,
368
  negative_prompt=None,
369
  )
370
- state.prompts[0] = model.background.prompt
371
- state.neg_prompts[0] = model.background.negative_prompt
372
 
373
- model.update_layers(
374
  prompts=prompts,
375
  negative_prompts=negative_prompts,
376
  masks=masks.to(device),
@@ -384,10 +385,10 @@ def register(state, drawpad):
384
 
385
  @spaces.GPU(duration=120)
386
  def run(state, drawpad):
387
- model.device = torch.device('cuda')
388
- model.reset_seed(model.generator, opt.seed)
389
- model.reset_latent()
390
- model.prepare()
391
 
392
  state = register(state, drawpad)
393
  state.is_running = True
@@ -438,7 +439,7 @@ def draw(state, drawpad):
438
  # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
439
 
440
  for i in range(len(has_masks)):
441
- model.update_single_layer(
442
  idx=i,
443
  mask=masks[i],
444
  mask_strength=mask_strengths[i],
@@ -516,10 +517,6 @@ css = f"""
516
  width: 100%;
517
  aspect-ratio: {opt.width} / {opt.height};
518
  }}
519
-
520
- .layer-wrap {{
521
- display: none;
522
- }}
523
  """
524
 
525
  for i in range(opt.max_palettes + 1):
@@ -604,6 +601,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, head=head) as demo:
604
  state.model_id = opt.model
605
  state.style_name = '(None)'
606
  state.quality_name = 'Standard v3.1'
 
607
 
608
  # State variables (one-hot).
609
  state.active_palettes = 5
@@ -1142,7 +1140,7 @@ async () => {{
1142
  # api_name='quality_select',
1143
  # )
1144
 
1145
- iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
1146
  iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
1147
  state,
1148
  *iface.btn_semantics,
 
146
  # '#92C6EC',
147
  # '#FECAC0',
148
  ]
149
+ opt.excluded_keys = ['inpainting_mode', 'is_runing', 'active_palettes', 'current_palette', 'model']
150
 
151
 
152
  ### Event handlers
 
295
  current_palette = state.current_palette
296
  # active_palettes = state.active_palettes
297
  state_dict = json.loads(json_text)
298
+ for k in opt.excluded_keys:
299
  if k in state_dict:
300
  del state_dict[k]
301
  state = argparse.Namespace(**state_dict)
 
363
  # prompts, negative_prompts = preprocess_prompts(
364
  # prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
365
 
366
+ state.model.update_background(
367
  background.convert('RGB'),
368
  prompt=None,
369
  negative_prompt=None,
370
  )
371
+ state.prompts[0] = state.model.background.prompt
372
+ state.neg_prompts[0] = state.model.background.negative_prompt
373
 
374
+ state.model.update_layers(
375
  prompts=prompts,
376
  negative_prompts=negative_prompts,
377
  masks=masks.to(device),
 
385
 
386
  @spaces.GPU(duration=120)
387
  def run(state, drawpad):
388
+ state.model.device = torch.device('cuda')
389
+ state.model.reset_seed(state.model.generator, opt.seed)
390
+ state.model.reset_latent()
391
+ state.model.prepare()
392
 
393
  state = register(state, drawpad)
394
  state.is_running = True
 
439
  # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
440
 
441
  for i in range(len(has_masks)):
442
+ state.model.update_single_layer(
443
  idx=i,
444
  mask=masks[i],
445
  mask_strength=mask_strengths[i],
 
517
  width: 100%;
518
  aspect-ratio: {opt.width} / {opt.height};
519
  }}
 
 
 
 
520
  """
521
 
522
  for i in range(opt.max_palettes + 1):
 
601
  state.model_id = opt.model
602
  state.style_name = '(None)'
603
  state.quality_name = 'Standard v3.1'
604
+ state.model = model
605
 
606
  # State variables (one-hot).
607
  state.active_palettes = 5
 
1140
  # api_name='quality_select',
1141
  # )
1142
 
1143
+ iface.btn_export_state.click(lambda x: {k: v for k, v in vars(x).items() if k not in opt.excluded_keys}, state, iface.json_state_export)
1144
  iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
1145
  state,
1146
  *iface.btn_semantics,