dwb2023 commited on
Commit
a78c4d2
1 Parent(s): 839e917

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -46
app.py CHANGED
@@ -1,24 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import functools
2
  import re
 
3
  import PIL.Image
4
  import gradio as gr
 
 
5
  import jax
6
  import jax.numpy as jnp
7
- import numpy as np
8
  import flax.linen as nn
9
- from inference import PaliGemmaModel, VAEModel
10
 
11
- COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
 
 
 
 
 
 
12
 
13
- # Instantiate the models
14
- pali_gemma_model = PaliGemmaModel()
15
- vae_model = VAEModel('vae-oid.npz')
16
 
17
- ##### Parse segmentation output tokens into masks
18
- ##### Also returns bounding boxes with their labels
19
 
20
- def parse_segmentation(input_image, input_text, max_new_tokens=100):
21
- out = pali_gemma_model.infer(image=input_image, text=input_text, max_new_tokens=max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
23
  labels = set(obj.get('name') for obj in objs if obj.get('name'))
24
  color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
@@ -37,14 +102,166 @@ def parse_segmentation(input_image, input_text, max_new_tokens=100):
37
  has_annotations = bool(annotated_img[1])
38
  return annotated_img
39
 
40
- INTRO_TEXT = "🔬🧠 CellVision AI -- Intelligent Cell Imaging Analysis 🤖🧫"
41
- IMAGE_PROMPT = """
42
- Describe the morphological characteristics and visible interactions between different cell types.
43
- Assess the biological context to identify signs of cancer and the presence of antigens.
44
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def extract_objs(text, width, height, unique_labels=False):
47
- """Returns objs for a string with "<loc>" and "<seg>" tokens."""
 
 
 
 
 
 
 
 
 
 
 
48
  objs = []
49
  seen = set()
50
  while text:
@@ -56,14 +273,14 @@ def extract_objs(text, width, height, unique_labels=False):
56
  before = gs.pop(0)
57
  name = gs.pop()
58
  y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
59
-
60
- y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
61
  seg_indices = gs[4:20]
62
  if seg_indices[0] is None:
63
  mask = None
64
  else:
65
  seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
66
- m64, = vae_model.reconstruct_masks(seg_indices[None])[..., 0]
67
  m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
68
  m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
69
  mask = np.zeros([height, width])
@@ -86,12 +303,13 @@ def extract_objs(text, width, height, unique_labels=False):
86
 
87
  return objs
88
 
89
- _SEGMENT_DETECT_RE = re.compile(
90
- r'(.*?)' +
91
- r'<loc(\d{4})>' * 4 + r'\s*' +
92
- '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
93
- r'\s*([^;<>]+)? ?(?:; )?',
94
- )
 
95
 
96
  with gr.Blocks(css="style.css") as demo:
97
  gr.Markdown(INTRO_TEXT)
@@ -100,27 +318,25 @@ with gr.Blocks(css="style.css") as demo:
100
  with gr.Column():
101
  image = gr.Image(type="pil")
102
  seg_input = gr.Text(label="Entities to Segment/Detect")
103
-
104
  with gr.Column():
105
  annotated_image = gr.AnnotatedImage(label="Output")
106
 
107
- seg_btn = gr.Button("Submit")
108
- examples = [
109
- ["./examples/cart1.jpg", "segment cells"],
110
- ["./examples/cart1.jpg", "detect cells"],
111
- ["./examples/cart2.jpg", "segment cells"],
112
- ["./examples/cart2.jpg", "detect cells"],
113
- ["./examples/cart3.jpg", "segment cells"],
114
- ["./examples/cart3.jpg", "detect cells"]
115
- ]
116
  gr.Examples(
117
  examples=examples,
118
  inputs=[image, seg_input],
119
  )
120
  seg_inputs = [
121
  image,
122
- seg_input,
123
- ]
124
  seg_outputs = [
125
  annotated_image
126
  ]
@@ -133,6 +349,7 @@ with gr.Blocks(css="style.css") as demo:
133
  with gr.Column():
134
  image = gr.Image(type="pil")
135
  text_input = gr.Text(label="Input Text")
 
136
  text_output = gr.Text(label="Text Output")
137
  chat_btn = gr.Button()
138
  tokens = gr.Slider(
@@ -148,25 +365,25 @@ with gr.Blocks(css="style.css") as demo:
148
  image,
149
  text_input,
150
  tokens
151
- ]
152
  chat_outputs = [
153
  text_output
154
  ]
155
  chat_btn.click(
156
- fn=pali_gemma_model.infer,
157
  inputs=chat_inputs,
158
  outputs=chat_outputs,
159
  )
160
-
161
- examples = [
162
- ["./examples/cart1.jpg", IMAGE_PROMPT],
163
- ["./examples/cart2.jpg", IMAGE_PROMPT],
164
- ["./examples/cart3.jpg", IMAGE_PROMPT]
165
- ]
166
  gr.Examples(
167
  examples=examples,
168
  inputs=chat_inputs,
169
  )
170
 
 
 
171
  if __name__ == "__main__":
172
- demo.queue(max_size=10).launch(debug=True)
 
1
+ """
2
+ CellVision AI - Intelligent Cell Imaging Analysis
3
+
4
+ This module provides a Gradio web application for performing intelligent cell imaging analysis
5
+ using the PaliGemma model from Google. The app allows users to segment or detect cells in images
6
+ and generate descriptive text based on the input image and prompt.
7
+
8
+ Dependencies:
9
+ - gradio
10
+ - transformers
11
+ - torch
12
+ - jax
13
+ - flax
14
+ - spaces
15
+ - PIL
16
+ - numpy
17
+ - huggingface_hub
18
+
19
+ """
20
+
21
+ import os
22
  import functools
23
  import re
24
+
25
  import PIL.Image
26
  import gradio as gr
27
+ import numpy as np
28
+ import torch
29
  import jax
30
  import jax.numpy as jnp
 
31
  import flax.linen as nn
 
32
 
33
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
34
+ from huggingface_hub import login
35
+ import spaces
36
+
37
+ # Perform login using the token
38
+ hf_token = os.getenv("HF_TOKEN")
39
+ login(token=hf_token, add_to_git_credential=True)
40
 
 
 
 
41
 
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
43
 
44
+ model_id = "google/paligemma-3b-mix-448"
45
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
46
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
47
+
48
+ @spaces.GPU
49
+ def infer(
50
+ image: PIL.Image.Image,
51
+ text: str,
52
+ max_new_tokens: int
53
+ ) -> str:
54
+ """
55
+ Perform inference using the PaliGemma model.
56
+
57
+ Args:
58
+ image (PIL.Image.Image): Input image.
59
+ text (str): Input text prompt.
60
+ max_new_tokens (int): Maximum number of new tokens to generate.
61
+
62
+ Returns:
63
+ str: Generated text based on the input image and prompt.
64
+ """
65
+ inputs = processor(text=text, images=image, return_tensors="pt").to(device)
66
+ with torch.inference_mode():
67
+ generated_ids = model.generate(
68
+ **inputs,
69
+ max_new_tokens=max_new_tokens,
70
+ do_sample=False
71
+ )
72
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
73
+ return result[0][len(text):].lstrip("\n")
74
+
75
+ def parse_segmentation(input_image, input_text):
76
+ """
77
+ Parse segmentation output tokens into masks and bounding boxes.
78
+
79
+ Args:
80
+ input_image (PIL.Image.Image): Input image.
81
+ input_text (str): Input text specifying entities to segment or detect.
82
+
83
+ Returns:
84
+ tuple: A tuple containing the annotated image and a boolean indicating if annotations are present.
85
+ """
86
+ out = infer(input_image, input_text, max_new_tokens=100)
87
  objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
88
  labels = set(obj.get('name') for obj in objs if obj.get('name'))
89
  color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
 
102
  has_annotations = bool(annotated_img[1])
103
  return annotated_img
104
 
105
+
106
+ ### Postprocessing Utils for Segmentation Tokens
107
+
108
+ _MODEL_PATH = 'vae-oid.npz'
109
+
110
+ _SEGMENT_DETECT_RE = re.compile(
111
+ r'(.*?)' +
112
+ r'<loc(\d{4})>' * 4 + r'\s*' +
113
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
114
+ r'\s*([^;<>]+)? ?(?:; )?',
115
+ )
116
+
117
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
118
+
119
+ def _get_params(checkpoint):
120
+ """
121
+ Convert PyTorch checkpoint to Flax params.
122
+
123
+ Args:
124
+ checkpoint (dict): PyTorch checkpoint dictionary.
125
+
126
+ Returns:
127
+ dict: Flax parameters.
128
+ """
129
+ def transp(kernel):
130
+ return np.transpose(kernel, (2, 3, 1, 0))
131
+
132
+ def conv(name):
133
+ return {
134
+ 'bias': checkpoint[name + '.bias'],
135
+ 'kernel': transp(checkpoint[name + '.weight']),
136
+ }
137
+
138
+ def resblock(name):
139
+ return {
140
+ 'Conv_0': conv(name + '.0'),
141
+ 'Conv_1': conv(name + '.2'),
142
+ 'Conv_2': conv(name + '.4'),
143
+ }
144
+
145
+ return {
146
+ '_embeddings': checkpoint['_vq_vae._embedding'],
147
+ 'Conv_0': conv('decoder.0'),
148
+ 'ResBlock_0': resblock('decoder.2.net'),
149
+ 'ResBlock_1': resblock('decoder.3.net'),
150
+ 'ConvTranspose_0': conv('decoder.4'),
151
+ 'ConvTranspose_1': conv('decoder.6'),
152
+ 'ConvTranspose_2': conv('decoder.8'),
153
+ 'ConvTranspose_3': conv('decoder.10'),
154
+ 'Conv_1': conv('decoder.12'),
155
+ }
156
+
157
+
158
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
159
+ """
160
+ Get quantized values from codebook indices.
161
+
162
+ Args:
163
+ codebook_indices (jax.numpy.ndarray): Codebook indices.
164
+ embeddings (jax.numpy.ndarray): Embeddings.
165
+
166
+ Returns:
167
+ jax.numpy.ndarray: Quantized values.
168
+ """
169
+ batch_size, num_tokens = codebook_indices.shape
170
+ assert num_tokens == 16, codebook_indices.shape
171
+ unused_num_embeddings, embedding_dim = embeddings.shape
172
+
173
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
174
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
175
+ return encodings
176
+
177
+
178
+ @functools.cache
179
+ def _get_reconstruct_masks():
180
+ """
181
+ Reconstruct masks from codebook indices.
182
+
183
+ Returns:
184
+ function: A function that expects indices shaped `[B, 16]` of dtype int32, each
185
+ ranging from 0 to 127 (inclusive), and returns decoded masks sized
186
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
187
+ """
188
+
189
+ class ResBlock(nn.Module):
190
+ features: int
191
+
192
+ @nn.compact
193
+ def __call__(self, x):
194
+ original_x = x
195
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
196
+ x = nn.relu(x)
197
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
198
+ x = nn.relu(x)
199
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
200
+ return x + original_x
201
+
202
+ class Decoder(nn.Module):
203
+ """Upscales quantized vectors to mask."""
204
+
205
+ @nn.compact
206
+ def __call__(self, x):
207
+ num_res_blocks = 2
208
+ dim = 128
209
+ num_upsample_layers = 4
210
+
211
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
212
+ x = nn.relu(x)
213
+
214
+ for _ in range(num_res_blocks):
215
+ x = ResBlock(features=dim)(x)
216
+
217
+ for _ in range(num_upsample_layers):
218
+ x = nn.ConvTranspose(
219
+ features=dim,
220
+ kernel_size=(4, 4),
221
+ strides=(2, 2),
222
+ padding=2,
223
+ transpose_kernel=True,
224
+ )(x)
225
+ x = nn.relu(x)
226
+ dim //= 2
227
+
228
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
229
+
230
+ return x
231
+
232
+ def reconstruct_masks(codebook_indices):
233
+ """
234
+ Reconstruct masks from codebook indices.
235
+
236
+ Args:
237
+ codebook_indices (jax.numpy.ndarray): Codebook indices.
238
+
239
+ Returns:
240
+ jax.numpy.ndarray: Reconstructed masks.
241
+ """
242
+ quantized = _quantized_values_from_codebook_indices(
243
+ codebook_indices, params['_embeddings']
244
+ )
245
+ return Decoder().apply({'params': params}, quantized)
246
+
247
+ with open(_MODEL_PATH, 'rb') as f:
248
+ params = _get_params(dict(np.load(f)))
249
+
250
+ return jax.jit(reconstruct_masks, backend='cpu')
251
 
252
  def extract_objs(text, width, height, unique_labels=False):
253
+ """
254
+ Extract objects from text containing "<loc>" and "<seg>" tokens.
255
+
256
+ Args:
257
+ text (str): Input text containing "<loc>" and "<seg>" tokens.
258
+ width (int): Width of the image.
259
+ height (int): Height of the image.
260
+ unique_labels (bool, optional): Whether to enforce unique labels. Defaults to False.
261
+
262
+ Returns:
263
+ list: List of extracted objects.
264
+ """
265
  objs = []
266
  seen = set()
267
  while text:
 
273
  before = gs.pop(0)
274
  name = gs.pop()
275
  y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
276
+
277
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
278
  seg_indices = gs[4:20]
279
  if seg_indices[0] is None:
280
  mask = None
281
  else:
282
  seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
283
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
284
  m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
285
  m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
286
  mask = np.zeros([height, width])
 
303
 
304
  return objs
305
 
306
+ #########
307
+
308
+ INTRO_TEXT="🔬🧠 CellVision AI -- Intelligent Cell Imaging Analysis 🤖🧫"
309
+ IMAGE_PROMPT="""
310
+ Describe the morphological characteristics and visible interactions between different cell types.
311
+ Assess the biological context to identify signs of cancer and the presence of antigens.
312
+ """
313
 
314
  with gr.Blocks(css="style.css") as demo:
315
  gr.Markdown(INTRO_TEXT)
 
318
  with gr.Column():
319
  image = gr.Image(type="pil")
320
  seg_input = gr.Text(label="Entities to Segment/Detect")
321
+
322
  with gr.Column():
323
  annotated_image = gr.AnnotatedImage(label="Output")
324
 
325
+ seg_btn = gr.Button("Submit")
326
+ examples = [["./examples/cart1.jpg", "segment cells"],
327
+ ["./examples/cart1.jpg", "detect cells"],
328
+ ["./examples/cart2.jpg", "segment cells"],
329
+ ["./examples/cart2.jpg", "detect cells"],
330
+ ["./examples/cart3.jpg", "segment cells"],
331
+ ["./examples/cart3.jpg", "detect cells"]]
 
 
332
  gr.Examples(
333
  examples=examples,
334
  inputs=[image, seg_input],
335
  )
336
  seg_inputs = [
337
  image,
338
+ seg_input
339
+ ]
340
  seg_outputs = [
341
  annotated_image
342
  ]
 
349
  with gr.Column():
350
  image = gr.Image(type="pil")
351
  text_input = gr.Text(label="Input Text")
352
+
353
  text_output = gr.Text(label="Text Output")
354
  chat_btn = gr.Button()
355
  tokens = gr.Slider(
 
365
  image,
366
  text_input,
367
  tokens
368
+ ]
369
  chat_outputs = [
370
  text_output
371
  ]
372
  chat_btn.click(
373
+ fn=infer,
374
  inputs=chat_inputs,
375
  outputs=chat_outputs,
376
  )
377
+
378
+ examples = [["./examples/cart1.jpg", IMAGE_PROMPT],
379
+ ["./examples/cart2.jpg", IMAGE_PROMPT],
380
+ ["./examples/cart3.jpg", IMAGE_PROMPT]]
 
 
381
  gr.Examples(
382
  examples=examples,
383
  inputs=chat_inputs,
384
  )
385
 
386
+ #########
387
+
388
  if __name__ == "__main__":
389
+ demo.queue(max_size=10).launch(debug=True)