Spaces:
Starting
on
A10G
Starting
on
A10G
Update app.py
Browse files
app.py
CHANGED
@@ -389,11 +389,11 @@ def register(state, drawpad, model):
|
|
389 |
return state
|
390 |
|
391 |
|
392 |
-
@spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
393 |
def run(state, drawpad):
|
394 |
# ZeroGPU hack.
|
395 |
-
listener = Listener(opt.address, authkey=opt.authkey)
|
396 |
-
conn = listener.accept()
|
397 |
|
398 |
# Reset model.
|
399 |
model.device = torch.device('cuda')
|
@@ -407,16 +407,16 @@ def run(state, drawpad):
|
|
407 |
tic = time.time()
|
408 |
while True:
|
409 |
# Receive real-time mask inputs from the main process.
|
410 |
-
data = conn.recv()
|
411 |
-
if data is not None:
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
|
421 |
yield [state, model()]
|
422 |
toc = time.time()
|
@@ -441,7 +441,7 @@ def draw(state, drawpad):
|
|
441 |
# return
|
442 |
|
443 |
# ZeroGPU hack.
|
444 |
-
conn = Client(opt.address, authkey=opt.authkey)
|
445 |
|
446 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
447 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
@@ -466,20 +466,20 @@ def draw(state, drawpad):
|
|
466 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
467 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
468 |
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
data = dict(
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
)
|
481 |
-
conn.send(data)
|
482 |
-
conn.close()
|
483 |
|
484 |
### Load examples
|
485 |
|
|
|
389 |
return state
|
390 |
|
391 |
|
392 |
+
# @spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
393 |
def run(state, drawpad):
|
394 |
# ZeroGPU hack.
|
395 |
+
# listener = Listener(opt.address, authkey=opt.authkey)
|
396 |
+
# conn = listener.accept()
|
397 |
|
398 |
# Reset model.
|
399 |
model.device = torch.device('cuda')
|
|
|
407 |
tic = time.time()
|
408 |
while True:
|
409 |
# Receive real-time mask inputs from the main process.
|
410 |
+
# data = conn.recv()
|
411 |
+
# if data is not None:
|
412 |
+
# print('Received data!!!')
|
413 |
+
# for i in range(opt.max_palettes):
|
414 |
+
# model.update_single_layer(
|
415 |
+
# idx=i,
|
416 |
+
# mask=data['masks'][i],
|
417 |
+
# mask_strength=data['mask_strengths'][i],
|
418 |
+
# mask_std=data['mask_stds'][i],
|
419 |
+
# )
|
420 |
|
421 |
yield [state, model()]
|
422 |
toc = time.time()
|
|
|
441 |
# return
|
442 |
|
443 |
# ZeroGPU hack.
|
444 |
+
# conn = Client(opt.address, authkey=opt.authkey)
|
445 |
|
446 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
447 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
|
|
466 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
467 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
468 |
|
469 |
+
for i in range(len(has_masks)):
|
470 |
+
model.update_single_layer(
|
471 |
+
idx=i,
|
472 |
+
mask=masks[i],
|
473 |
+
mask_strength=mask_strengths[i],
|
474 |
+
mask_std=mask_stds[i],
|
475 |
+
)
|
476 |
+
# data = dict(
|
477 |
+
# masks=masks,
|
478 |
+
# mask_strengths=mask_strengths,
|
479 |
+
# mask_stds=mask_stds,
|
480 |
+
# )
|
481 |
+
# conn.send(data)
|
482 |
+
# conn.close()
|
483 |
|
484 |
### Load examples
|
485 |
|