WindVChen commited on
Commit
ef47508
β€’
1 Parent(s): 41536eb

Upload 3 files

Browse files
app.py CHANGED
@@ -198,7 +198,9 @@ with gr.Blocks() as app:
198
  form_split_res, form_split_num):
199
  log.log = io.BytesIO()
200
  if form_inference_mode == "Square Image":
201
- from efficient_inference_for_square_image import parse_args, main_process
 
 
202
  opt = parse_args()
203
  opt.transform_mean = [.5, .5, .5]
204
  opt.transform_var = [.5, .5, .5]
@@ -219,7 +221,9 @@ with gr.Blocks() as app:
219
  raise gr.Error("Patches too big. Try to reduce the `split_res`!")
220
 
221
  else:
222
- from inference_for_arbitrary_resolution_image import parse_args, main_process
 
 
223
  opt = parse_args()
224
  opt.transform_mean = [.5, .5, .5]
225
  opt.transform_var = [.5, .5, .5]
@@ -240,12 +244,20 @@ with gr.Blocks() as app:
240
  raise gr.Error("Patches too big. Try to increase the `split_num`!")
241
 
242
 
243
- form_start_btn.click(on_click_form_start_btn,
244
- inputs=[form_composite_image, form_mask_image, form_pretrained_dropdown, form_inference_mode,
245
- form_split_res, form_split_num], outputs=[form_harmonized_image])
 
246
 
247
 
248
- def on_click_form_reset_btn():
 
 
 
 
 
 
 
249
  log.log = io.BytesIO()
250
  return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
251
  interactive=False), gr.update(
@@ -253,13 +265,27 @@ with gr.Blocks() as app:
253
 
254
 
255
  form_reset_btn.click(on_click_form_reset_btn,
256
- inputs=None, outputs=[form_log, form_composite_image, form_mask_image, form_start_btn])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- def on_click_form_stop():
259
- gr.close_all()
260
 
261
  form_stop_btn.click(on_click_form_stop,
262
- inputs=None, outputs=None)
 
263
 
264
  gr.Markdown("""
265
  ## Quick Start
 
198
  form_split_res, form_split_num):
199
  log.log = io.BytesIO()
200
  if form_inference_mode == "Square Image":
201
+ from efficient_inference_for_square_image import parse_args, main_process, global_state
202
+ global_state[0] = 1
203
+
204
  opt = parse_args()
205
  opt.transform_mean = [.5, .5, .5]
206
  opt.transform_var = [.5, .5, .5]
 
221
  raise gr.Error("Patches too big. Try to reduce the `split_res`!")
222
 
223
  else:
224
+ from inference_for_arbitrary_resolution_image import parse_args, main_process, global_state
225
+ global_state[0] = 1
226
+
227
  opt = parse_args()
228
  opt.transform_mean = [.5, .5, .5]
229
  opt.transform_var = [.5, .5, .5]
 
244
  raise gr.Error("Patches too big. Try to increase the `split_num`!")
245
 
246
 
247
+ generate = form_start_btn.click(on_click_form_start_btn,
248
+ inputs=[form_composite_image, form_mask_image, form_pretrained_dropdown,
249
+ form_inference_mode,
250
+ form_split_res, form_split_num], outputs=[form_harmonized_image])
251
 
252
 
253
+ def on_click_form_reset_btn(form_inference_mode):
254
+ if form_inference_mode == "Square Image":
255
+ from efficient_inference_for_square_image import global_state
256
+ global_state[0] = 0
257
+ else:
258
+ from inference_for_arbitrary_resolution_image import global_state
259
+ global_state[0] = 0
260
+
261
  log.log = io.BytesIO()
262
  return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
263
  interactive=False), gr.update(
 
265
 
266
 
267
  form_reset_btn.click(on_click_form_reset_btn,
268
+ inputs=[form_inference_mode],
269
+ outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
270
+
271
+
272
+ def on_click_form_stop(form_inference_mode):
273
+ if form_inference_mode == "Square Image":
274
+ from efficient_inference_for_square_image import global_state
275
+ global_state[0] = 0
276
+ else:
277
+ from inference_for_arbitrary_resolution_image import global_state
278
+ global_state[0] = 0
279
+
280
+ log.log = io.BytesIO()
281
+ return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
282
+ interactive=False), gr.update(
283
+ interactive=False)
284
 
 
 
285
 
286
  form_stop_btn.click(on_click_form_stop,
287
+ inputs=[form_inference_mode],
288
+ outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
289
 
290
  gr.Markdown("""
291
  ## Quick Start
efficient_inference_for_square_image.py CHANGED
@@ -24,6 +24,7 @@ from utils.misc import normalize
24
 
25
  import math
26
 
 
27
 
28
  class single_image_dataset(torch.utils.data.Dataset):
29
  def __init__(self, opt, composite_image=None, mask=None):
@@ -273,6 +274,10 @@ def inference(model, opt, composite_image=None, mask=None):
273
  fg_INR_coordinates = coordinate_map[1:]
274
 
275
  try:
 
 
 
 
276
  if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
277
  fg_content_bg_appearance_construct, _, lut_transform_image = model(
278
  composite_image,
@@ -317,7 +322,8 @@ def inference(model, opt, composite_image=None, mask=None):
317
  init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
318
  start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
319
 
320
- print(f'Inference time: {time_all}')
 
321
  if opt.save_path is not None:
322
  os.makedirs(opt.save_path, exist_ok=True)
323
  cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
@@ -329,7 +335,7 @@ def main_process(opt, composite_image=None, mask=None):
329
 
330
  model = build_model(opt).to(opt.device)
331
 
332
- load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
333
  for k in load_dict.keys():
334
  if k not in model.state_dict().keys():
335
  print(f"Skip {k}")
 
24
 
25
  import math
26
 
27
+ global_state = [1] # For Gradio Stop Button.
28
 
29
  class single_image_dataset(torch.utils.data.Dataset):
30
  def __init__(self, opt, composite_image=None, mask=None):
 
274
  fg_INR_coordinates = coordinate_map[1:]
275
 
276
  try:
277
+ if global_state[0] == 0:
278
+ print("Stop Harmonizing...!")
279
+ break
280
+
281
  if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
282
  fg_content_bg_appearance_construct, _, lut_transform_image = model(
283
  composite_image,
 
322
  init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
323
  start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
324
 
325
+ if opt.device == "cuda":
326
+ print(f'Inference time: {time_all}')
327
  if opt.save_path is not None:
328
  os.makedirs(opt.save_path, exist_ok=True)
329
  cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
 
335
 
336
  model = build_model(opt).to(opt.device)
337
 
338
+ load_dict = torch.load(opt.pretrained)['model']
339
  for k in load_dict.keys():
340
  if k not in model.state_dict().keys():
341
  print(f"Skip {k}")
inference_for_arbitrary_resolution_image.py CHANGED
@@ -24,6 +24,7 @@ from utils.misc import normalize
24
 
25
  import math
26
 
 
27
 
28
  class single_image_dataset(torch.utils.data.Dataset):
29
  def __init__(self, opt, composite_image=None, mask=None):
@@ -265,6 +266,10 @@ def inference(model, opt, composite_image=None, mask=None):
265
  fg_INR_coordinates = coordinate_map[1:]
266
 
267
  try:
 
 
 
 
268
  if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
269
  fg_content_bg_appearance_construct, _, lut_transform_image = model(
270
  composite_image,
@@ -309,7 +314,8 @@ def inference(model, opt, composite_image=None, mask=None):
309
  init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
310
  start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
311
 
312
- print(f'Inference time: {time_all}')
 
313
  if opt.save_path is not None:
314
  os.makedirs(opt.save_path, exist_ok=True)
315
  cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
@@ -321,7 +327,7 @@ def main_process(opt, composite_image=None, mask=None):
321
 
322
  model = build_model(opt).to(opt.device)
323
 
324
- load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
325
  for k in load_dict.keys():
326
  if k not in model.state_dict().keys():
327
  print(f"Skip {k}")
 
24
 
25
  import math
26
 
27
+ global_state = [1] # For Gradio Stop Button.
28
 
29
  class single_image_dataset(torch.utils.data.Dataset):
30
  def __init__(self, opt, composite_image=None, mask=None):
 
266
  fg_INR_coordinates = coordinate_map[1:]
267
 
268
  try:
269
+ if global_state[0] == 0:
270
+ print("Stop Harmonizing...!")
271
+ break
272
+
273
  if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
274
  fg_content_bg_appearance_construct, _, lut_transform_image = model(
275
  composite_image,
 
314
  init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
315
  start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
316
 
317
+ if opt.device == "cuda":
318
+ print(f'Inference time: {time_all}')
319
  if opt.save_path is not None:
320
  os.makedirs(opt.save_path, exist_ok=True)
321
  cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
 
327
 
328
  model = build_model(opt).to(opt.device)
329
 
330
+ load_dict = torch.load(opt.pretrained)['model']
331
  for k in load_dict.keys():
332
  if k not in model.state_dict().keys():
333
  print(f"Skip {k}")