onuralpszr commited on
Commit
94d19e9
1 Parent(s): d552355

feat: ✨ For segmentation methods are added

Browse files

Signed-off-by: Onuralp SEZER <thunderbirdtr@gmail.com>

.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
app.py CHANGED
@@ -6,7 +6,8 @@ import numpy as np
6
  from PIL import Image
7
  import gradio as gr
8
  import spaces
9
- from helpers.utils import create_directory, delete_directory, generate_unique_name
 
10
  import os
11
 
12
  BOX_ANNOTATOR = sv.BoxAnnotator()
@@ -14,10 +15,12 @@ LABEL_ANNOTATOR = sv.LabelAnnotator()
14
  MASK_ANNOTATOR = sv.MaskAnnotator()
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  VIDEO_TARGET_DIRECTORY = "tmp"
 
17
 
 
18
 
19
  INTRO_TEXT = """
20
- ## PaliGemma 2 Detection with Supervision - Demo
21
 
22
  <div style="display: flex; gap: 10px;">
23
  <a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md">
@@ -60,6 +63,14 @@ def parse_class_names(prompt):
60
  classes_text = prompt[7:].strip()
61
  return [cls.strip() for cls in classes_text.split(';') if cls.strip()]
62
 
 
 
 
 
 
 
 
 
63
  @spaces.GPU
64
  def paligemma_detection(input_image, input_text, max_new_tokens):
65
  model_inputs = processor(text=input_text,
@@ -110,10 +121,60 @@ def annotate_image(result, resolution_wh, prompt, cv_image):
110
 
111
  def process_image(input_image, input_text, max_new_tokens):
112
  cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
113
- result = paligemma_detection(input_image, input_text, max_new_tokens)
114
- annotated_image = annotate_image(result,
115
- (input_image.width, input_image.height),
116
- input_text, cv_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  return annotated_image, result
118
 
119
 
@@ -188,13 +249,13 @@ def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(
188
  with gr.Blocks() as app:
189
  gr.Markdown(INTRO_TEXT)
190
 
191
- with gr.Tab("Image Detection"):
192
  with gr.Row():
193
  with gr.Column():
194
  input_image = gr.Image(type="pil", label="Input Image")
195
  input_text = gr.Textbox(
196
  lines=2,
197
- placeholder="Enter prompt in format like this: detect person;dog;building",
198
  label="Enter detection prompt"
199
  )
200
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
@@ -213,7 +274,7 @@ with gr.Blocks() as app:
213
  input_video = gr.Video(label="Input Video")
214
  input_text = gr.Textbox(
215
  lines=2,
216
- placeholder="Enter prompt in format like this: detect person;dog;building",
217
  label="Enter detection prompt"
218
  )
219
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
 
6
  from PIL import Image
7
  import gradio as gr
8
  import spaces
9
+ from helpers.file_utils import create_directory, delete_directory, generate_unique_name
10
+ from helpers.segment_utils import parse_segmentation, extract_objs
11
  import os
12
 
13
  BOX_ANNOTATOR = sv.BoxAnnotator()
 
15
  MASK_ANNOTATOR = sv.MaskAnnotator()
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  VIDEO_TARGET_DIRECTORY = "tmp"
18
+ VAE_MODEL = "vae-oid.npz"
19
 
20
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
21
 
22
  INTRO_TEXT = """
23
+ ## PaliGemma 2 Detection/Segmentation with Supervision - Demo
24
 
25
  <div style="display: flex; gap: 10px;">
26
  <a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md">
 
63
  classes_text = prompt[7:].strip()
64
  return [cls.strip() for cls in classes_text.split(';') if cls.strip()]
65
 
66
+ def parse_prompt_type(prompt):
67
+ """Determine if the prompt is for detection or segmentation."""
68
+ if prompt.lower().startswith('detect '):
69
+ return 'detection', prompt[7:].strip()
70
+ elif prompt.lower().startswith('segment '):
71
+ return 'segmentation', prompt[8:].strip()
72
+ return None, prompt
73
+
74
  @spaces.GPU
75
  def paligemma_detection(input_image, input_text, max_new_tokens):
76
  model_inputs = processor(text=input_text,
 
121
 
122
  def process_image(input_image, input_text, max_new_tokens):
123
  cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
124
+ prompt_type, cleaned_prompt = parse_prompt_type(input_text)
125
+
126
+ if prompt_type == 'detection':
127
+ # Existing detection logic
128
+ result = paligemma_detection(input_image, input_text, max_new_tokens)
129
+ class_names = [cls.strip() for cls in cleaned_prompt.split(';') if cls.strip()]
130
+
131
+ detections = sv.Detections.from_lmm(
132
+ sv.LMM.PALIGEMMA,
133
+ result,
134
+ resolution_wh=(input_image.width, input_image.height),
135
+ classes=class_names
136
+ )
137
+
138
+ annotated_image = BOX_ANNOTATOR.annotate(scene=cv_image.copy(), detections=detections)
139
+ annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
140
+ annotated_image = MASK_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
141
+
142
+ elif prompt_type == 'segmentation':
143
+ # New segmentation logic
144
+ result = paligemma_detection(input_image, input_text, max_new_tokens)
145
+ objs = extract_objs(result.lstrip("\n"), input_image.width, input_image.height, unique_labels=True)
146
+
147
+ # Create masks and annotations
148
+ annotated_image = cv_image.copy()
149
+ for obj in objs:
150
+ if 'mask' in obj and obj['mask'] is not None:
151
+ mask = obj['mask']
152
+ # Convert mask to uint8 for visualization
153
+ mask_vis = (mask * 255).astype(np.uint8)
154
+ # Create colored mask
155
+ colored_mask = np.zeros_like(cv_image)
156
+ color_idx = hash(obj['name']) % len(COLORS)
157
+ color = tuple(int(COLORS[color_idx].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
158
+ colored_mask[mask > 0] = color
159
+
160
+ # Blend mask with image
161
+ alpha = 0.5
162
+ annotated_image = cv2.addWeighted(annotated_image, 1, colored_mask, alpha, 0)
163
+
164
+ # Add label
165
+ if 'xyxy' in obj:
166
+ x1, y1, x2, y2 = obj['xyxy']
167
+ cv2.putText(annotated_image, obj['name'], (x1, y1-10),
168
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
169
+
170
+ else:
171
+ gr.Warning("Invalid prompt format. Please use 'detect' or 'segment' followed by class names")
172
+ return input_image, "Invalid prompt format"
173
+
174
+ # Convert back to RGB for display
175
+ annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
176
+ annotated_image = Image.fromarray(annotated_image)
177
+
178
  return annotated_image, result
179
 
180
 
 
249
  with gr.Blocks() as app:
250
  gr.Markdown(INTRO_TEXT)
251
 
252
+ with gr.Tab("Image Detection/Segmentation"):
253
  with gr.Row():
254
  with gr.Column():
255
  input_image = gr.Image(type="pil", label="Input Image")
256
  input_text = gr.Textbox(
257
  lines=2,
258
+ placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building",
259
  label="Enter detection prompt"
260
  )
261
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
 
274
  input_video = gr.Video(label="Input Video")
275
  input_text = gr.Textbox(
276
  lines=2,
277
+ placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building",
278
  label="Enter detection prompt"
279
  )
280
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
helpers/{utils.py → file_utils.py} RENAMED
File without changes
helpers/segment_utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import re
5
+ import numpy as np
6
+ import functools
7
+ from PIL import Image
8
+
9
+ ### Postprocessing Utils for Segmentation Tokens
10
+ ### Segmentation tokens are passed to another VAE which decodes them to a mask
11
+
12
+ _MODEL_PATH = 'vae-oid.npz'
13
+
14
+ _SEGMENT_DETECT_RE = re.compile(
15
+ r'(.*?)' +
16
+ r'<loc(\d{4})>' * 4 + r'\s*' +
17
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
18
+ r'\s*([^;<>]+)? ?(?:; )?',
19
+ )
20
+
21
+ def parse_segmentation(input_image, input_text, inference_output):
22
+ out = infer(input_image, input_text, max_new_tokens=100)
23
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
24
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
25
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
26
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
27
+ annotated_img = (
28
+ input_image,
29
+ [
30
+ (
31
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
32
+ obj['name'] or '',
33
+ )
34
+ for obj in objs
35
+ if 'mask' in obj or 'xyxy' in obj
36
+ ],
37
+ )
38
+ has_annotations = bool(annotated_img[1])
39
+ return annotated_img
40
+
41
+
42
+ def _get_params(checkpoint):
43
+ """Converts PyTorch checkpoint to Flax params."""
44
+
45
+ def transp(kernel):
46
+ return np.transpose(kernel, (2, 3, 1, 0))
47
+
48
+ def conv(name):
49
+ return {
50
+ 'bias': checkpoint[name + '.bias'],
51
+ 'kernel': transp(checkpoint[name + '.weight']),
52
+ }
53
+
54
+ def resblock(name):
55
+ return {
56
+ 'Conv_0': conv(name + '.0'),
57
+ 'Conv_1': conv(name + '.2'),
58
+ 'Conv_2': conv(name + '.4'),
59
+ }
60
+
61
+ return {
62
+ '_embeddings': checkpoint['_vq_vae._embedding'],
63
+ 'Conv_0': conv('decoder.0'),
64
+ 'ResBlock_0': resblock('decoder.2.net'),
65
+ 'ResBlock_1': resblock('decoder.3.net'),
66
+ 'ConvTranspose_0': conv('decoder.4'),
67
+ 'ConvTranspose_1': conv('decoder.6'),
68
+ 'ConvTranspose_2': conv('decoder.8'),
69
+ 'ConvTranspose_3': conv('decoder.10'),
70
+ 'Conv_1': conv('decoder.12'),
71
+ }
72
+
73
+
74
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
75
+ batch_size, num_tokens = codebook_indices.shape
76
+ assert num_tokens == 16, codebook_indices.shape
77
+ unused_num_embeddings, embedding_dim = embeddings.shape
78
+
79
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
80
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
81
+ return encodings
82
+
83
+
84
+ @functools.cache
85
+ def _get_reconstruct_masks():
86
+ """Reconstructs masks from codebook indices.
87
+ Returns:
88
+ A function that expects indices shaped `[B, 16]` of dtype int32, each
89
+ ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
90
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
91
+ """
92
+
93
+ class ResBlock(nn.Module):
94
+ features: int
95
+
96
+ @nn.compact
97
+ def __call__(self, x):
98
+ original_x = x
99
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
100
+ x = nn.relu(x)
101
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
102
+ x = nn.relu(x)
103
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
104
+ return x + original_x
105
+
106
+ class Decoder(nn.Module):
107
+ """Upscales quantized vectors to mask."""
108
+
109
+ @nn.compact
110
+ def __call__(self, x):
111
+ num_res_blocks = 2
112
+ dim = 128
113
+ num_upsample_layers = 4
114
+
115
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
116
+ x = nn.relu(x)
117
+
118
+ for _ in range(num_res_blocks):
119
+ x = ResBlock(features=dim)(x)
120
+
121
+ for _ in range(num_upsample_layers):
122
+ x = nn.ConvTranspose(
123
+ features=dim,
124
+ kernel_size=(4, 4),
125
+ strides=(2, 2),
126
+ padding=2,
127
+ transpose_kernel=True,
128
+ )(x)
129
+ x = nn.relu(x)
130
+ dim //= 2
131
+
132
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
133
+
134
+ return x
135
+
136
+ def reconstruct_masks(codebook_indices):
137
+ quantized = _quantized_values_from_codebook_indices(
138
+ codebook_indices, params['_embeddings']
139
+ )
140
+ return Decoder().apply({'params': params}, quantized)
141
+
142
+ with open(_MODEL_PATH, 'rb') as f:
143
+ params = _get_params(dict(np.load(f)))
144
+
145
+ return jax.jit(reconstruct_masks, backend='cpu')
146
+ def extract_objs(text, width, height, unique_labels=False):
147
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
148
+ objs = []
149
+ seen = set()
150
+ while text:
151
+ m = _SEGMENT_DETECT_RE.match(text)
152
+ if not m:
153
+ break
154
+ print("m", m)
155
+ gs = list(m.groups())
156
+ before = gs.pop(0)
157
+ name = gs.pop()
158
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
159
+
160
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
161
+ seg_indices = gs[4:20]
162
+ if seg_indices[0] is None:
163
+ mask = None
164
+ else:
165
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
166
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
167
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
168
+ m64 = Image.fromarray((m64 * 255).astype('uint8'))
169
+ mask = np.zeros([height, width])
170
+ if y2 > y1 and x2 > x1:
171
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
172
+
173
+ content = m.group()
174
+ if before:
175
+ objs.append(dict(content=before))
176
+ content = content[len(before):]
177
+ while unique_labels and name in seen:
178
+ name = (name or '') + "'"
179
+ seen.add(name)
180
+ objs.append(dict(
181
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
182
+ text = text[len(before) + len(content):]
183
+
184
+ if text:
185
+ objs.append(dict(content=text))
186
+
187
+ return objs
188
+
189
+ #########
requirements.txt CHANGED
@@ -3,4 +3,6 @@ transformers==4.47.0
3
  requests
4
  tqdm
5
  spaces
6
- torch
 
 
 
3
  requests
4
  tqdm
5
  spaces
6
+ torch
7
+ jax
8
+ flax