ironjr commited on
Commit
919599a
1 Parent(s): da56a9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -28
app.py CHANGED
@@ -389,11 +389,11 @@ def register(state, drawpad, model):
389
  return state
390
 
391
 
392
- @spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
393
  def run(state, drawpad):
394
  # ZeroGPU hack.
395
- listener = Listener(opt.address, authkey=opt.authkey)
396
- conn = listener.accept()
397
 
398
  # Reset model.
399
  model.device = torch.device('cuda')
@@ -407,16 +407,16 @@ def run(state, drawpad):
407
  tic = time.time()
408
  while True:
409
  # Receive real-time mask inputs from the main process.
410
- data = conn.recv()
411
- if data is not None:
412
- print('Received data!!!')
413
- for i in range(opt.max_palettes):
414
- model.update_single_layer(
415
- idx=i,
416
- mask=data['masks'][i],
417
- mask_strength=data['mask_strengths'][i],
418
- mask_std=data['mask_stds'][i],
419
- )
420
 
421
  yield [state, model()]
422
  toc = time.time()
@@ -441,7 +441,7 @@ def draw(state, drawpad):
441
  # return
442
 
443
  # ZeroGPU hack.
444
- conn = Client(opt.address, authkey=opt.authkey)
445
 
446
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
447
  foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
@@ -466,20 +466,20 @@ def draw(state, drawpad):
466
  # mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
467
  # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
468
 
469
- # for i in range(len(has_masks)):
470
- # model.update_single_layer(
471
- # idx=i,
472
- # mask=masks[i],
473
- # mask_strength=mask_strengths[i],
474
- # mask_std=mask_stds[i],
475
- # )
476
- data = dict(
477
- masks=masks,
478
- mask_strengths=mask_strengths,
479
- mask_stds=mask_stds,
480
- )
481
- conn.send(data)
482
- conn.close()
483
 
484
  ### Load examples
485
 
 
389
  return state
390
 
391
 
392
+ # @spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
393
  def run(state, drawpad):
394
  # ZeroGPU hack.
395
+ # listener = Listener(opt.address, authkey=opt.authkey)
396
+ # conn = listener.accept()
397
 
398
  # Reset model.
399
  model.device = torch.device('cuda')
 
407
  tic = time.time()
408
  while True:
409
  # Receive real-time mask inputs from the main process.
410
+ # data = conn.recv()
411
+ # if data is not None:
412
+ # print('Received data!!!')
413
+ # for i in range(opt.max_palettes):
414
+ # model.update_single_layer(
415
+ # idx=i,
416
+ # mask=data['masks'][i],
417
+ # mask_strength=data['mask_strengths'][i],
418
+ # mask_std=data['mask_stds'][i],
419
+ # )
420
 
421
  yield [state, model()]
422
  toc = time.time()
 
441
  # return
442
 
443
  # ZeroGPU hack.
444
+ # conn = Client(opt.address, authkey=opt.authkey)
445
 
446
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
447
  foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
 
466
  # mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
467
  # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
468
 
469
+ for i in range(len(has_masks)):
470
+ model.update_single_layer(
471
+ idx=i,
472
+ mask=masks[i],
473
+ mask_strength=mask_strengths[i],
474
+ mask_std=mask_stds[i],
475
+ )
476
+ # data = dict(
477
+ # masks=masks,
478
+ # mask_strengths=mask_strengths,
479
+ # mask_stds=mask_stds,
480
+ # )
481
+ # conn.send(data)
482
+ # conn.close()
483
 
484
  ### Load examples
485