Spaces:
Running
Running
Upload 3 files
Browse files- app.py +36 -10
- efficient_inference_for_square_image.py +8 -2
- inference_for_arbitrary_resolution_image.py +8 -2
app.py
CHANGED
@@ -198,7 +198,9 @@ with gr.Blocks() as app:
|
|
198 |
form_split_res, form_split_num):
|
199 |
log.log = io.BytesIO()
|
200 |
if form_inference_mode == "Square Image":
|
201 |
-
from efficient_inference_for_square_image import parse_args, main_process
|
|
|
|
|
202 |
opt = parse_args()
|
203 |
opt.transform_mean = [.5, .5, .5]
|
204 |
opt.transform_var = [.5, .5, .5]
|
@@ -219,7 +221,9 @@ with gr.Blocks() as app:
|
|
219 |
raise gr.Error("Patches too big. Try to reduce the `split_res`!")
|
220 |
|
221 |
else:
|
222 |
-
from inference_for_arbitrary_resolution_image import parse_args, main_process
|
|
|
|
|
223 |
opt = parse_args()
|
224 |
opt.transform_mean = [.5, .5, .5]
|
225 |
opt.transform_var = [.5, .5, .5]
|
@@ -240,12 +244,20 @@ with gr.Blocks() as app:
|
|
240 |
raise gr.Error("Patches too big. Try to increase the `split_num`!")
|
241 |
|
242 |
|
243 |
-
form_start_btn.click(on_click_form_start_btn,
|
244 |
-
|
245 |
-
|
|
|
246 |
|
247 |
|
248 |
-
def on_click_form_reset_btn():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
log.log = io.BytesIO()
|
250 |
return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
|
251 |
interactive=False), gr.update(
|
@@ -253,13 +265,27 @@ with gr.Blocks() as app:
|
|
253 |
|
254 |
|
255 |
form_reset_btn.click(on_click_form_reset_btn,
|
256 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
-
def on_click_form_stop():
|
259 |
-
gr.close_all()
|
260 |
|
261 |
form_stop_btn.click(on_click_form_stop,
|
262 |
-
|
|
|
263 |
|
264 |
gr.Markdown("""
|
265 |
## Quick Start
|
|
|
198 |
form_split_res, form_split_num):
|
199 |
log.log = io.BytesIO()
|
200 |
if form_inference_mode == "Square Image":
|
201 |
+
from efficient_inference_for_square_image import parse_args, main_process, global_state
|
202 |
+
global_state[0] = 1
|
203 |
+
|
204 |
opt = parse_args()
|
205 |
opt.transform_mean = [.5, .5, .5]
|
206 |
opt.transform_var = [.5, .5, .5]
|
|
|
221 |
raise gr.Error("Patches too big. Try to reduce the `split_res`!")
|
222 |
|
223 |
else:
|
224 |
+
from inference_for_arbitrary_resolution_image import parse_args, main_process, global_state
|
225 |
+
global_state[0] = 1
|
226 |
+
|
227 |
opt = parse_args()
|
228 |
opt.transform_mean = [.5, .5, .5]
|
229 |
opt.transform_var = [.5, .5, .5]
|
|
|
244 |
raise gr.Error("Patches too big. Try to increase the `split_num`!")
|
245 |
|
246 |
|
247 |
+
generate = form_start_btn.click(on_click_form_start_btn,
|
248 |
+
inputs=[form_composite_image, form_mask_image, form_pretrained_dropdown,
|
249 |
+
form_inference_mode,
|
250 |
+
form_split_res, form_split_num], outputs=[form_harmonized_image])
|
251 |
|
252 |
|
253 |
+
def on_click_form_reset_btn(form_inference_mode):
|
254 |
+
if form_inference_mode == "Square Image":
|
255 |
+
from efficient_inference_for_square_image import global_state
|
256 |
+
global_state[0] = 0
|
257 |
+
else:
|
258 |
+
from inference_for_arbitrary_resolution_image import global_state
|
259 |
+
global_state[0] = 0
|
260 |
+
|
261 |
log.log = io.BytesIO()
|
262 |
return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
|
263 |
interactive=False), gr.update(
|
|
|
265 |
|
266 |
|
267 |
form_reset_btn.click(on_click_form_reset_btn,
|
268 |
+
inputs=[form_inference_mode],
|
269 |
+
outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
|
270 |
+
|
271 |
+
|
272 |
+
def on_click_form_stop(form_inference_mode):
|
273 |
+
if form_inference_mode == "Square Image":
|
274 |
+
from efficient_inference_for_square_image import global_state
|
275 |
+
global_state[0] = 0
|
276 |
+
else:
|
277 |
+
from inference_for_arbitrary_resolution_image import global_state
|
278 |
+
global_state[0] = 0
|
279 |
+
|
280 |
+
log.log = io.BytesIO()
|
281 |
+
return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
|
282 |
+
interactive=False), gr.update(
|
283 |
+
interactive=False)
|
284 |
|
|
|
|
|
285 |
|
286 |
form_stop_btn.click(on_click_form_stop,
|
287 |
+
inputs=[form_inference_mode],
|
288 |
+
outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
|
289 |
|
290 |
gr.Markdown("""
|
291 |
## Quick Start
|
efficient_inference_for_square_image.py
CHANGED
@@ -24,6 +24,7 @@ from utils.misc import normalize
|
|
24 |
|
25 |
import math
|
26 |
|
|
|
27 |
|
28 |
class single_image_dataset(torch.utils.data.Dataset):
|
29 |
def __init__(self, opt, composite_image=None, mask=None):
|
@@ -273,6 +274,10 @@ def inference(model, opt, composite_image=None, mask=None):
|
|
273 |
fg_INR_coordinates = coordinate_map[1:]
|
274 |
|
275 |
try:
|
|
|
|
|
|
|
|
|
276 |
if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
|
277 |
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
278 |
composite_image,
|
@@ -317,7 +322,8 @@ def inference(model, opt, composite_image=None, mask=None):
|
|
317 |
init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
|
318 |
start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
|
319 |
|
320 |
-
|
|
|
321 |
if opt.save_path is not None:
|
322 |
os.makedirs(opt.save_path, exist_ok=True)
|
323 |
cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
|
@@ -329,7 +335,7 @@ def main_process(opt, composite_image=None, mask=None):
|
|
329 |
|
330 |
model = build_model(opt).to(opt.device)
|
331 |
|
332 |
-
load_dict = torch.load(opt.pretrained
|
333 |
for k in load_dict.keys():
|
334 |
if k not in model.state_dict().keys():
|
335 |
print(f"Skip {k}")
|
|
|
24 |
|
25 |
import math
|
26 |
|
27 |
+
global_state = [1] # For Gradio Stop Button.
|
28 |
|
29 |
class single_image_dataset(torch.utils.data.Dataset):
|
30 |
def __init__(self, opt, composite_image=None, mask=None):
|
|
|
274 |
fg_INR_coordinates = coordinate_map[1:]
|
275 |
|
276 |
try:
|
277 |
+
if global_state[0] == 0:
|
278 |
+
print("Stop Harmonizing...!")
|
279 |
+
break
|
280 |
+
|
281 |
if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
|
282 |
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
283 |
composite_image,
|
|
|
322 |
init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
|
323 |
start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
|
324 |
|
325 |
+
if opt.device == "cuda":
|
326 |
+
print(f'Inference time: {time_all}')
|
327 |
if opt.save_path is not None:
|
328 |
os.makedirs(opt.save_path, exist_ok=True)
|
329 |
cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
|
|
|
335 |
|
336 |
model = build_model(opt).to(opt.device)
|
337 |
|
338 |
+
load_dict = torch.load(opt.pretrained)['model']
|
339 |
for k in load_dict.keys():
|
340 |
if k not in model.state_dict().keys():
|
341 |
print(f"Skip {k}")
|
inference_for_arbitrary_resolution_image.py
CHANGED
@@ -24,6 +24,7 @@ from utils.misc import normalize
|
|
24 |
|
25 |
import math
|
26 |
|
|
|
27 |
|
28 |
class single_image_dataset(torch.utils.data.Dataset):
|
29 |
def __init__(self, opt, composite_image=None, mask=None):
|
@@ -265,6 +266,10 @@ def inference(model, opt, composite_image=None, mask=None):
|
|
265 |
fg_INR_coordinates = coordinate_map[1:]
|
266 |
|
267 |
try:
|
|
|
|
|
|
|
|
|
268 |
if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
|
269 |
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
270 |
composite_image,
|
@@ -309,7 +314,8 @@ def inference(model, opt, composite_image=None, mask=None):
|
|
309 |
init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
|
310 |
start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
|
311 |
|
312 |
-
|
|
|
313 |
if opt.save_path is not None:
|
314 |
os.makedirs(opt.save_path, exist_ok=True)
|
315 |
cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
|
@@ -321,7 +327,7 @@ def main_process(opt, composite_image=None, mask=None):
|
|
321 |
|
322 |
model = build_model(opt).to(opt.device)
|
323 |
|
324 |
-
load_dict = torch.load(opt.pretrained
|
325 |
for k in load_dict.keys():
|
326 |
if k not in model.state_dict().keys():
|
327 |
print(f"Skip {k}")
|
|
|
24 |
|
25 |
import math
|
26 |
|
27 |
+
global_state = [1] # For Gradio Stop Button.
|
28 |
|
29 |
class single_image_dataset(torch.utils.data.Dataset):
|
30 |
def __init__(self, opt, composite_image=None, mask=None):
|
|
|
266 |
fg_INR_coordinates = coordinate_map[1:]
|
267 |
|
268 |
try:
|
269 |
+
if global_state[0] == 0:
|
270 |
+
print("Stop Harmonizing...!")
|
271 |
+
break
|
272 |
+
|
273 |
if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
|
274 |
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
275 |
composite_image,
|
|
|
314 |
init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
|
315 |
start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
|
316 |
|
317 |
+
if opt.device == "cuda":
|
318 |
+
print(f'Inference time: {time_all}')
|
319 |
if opt.save_path is not None:
|
320 |
os.makedirs(opt.save_path, exist_ok=True)
|
321 |
cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
|
|
|
327 |
|
328 |
model = build_model(opt).to(opt.device)
|
329 |
|
330 |
+
load_dict = torch.load(opt.pretrained)['model']
|
331 |
for k in load_dict.keys():
|
332 |
if k not in model.state_dict().keys():
|
333 |
print(f"Skip {k}")
|