dwb2023 commited on
Commit
a2e1737
1 Parent(s): 25b653f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -221
app.py CHANGED
@@ -1,41 +1,40 @@
1
  import functools
2
  import re
3
-
4
  import PIL.Image
5
  import gradio as gr
6
  import jax
7
  import jax.numpy as jnp
8
  import numpy as np
9
-
10
  import flax.linen as nn
11
  from inference import PaliGemmaModel
12
 
13
- pali_gemma_model = PaliGemmaModel() # Create an instance of the model
14
-
15
  COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
16
 
 
 
 
17
  ##### Parse segmentation output tokens into masks
18
  ##### Also returns bounding boxes with their labels
19
 
20
  def parse_segmentation(input_image, input_text, max_new_tokens=100):
21
- out = pali_gemma_model.infer(image=input_image, text=input_text, max_new_tokens=max_new_tokens)
22
- objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
23
- labels = set(obj.get('name') for obj in objs if obj.get('name'))
24
- color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
25
- highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
26
- annotated_img = (
27
- input_image,
28
- [
29
- (
30
- obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
31
- obj['name'] or '',
32
- )
33
- for obj in objs
34
- if 'mask' in obj or 'xyxy' in obj
35
- ],
36
- )
37
- has_annotations = bool(annotated_img[1])
38
- return annotated_img
39
 
40
  INTRO_TEXT="🔬🧠 CellVision AI -- Intelligent Cell Imaging Analysis 🤖🧫"
41
  IMAGE_PROMPT="""
@@ -44,77 +43,79 @@ Assess the biological context to identify signs of cancer and the presence of an
44
  """
45
 
46
  with gr.Blocks(css="style.css") as demo:
47
- gr.Markdown(INTRO_TEXT)
48
- with gr.Tab("Segment/Detect"):
49
- with gr.Row():
50
- with gr.Column():
51
- image = gr.Image(type="pil")
52
- seg_input = gr.Text(label="Entities to Segment/Detect")
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  with gr.Column():
55
- annotated_image = gr.AnnotatedImage(label="Output")
56
-
57
- seg_btn = gr.Button("Submit")
58
- examples = [["./examples/cart1.jpg", "segment cells"],
59
- ["./examples/cart1.jpg", "detect cells"],
60
- ["./examples/cart2.jpg", "segment cells"],
61
- ["./examples/cart2.jpg", "detect cells"],
62
- ["./examples/cart3.jpg", "segment cells"],
63
- ["./examples/cart3.jpg", "detect cells"]]
64
- gr.Examples(
65
- examples=examples,
66
- inputs=[image, seg_input],
67
- )
68
- seg_inputs = [
69
- image,
70
- seg_input,
 
71
  ]
72
- seg_outputs = [
73
- annotated_image
74
- ]
75
- seg_btn.click(
76
- fn=parse_segmentation,
77
- inputs=seg_inputs,
78
- outputs=seg_outputs,
79
- )
80
- with gr.Tab("Text Generation"):
81
- with gr.Column():
82
- image = gr.Image(type="pil")
83
- text_input = gr.Text(label="Input Text")
84
-
85
- text_output = gr.Text(label="Text Output")
86
- chat_btn = gr.Button()
87
- tokens = gr.Slider(
88
- label="Max New Tokens",
89
- info="Set to larger for longer generation.",
90
- minimum=10,
91
- maximum=100,
92
- value=50,
93
- step=10,
94
  )
95
 
96
- chat_inputs = [
97
- image,
98
- text_input,
99
- tokens
100
  ]
101
- chat_outputs = [
102
- text_output
103
- ]
104
- chat_btn.click(
105
- fn=pali_gemma_model.infer,
106
- inputs=chat_inputs,
107
- outputs=chat_outputs,
108
- )
109
-
110
- examples = [["./examples/cart1.jpg", IMAGE_PROMPT],
111
- ["./examples/cart2.jpg", IMAGE_PROMPT],
112
- ["./examples/cart3.jpg", IMAGE_PROMPT]]
113
- gr.Examples(
114
- examples=examples,
115
- inputs=chat_inputs,
116
- )
117
-
118
 
119
  ### Postprocessing Utils for Segmentation Tokens
120
  ### Segmentation tokens are passed to another VAE which decodes them to a mask
@@ -128,156 +129,153 @@ _SEGMENT_DETECT_RE = re.compile(
128
  r'\s*([^;<>]+)? ?(?:; )?',
129
  )
130
 
131
-
132
  def _get_params(checkpoint):
133
- """Converts PyTorch checkpoint to Flax params."""
134
 
135
- def transp(kernel):
136
- return np.transpose(kernel, (2, 3, 1, 0))
137
 
138
- def conv(name):
139
- return {
140
- 'bias': checkpoint[name + '.bias'],
141
- 'kernel': transp(checkpoint[name + '.weight']),
142
- }
 
 
 
 
 
 
 
143
 
144
- def resblock(name):
145
  return {
146
- 'Conv_0': conv(name + '.0'),
147
- 'Conv_1': conv(name + '.2'),
148
- 'Conv_2': conv(name + '.4'),
 
 
 
 
 
 
149
  }
150
 
151
- return {
152
- '_embeddings': checkpoint['_vq_vae._embedding'],
153
- 'Conv_0': conv('decoder.0'),
154
- 'ResBlock_0': resblock('decoder.2.net'),
155
- 'ResBlock_1': resblock('decoder.3.net'),
156
- 'ConvTranspose_0': conv('decoder.4'),
157
- 'ConvTranspose_1': conv('decoder.6'),
158
- 'ConvTranspose_2': conv('decoder.8'),
159
- 'ConvTranspose_3': conv('decoder.10'),
160
- 'Conv_1': conv('decoder.12'),
161
- }
162
-
163
-
164
  def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
165
- batch_size, num_tokens = codebook_indices.shape
166
- assert num_tokens == 16, codebook_indices.shape
167
- unused_num_embeddings, embedding_dim = embeddings.shape
168
-
169
- encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
170
- encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
171
- return encodings
172
 
 
 
 
173
 
174
  @functools.cache
175
  def _get_reconstruct_masks():
176
- """Reconstructs masks from codebook indices.
177
- Returns:
178
- A function that expects indices shaped `[B, 16]` of dtype int32, each
179
- ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
180
- `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
181
- """
182
-
183
- class ResBlock(nn.Module):
184
- features: int
185
-
186
- @nn.compact
187
- def __call__(self, x):
188
- original_x = x
189
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
190
- x = nn.relu(x)
191
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
192
- x = nn.relu(x)
193
- x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
194
- return x + original_x
195
-
196
- class Decoder(nn.Module):
197
- """Upscales quantized vectors to mask."""
198
-
199
- @nn.compact
200
- def __call__(self, x):
201
- num_res_blocks = 2
202
- dim = 128
203
- num_upsample_layers = 4
204
-
205
- x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
206
- x = nn.relu(x)
207
-
208
- for _ in range(num_res_blocks):
209
- x = ResBlock(features=dim)(x)
210
-
211
- for _ in range(num_upsample_layers):
212
- x = nn.ConvTranspose(
213
- features=dim,
214
- kernel_size=(4, 4),
215
- strides=(2, 2),
216
- padding=2,
217
- transpose_kernel=True,
218
- )(x)
219
- x = nn.relu(x)
220
- dim //= 2
221
-
222
- x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
223
-
224
- return x
225
-
226
- def reconstruct_masks(codebook_indices):
227
- quantized = _quantized_values_from_codebook_indices(
228
- codebook_indices, params['_embeddings']
229
- )
230
- return Decoder().apply({'params': params}, quantized)
231
 
232
- with open(_MODEL_PATH, 'rb') as f:
233
- params = _get_params(dict(np.load(f)))
234
 
235
- return jax.jit(reconstruct_masks, backend='cpu')
236
 
237
  def extract_objs(text, width, height, unique_labels=False):
238
- """Returns objs for a string with "<loc>" and "<seg>" tokens."""
239
- objs = []
240
- seen = set()
241
- while text:
242
- m = _SEGMENT_DETECT_RE.match(text)
243
- if not m:
244
- break
245
- print("m", m)
246
- gs = list(m.groups())
247
- before = gs.pop(0)
248
- name = gs.pop()
249
- y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
250
-
251
- y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
252
- seg_indices = gs[4:20]
253
- if seg_indices[0] is None:
254
- mask = None
255
- else:
256
- seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
257
- m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
258
- m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
259
- m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
260
- mask = np.zeros([height, width])
261
- if y2 > y1 and x2 > x1:
262
- mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
263
-
264
- content = m.group()
265
- if before:
266
- objs.append(dict(content=before))
267
- content = content[len(before):]
268
- while unique_labels and name in seen:
269
- name = (name or '') + "'"
270
- seen.add(name)
271
- objs.append(dict(
272
- content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
273
- text = text[len(before) + len(content):]
274
-
275
- if text:
276
- objs.append(dict(content=text))
277
-
278
- return objs
279
 
280
  #########
281
 
282
  if __name__ == "__main__":
283
- demo.queue(max_size=10).launch(debug=True)
 
1
  import functools
2
  import re
 
3
  import PIL.Image
4
  import gradio as gr
5
  import jax
6
  import jax.numpy as jnp
7
  import numpy as np
 
8
  import flax.linen as nn
9
  from inference import PaliGemmaModel
10
 
 
 
11
  COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
12
 
13
+ # Instantiate the model
14
+ pali_gemma_model = PaliGemmaModel()
15
+
16
  ##### Parse segmentation output tokens into masks
17
  ##### Also returns bounding boxes with their labels
18
 
19
  def parse_segmentation(input_image, input_text, max_new_tokens=100):
20
+ out = pali_gemma_model.infer(image=input_image, text=input_text, max_new_tokens=max_new_tokens)
21
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
22
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
23
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
24
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
25
+ annotated_img = (
26
+ input_image,
27
+ [
28
+ (
29
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
30
+ obj['name'] or '',
31
+ )
32
+ for obj in objs
33
+ if 'mask' in obj or 'xyxy' in obj
34
+ ],
35
+ )
36
+ has_annotations = bool(annotated_img[1])
37
+ return annotated_img
38
 
39
  INTRO_TEXT="🔬🧠 CellVision AI -- Intelligent Cell Imaging Analysis 🤖🧫"
40
  IMAGE_PROMPT="""
 
43
  """
44
 
45
  with gr.Blocks(css="style.css") as demo:
46
+ gr.Markdown(INTRO_TEXT)
47
+ with gr.Tab("Segment/Detect"):
48
+ with gr.Row():
49
+ with gr.Column():
50
+ image = gr.Image(type="pil")
51
+ seg_input = gr.Text(label="Entities to Segment/Detect")
52
+
53
+ with gr.Column():
54
+ annotated_image = gr.AnnotatedImage(label="Output")
55
+
56
+ seg_btn = gr.Button("Submit")
57
+ examples = [
58
+ ["./examples/cart1.jpg", "segment cells"],
59
+ ["./examples/cart1.jpg", "detect cells"],
60
+ ["./examples/cart2.jpg", "segment cells"],
61
+ ["./examples/cart2.jpg", "detect cells"],
62
+ ["./examples/cart3.jpg", "segment cells"],
63
+ ["./examples/cart3.jpg", "detect cells"]
64
+ ]
65
+ gr.Examples(
66
+ examples=examples,
67
+ inputs=[image, seg_input],
68
+ )
69
+ seg_inputs = [
70
+ image,
71
+ seg_input,
72
+ ]
73
+ seg_outputs = [
74
+ annotated_image
75
+ ]
76
+ seg_btn.click(
77
+ fn=parse_segmentation,
78
+ inputs=seg_inputs,
79
+ outputs=seg_outputs,
80
+ )
81
+ with gr.Tab("Text Generation"):
82
  with gr.Column():
83
+ image = gr.Image(type="pil")
84
+ text_input = gr.Text(label="Input Text")
85
+ text_output = gr.Text(label="Text Output")
86
+ chat_btn = gr.Button()
87
+ tokens = gr.Slider(
88
+ label="Max New Tokens",
89
+ info="Set to larger for longer generation.",
90
+ minimum=10,
91
+ maximum=100,
92
+ value=50,
93
+ step=10,
94
+ )
95
+
96
+ chat_inputs = [
97
+ image,
98
+ text_input,
99
+ tokens
100
  ]
101
+ chat_outputs = [
102
+ text_output
103
+ ]
104
+ chat_btn.click(
105
+ fn=pali_gemma_model.infer,
106
+ inputs=chat_inputs,
107
+ outputs=chat_outputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
 
110
+ examples = [
111
+ ["./examples/cart1.jpg", IMAGE_PROMPT],
112
+ ["./examples/cart2.jpg", IMAGE_PROMPT],
113
+ ["./examples/cart3.jpg", IMAGE_PROMPT]
114
  ]
115
+ gr.Examples(
116
+ examples=examples,
117
+ inputs=chat_inputs,
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  ### Postprocessing Utils for Segmentation Tokens
121
  ### Segmentation tokens are passed to another VAE which decodes them to a mask
 
129
  r'\s*([^;<>]+)? ?(?:; )?',
130
  )
131
 
 
132
  def _get_params(checkpoint):
133
+ """Converts PyTorch checkpoint to Flax params."""
134
 
135
+ def transp(kernel):
136
+ return np.transpose(kernel, (2, 3, 1, 0))
137
 
138
+ def conv(name):
139
+ return {
140
+ 'bias': checkpoint[name + '.bias'],
141
+ 'kernel': transp(checkpoint[name + '.weight']),
142
+ }
143
+
144
+ def resblock(name):
145
+ return {
146
+ 'Conv_0': conv(name + '.0'),
147
+ 'Conv_1': conv(name + '.2'),
148
+ 'Conv_2': conv(name + '.4'),
149
+ }
150
 
 
151
  return {
152
+ '_embeddings': checkpoint['_vq_vae._embedding'],
153
+ 'Conv_0': conv('decoder.0'),
154
+ 'ResBlock_0': resblock('decoder.2.net'),
155
+ 'ResBlock_1': resblock('decoder.3.net'),
156
+ 'ConvTranspose_0': conv('decoder.4'),
157
+ 'ConvTranspose_1': conv('decoder.6'),
158
+ 'ConvTranspose_2': conv('decoder.8'),
159
+ 'ConvTranspose_3': conv('decoder.10'),
160
+ 'Conv_1': conv('decoder.12'),
161
  }
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
164
+ batch_size, num_tokens = codebook_indices.shape
165
+ assert num_tokens == 16, codebook_indices.shape
166
+ unused_num_embeddings, embedding_dim = embeddings.shape
 
 
 
 
167
 
168
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
169
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
170
+ return encodings
171
 
172
  @functools.cache
173
  def _get_reconstruct_masks():
174
+ """Reconstructs masks from codebook indices.
175
+ Returns:
176
+ A function that expects indices shaped `[B, 16]` of dtype int32, each
177
+ ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
178
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
179
+ """
180
+
181
+ class ResBlock(nn.Module):
182
+ features: int
183
+
184
+ @nn.compact
185
+ def __call__(self, x):
186
+ original_x = x
187
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
188
+ x = nn.relu(x)
189
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
190
+ x = nn.relu(x)
191
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
192
+ return x + original_x
193
+
194
+ class Decoder(nn.Module):
195
+ """Upscales quantized vectors to mask."""
196
+
197
+ @nn.compact
198
+ def __call__(self, x):
199
+ num_res_blocks = 2
200
+ dim = 128
201
+ num_upsample_layers = 4
202
+
203
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
204
+ x = nn.relu(x)
205
+
206
+ for _ in range(num_res_blocks):
207
+ x = ResBlock(features=dim)(x)
208
+
209
+ for _ in range(num_upsample_layers):
210
+ x = nn.ConvTranspose(
211
+ features=dim,
212
+ kernel_size=(4, 4),
213
+ strides=(2, 2),
214
+ padding=2,
215
+ transpose_kernel=True,
216
+ )(x)
217
+ x = nn.relu(x)
218
+ dim //= 2
219
+
220
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
221
+
222
+ return x
223
+
224
+ def reconstruct_masks(codebook_indices):
225
+ quantized = _quantized_values_from_codebook_indices(
226
+ codebook_indices, params['_embeddings']
227
+ )
228
+ return Decoder().apply({'params': params}, quantized)
229
 
230
+ with open(_MODEL_PATH, 'rb') as f:
231
+ params = _get_params(dict(np.load(f)))
232
 
233
+ return jax.jit(reconstruct_masks, backend='cpu')
234
 
235
  def extract_objs(text, width, height, unique_labels=False):
236
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
237
+ objs = []
238
+ seen = set()
239
+ while text:
240
+ m = _SEGMENT_DETECT_RE.match(text)
241
+ if not m:
242
+ break
243
+ print("m", m)
244
+ gs = list(m.groups())
245
+ before = gs.pop(0)
246
+ name = gs.pop()
247
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
248
+
249
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
250
+ seg_indices = gs[4:20]
251
+ if seg_indices[0] is None:
252
+ mask = None
253
+ else:
254
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
255
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
256
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
257
+ m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
258
+ mask = np.zeros([height, width])
259
+ if y2 > y1 and x2 > x1:
260
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
261
+
262
+ content = m.group()
263
+ if before:
264
+ objs.append(dict(content=before))
265
+ content = content[len(before):]
266
+ while unique_labels and name in seen:
267
+ name = (name or '') + "'"
268
+ seen.add(name)
269
+ objs.append(dict(
270
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
271
+ text = text[len(before) + len(content):]
272
+
273
+ if text:
274
+ objs.append(dict(content=text))
275
+
276
+ return objs
277
 
278
  #########
279
 
280
  if __name__ == "__main__":
281
+ demo.queue(max_size=10).launch(debug=True)