liuyizhang commited on
Commit
e12d135
1 Parent(s): 5ee6e09

update app.py

Browse files
Files changed (2) hide show
  1. api_client.py +27 -11
  2. app.py +87 -63
api_client.py CHANGED
@@ -52,18 +52,34 @@ def base64_to_PILImage(im_b64):
52
  pil_img = Image.open(io.BytesIO(im_bytes))
53
  return pil_img
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  image_file = 'dog.png'
56
- data = {'remove_texts': "小狗 . 椅子",
57
- 'extend': 20,
58
- 'img': imgFile_to_base64(image_file),
59
- }
60
 
61
- ret = request_post(url, data, timeout=600, headers = None)
62
- print(len(ret['result']['imgs']))
63
 
64
- for img in ret['result']['imgs']:
65
- pilImage = base64_to_PILImage(img)
66
- plt.imshow(pilImage)
67
- plt.show()
68
- plt.clf()
69
 
 
52
  pil_img = Image.open(io.BytesIO(im_bytes))
53
  return pil_img
54
 
55
+ def cleaner_img(image_file, remove_texts, mask_extend=20, disp_debug=True):
56
+ data = {'remove_texts': remove_texts,
57
+ 'mask_extend': mask_extend,
58
+ 'img': imgFile_to_base64(image_file),
59
+ }
60
+ ret = request_post(url, data, timeout=600, headers = None)
61
+ if ret['code'] == 0:
62
+ if disp_debug:
63
+ for img in ret['result']['imgs']:
64
+ pilImage = base64_to_PILImage(img)
65
+ plt.imshow(pilImage)
66
+ plt.show()
67
+ plt.clf()
68
+ plt.close('all')
69
+ img_len = len(ret['result']['imgs'])
70
+ pilImage = base64_to_PILImage(ret['result']['imgs'][img_len-1])
71
+ else:
72
+ pilImage = None
73
+ return pilImage, ret
74
+
75
  image_file = 'dog.png'
76
+ remove_texts = "小狗 . 椅子"
 
 
 
77
 
78
+ mask_extend = 20
79
+ pil_image, ret = cleaner_img(image_file, remove_texts, mask_extend, disp_debug=False)
80
 
81
+ plt.imshow(pil_image)
82
+ plt.show()
83
+ plt.clf()
84
+ plt.close()
 
85
 
app.py CHANGED
@@ -3,7 +3,17 @@ import warnings
3
  warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
- os.system("pip install gradio==3.40.1")
 
 
 
 
 
 
 
 
 
 
7
  import gradio as gr
8
 
9
  from loguru import logger
@@ -35,7 +45,10 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
35
 
36
  import cv2
37
  import numpy as np
38
- import matplotlib.pyplot as plt
 
 
 
39
 
40
  groundingdino_enable = True
41
  sam_enable = True
@@ -332,60 +345,63 @@ def load_lama_cleaner_model(device):
332
  )
333
 
334
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
335
- ori_image = image
336
- if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
337
- # rotate image
338
- ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
339
- image = ori_image
340
-
341
- original_shape = ori_image.shape
342
- interpolation = cv2.INTER_CUBIC
343
-
344
- size_limit = cleaner_size_limit
345
- if size_limit == -1:
346
- size_limit = max(image.shape)
347
- else:
348
- size_limit = int(size_limit)
349
-
350
- config = lama_Config(
351
- ldm_steps=25,
352
- ldm_sampler='plms',
353
- zits_wireframe=True,
354
- hd_strategy='Original',
355
- hd_strategy_crop_margin=196,
356
- hd_strategy_crop_trigger_size=1280,
357
- hd_strategy_resize_limit=2048,
358
- prompt='',
359
- use_croper=False,
360
- croper_x=0,
361
- croper_y=0,
362
- croper_height=512,
363
- croper_width=512,
364
- sd_mask_blur=5,
365
- sd_strength=0.75,
366
- sd_steps=50,
367
- sd_guidance_scale=7.5,
368
- sd_sampler='ddim',
369
- sd_seed=42,
370
- cv2_flag='INPAINT_NS',
371
- cv2_radius=5,
372
- )
373
-
374
- if config.sd_seed == -1:
375
- config.sd_seed = random.randint(1, 999999999)
 
376
 
377
- # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
378
- image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
379
- # logger.info(f"Resized image shape_1_: {image.shape}")
380
-
381
- # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
382
- mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
383
- # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
384
 
385
- res_np_img = lama_cleaner_model(image, mask, config)
386
- torch.cuda.empty_cache()
387
-
388
- image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
 
 
389
  return image
390
 
391
  class Ram_Predictor(RamPredictor):
@@ -691,6 +707,8 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
691
  plt.axis('off')
692
  image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
693
  plt.savefig(image_path, bbox_inches="tight")
 
 
694
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
695
  os.remove(image_path)
696
  output_images.append(Image.fromarray(segment_image_result))
@@ -757,6 +775,10 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
757
 
758
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
759
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
 
 
 
 
760
  # output_images.append(image_inpainting)
761
  # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
762
 
@@ -975,7 +997,10 @@ class API_Starter:
975
  request_data = request.data.decode('utf-8')
976
  data = json.loads(request_data)
977
  result = self.handle_data(data)
978
- ret_json = {'code': 0, 'result':result}
 
 
 
979
  return jsonify(ret_json)
980
 
981
  self.app = app
@@ -996,15 +1021,18 @@ class API_Starter:
996
  inpaint_mode = "merge",
997
  mask_source_radio = "type what to detect below",
998
  remove_mode = "rectangle", # ["segment", "rectangle"]
999
- remove_mask_extend = "10",
1000
  num_relation = 5,
1001
  kosmos_input = None,
1002
  cleaner_size_limit = -1,
1003
  )
1004
  output_images = results[0]
 
 
1005
  ret_json_images = []
1006
  file_temp = int(time.time())
1007
  count = 0
 
1008
  for image_pil in output_images:
1009
  try:
1010
  img_format = image_pil.format.lower()
@@ -1086,16 +1114,12 @@ if __name__ == "__main__":
1086
  # print(f'ram_model__{get_model_device(ram_model)}')
1087
  # print(f'kosmos_model__{get_model_device(kosmos_model)}')
1088
 
1089
- if os.environ.get('IS_MY_DEBUG') is None:
1090
  # Provide gradio services
1091
  main_gradio(args)
1092
  else:
1093
- if 0 == 0:
1094
- # Provide API services
1095
- main_api(args)
1096
- else:
1097
- # Provide gradio services
1098
- main_gradio(args)
1099
 
1100
 
1101
 
 
3
  warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
+
7
+ run_gradio = False
8
+ if os.environ.get('IS_MY_DEBUG') is None:
9
+ run_gradio = True
10
+ else:
11
+ run_gradio = False
12
+ # run_gradio = True
13
+
14
+ if run_gradio:
15
+ os.system("pip install gradio==3.40.1")
16
+
17
  import gradio as gr
18
 
19
  from loguru import logger
 
45
 
46
  import cv2
47
  import numpy as np
48
+ import matplotlib
49
+ matplotlib.use('AGG')
50
+ plt = matplotlib.pyplot
51
+ # import matplotlib.pyplot as plt
52
 
53
  groundingdino_enable = True
54
  sam_enable = True
 
345
  )
346
 
347
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
348
+ try:
349
+ ori_image = image
350
+ if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
351
+ # rotate image
352
+ ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
353
+ image = ori_image
354
+
355
+ original_shape = ori_image.shape
356
+ interpolation = cv2.INTER_CUBIC
357
+
358
+ size_limit = cleaner_size_limit
359
+ if size_limit == -1:
360
+ size_limit = max(image.shape)
361
+ else:
362
+ size_limit = int(size_limit)
363
+
364
+ config = lama_Config(
365
+ ldm_steps=25,
366
+ ldm_sampler='plms',
367
+ zits_wireframe=True,
368
+ hd_strategy='Original',
369
+ hd_strategy_crop_margin=196,
370
+ hd_strategy_crop_trigger_size=1280,
371
+ hd_strategy_resize_limit=2048,
372
+ prompt='',
373
+ use_croper=False,
374
+ croper_x=0,
375
+ croper_y=0,
376
+ croper_height=512,
377
+ croper_width=512,
378
+ sd_mask_blur=5,
379
+ sd_strength=0.75,
380
+ sd_steps=50,
381
+ sd_guidance_scale=7.5,
382
+ sd_sampler='ddim',
383
+ sd_seed=42,
384
+ cv2_flag='INPAINT_NS',
385
+ cv2_radius=5,
386
+ )
387
+
388
+ if config.sd_seed == -1:
389
+ config.sd_seed = random.randint(1, 999999999)
390
 
391
+ # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
392
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
393
+ # logger.info(f"Resized image shape_1_: {image.shape}")
394
+
395
+ # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
396
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
397
+ # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
398
 
399
+ res_np_img = lama_cleaner_model(image, mask, config)
400
+ torch.cuda.empty_cache()
401
+
402
+ image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
403
+ except Exception as e:
404
+ image = None
405
  return image
406
 
407
  class Ram_Predictor(RamPredictor):
 
707
  plt.axis('off')
708
  image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
709
  plt.savefig(image_path, bbox_inches="tight")
710
+ plt.clf()
711
+ plt.close('all')
712
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
713
  os.remove(image_path)
714
  output_images.append(Image.fromarray(segment_image_result))
 
775
 
776
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
777
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
778
+ if image_inpainting is None:
779
+ logger.info(f'run_anything_task_failed_')
780
+ return None, None, None, None, None, None, None
781
+
782
  # output_images.append(image_inpainting)
783
  # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
784
 
 
997
  request_data = request.data.decode('utf-8')
998
  data = json.loads(request_data)
999
  result = self.handle_data(data)
1000
+ if result is None:
1001
+ ret_json = {'code': -2, 'reason':'handle error'}
1002
+ else:
1003
+ ret_json = {'code': 0, 'result':result}
1004
  return jsonify(ret_json)
1005
 
1006
  self.app = app
 
1021
  inpaint_mode = "merge",
1022
  mask_source_radio = "type what to detect below",
1023
  remove_mode = "rectangle", # ["segment", "rectangle"]
1024
+ remove_mask_extend = f"{data['mask_extend']}",
1025
  num_relation = 5,
1026
  kosmos_input = None,
1027
  cleaner_size_limit = -1,
1028
  )
1029
  output_images = results[0]
1030
+ if output_images is None:
1031
+ return None
1032
  ret_json_images = []
1033
  file_temp = int(time.time())
1034
  count = 0
1035
+ output_images = output_images[-1:]
1036
  for image_pil in output_images:
1037
  try:
1038
  img_format = image_pil.format.lower()
 
1114
  # print(f'ram_model__{get_model_device(ram_model)}')
1115
  # print(f'kosmos_model__{get_model_device(kosmos_model)}')
1116
 
1117
+ if run_gradio:
1118
  # Provide gradio services
1119
  main_gradio(args)
1120
  else:
1121
+ # Provide API services
1122
+ main_api(args)
 
 
 
 
1123
 
1124
 
1125