divimund95 commited on
Commit
6f3f66a
1 Parent(s): ba3e3be

disable tensorflow library

Browse files
Files changed (3) hide show
  1. app.py +39 -21
  2. requirements.txt +3 -2
  3. setup_local.sh +15 -5
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
- import io
6
  from omegaconf import OmegaConf
7
 
8
  import subprocess
@@ -14,6 +14,7 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama')
14
 
15
  from lama.saicinpainting.evaluation.refinement import refine_predict
16
  from lama.saicinpainting.training.trainers import load_checkpoint
 
17
 
18
 
19
  # Load the model
@@ -43,25 +44,20 @@ def get_inpaint_model():
43
  model.to(device)
44
  return model, predict_config
45
 
46
- def inpaint(input_dict):
 
47
  """
48
  Performs image inpainting on the input image using the provided mask.
49
  Args: input_dict containing 'background' (image) and 'layers' (mask)
50
  Returns: Tuple of (output_image, input_mask)
51
  """
52
- input_image = input_dict["background"].convert("RGB")
53
  input_mask = pil_to_binary_mask(input_dict['layers'][0])
54
 
55
- # TODO: check if this is correct; (C,H,W) or (H,W,C)
56
-
57
- # batch = dict(image=input_image, mask=input_mask[None, ...])
58
  np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
59
  np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images
60
  batch = dict(image=np_input_image, mask=np_input_mask)
61
 
62
- print('lol', batch['image'].shape)
63
- print('lol', batch['mask'].shape)
64
-
65
  inpaint_model, predict_config = get_inpaint_model()
66
  device = torch.device(predict_config.device)
67
 
@@ -69,8 +65,20 @@ def inpaint(input_dict):
69
  batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
70
  batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)
71
 
72
- cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
73
- cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
76
  output_image = Image.fromarray(cur_res)
@@ -88,7 +96,7 @@ def pad_img_to_modulo(img, mod):
88
  out_width = ceil_modulo(width, mod)
89
  return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
90
 
91
- def pil_to_binary_mask(pil_image, threshold=0):
92
  """
93
  Converts a PIL image to a binary mask.
94
 
@@ -107,25 +115,35 @@ def pil_to_binary_mask(pil_image, threshold=0):
107
  for j in range(binary_mask.shape[1]):
108
  if binary_mask[i,j] == True :
109
  mask[i,j] = 1
110
- mask = (mask*255).astype(np.uint8)
111
  output_mask = Image.fromarray(mask)
112
  # Convert mask to grayscale
113
  return output_mask.convert("L")
114
 
 
 
115
  # Create Gradio interface
116
- with gr.Blocks() as demo:
117
  gr.Markdown("# Image Inpainting")
118
  gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
119
 
120
  with gr.Row():
121
- input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto")
122
- output_image = gr.Image(type="pil", label="Output Image")
123
- # with gr.Column():
124
- # masked_image = gr.Image(label="Masked image", type="pil")
 
 
 
 
 
125
 
126
- inpaint_button = gr.Button("Inpaint")
127
- inpaint_button.click(fn=inpaint, inputs=[input_image], outputs=[output_image])
 
 
 
128
 
129
  # Launch the interface
130
  if __name__ == "__main__":
131
- demo.launch()
 
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
+ import spaces
6
  from omegaconf import OmegaConf
7
 
8
  import subprocess
 
14
 
15
  from lama.saicinpainting.evaluation.refinement import refine_predict
16
  from lama.saicinpainting.training.trainers import load_checkpoint
17
+ from lama.saicinpainting.evaluation.utils import move_to_device
18
 
19
 
20
  # Load the model
 
44
  model.to(device)
45
  return model, predict_config
46
 
47
+ @spaces.GPU
48
+ def inpaint(input_dict, refinement_enabled=False):
49
  """
50
  Performs image inpainting on the input image using the provided mask.
51
  Args: input_dict containing 'background' (image) and 'layers' (mask)
52
  Returns: Tuple of (output_image, input_mask)
53
  """
54
+ input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255
55
  input_mask = pil_to_binary_mask(input_dict['layers'][0])
56
 
 
 
 
57
  np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
58
  np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images
59
  batch = dict(image=np_input_image, mask=np_input_mask)
60
 
 
 
 
61
  inpaint_model, predict_config = get_inpaint_model()
62
  device = torch.device(predict_config.device)
63
 
 
65
  batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
66
  batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)
67
 
68
+
69
+ if refinement_enabled is True:
70
+ cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
71
+ cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
72
+ else:
73
+ with torch.no_grad():
74
+ batch = move_to_device(batch, device)
75
+ batch['mask'] = (batch['mask'] > 0) * 1
76
+ batch = inpaint_model(batch)
77
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
78
+ unpad_to_size = batch.get('unpad_to_size', None)
79
+ if unpad_to_size is not None:
80
+ orig_height, orig_width = unpad_to_size
81
+ cur_res = cur_res[:orig_height, :orig_width]
82
 
83
  cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
84
  output_image = Image.fromarray(cur_res)
 
96
  out_width = ceil_modulo(width, mod)
97
  return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
98
 
99
+ def pil_to_binary_mask(pil_image, threshold=0, max_scale=1):
100
  """
101
  Converts a PIL image to a binary mask.
102
 
 
115
  for j in range(binary_mask.shape[1]):
116
  if binary_mask[i,j] == True :
117
  mask[i,j] = 1
118
+ mask = (mask*max_scale).astype(np.uint8)
119
  output_mask = Image.fromarray(mask)
120
  # Convert mask to grayscale
121
  return output_mask.convert("L")
122
 
123
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
124
+
125
  # Create Gradio interface
126
+ with gr.Blocks(css=css) as demo:
127
  gr.Markdown("# Image Inpainting")
128
  gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
129
 
130
  with gr.Row():
131
+ input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25))
132
+ output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto")
133
+
134
+ with gr.Row():
135
+ refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False)
136
+ inpaint_button = gr.Button("Inpaint")
137
+
138
+ def inpaint_with_refinement(image, enable_refinement):
139
+ return inpaint(image, refinement_enabled=enable_refinement)
140
 
141
+ inpaint_button.click(
142
+ fn=inpaint_with_refinement,
143
+ inputs=[input_image, refine_checkbox],
144
+ outputs=[output_image]
145
+ )
146
 
147
  # Launch the interface
148
  if __name__ == "__main__":
149
+ demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio
2
- numpy
3
  pillow
4
  pyyaml
5
  tqdm
@@ -7,7 +7,7 @@ easydict==1.9.0
7
  scikit-image
8
  scikit-learn
9
  opencv-python
10
- tensorflow
11
  joblib
12
  matplotlib
13
  pandas
@@ -21,3 +21,4 @@ packaging
21
  wldhx.yadisk-direct
22
  torch
23
  torchvision
 
 
1
  gradio
2
+ numpy==1.26.4
3
  pillow
4
  pyyaml
5
  tqdm
 
7
  scikit-image
8
  scikit-learn
9
  opencv-python
10
+ # tensorflow
11
  joblib
12
  matplotlib
13
  pandas
 
21
  wldhx.yadisk-direct
22
  torch
23
  torchvision
24
+ spaces
setup_local.sh CHANGED
@@ -6,9 +6,19 @@ conda install pytorch torchvision -c pytorch -y
6
  pip install -r requirements.txt
7
 
8
 
9
- # Clone dependency repos
10
- git clone https://github.com/advimman/lama.git
 
 
 
 
 
11
 
12
- # Download big-lama model
13
- curl -LJO https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
14
- unzip big-lama.zip
 
 
 
 
 
 
6
  pip install -r requirements.txt
7
 
8
 
9
+ # Check if lama directory exists
10
+ if [ ! -d "lama" ]; then
11
+ # Clone dependency repos
12
+ git clone https://github.com/advimman/lama.git
13
+ else
14
+ echo "lama directory already exists. Skipping clone."
15
+ fi
16
 
17
+ # Check if big-lama.zip or big-lama directory exists
18
+ if [ ! -f "big-lama.zip" ] && [ ! -d "big-lama" ]; then
19
+ # Download big-lama model
20
+ curl -LJO https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
21
+ unzip big-lama.zip
22
+ else
23
+ echo "big-lama model already exists. Skipping download and extraction."
24
+ fi