Manli commited on
Commit
d982555
1 Parent(s): 6a1d694

update preprocessing

Browse files
Files changed (3) hide show
  1. demo.ipynb +3 -5
  2. image_processing_blip_3.py +3 -18
  3. vlm.py +1 -126
demo.ipynb CHANGED
@@ -253,10 +253,10 @@
253
  " for fn in sample['image_path']:\n",
254
  " img = PIL.Image.open(fn)\n",
255
  " display.display(Image(filename=fn, width=300))\n",
256
- " image_list.append(image_processor([img], image_aspect_ratio='anyres')[\"pixel_values\"])\n",
257
  " image_sizes.append(img.size)\n",
258
  " inputs = {\n",
259
- " \"pixel_values\": image_list\n",
260
  " }\n",
261
  " for query in sample['question']:\n",
262
  " prompt = apply_prompt_template(query)\n",
@@ -266,9 +266,7 @@
266
  " for name, value in inputs.items():\n",
267
  " if isinstance(value, torch.Tensor):\n",
268
  " inputs[name] = value.cuda()\n",
269
- " else:\n",
270
- " inputs[name] = [v.cuda() for v in value]\n",
271
- " generated_text = model.generate(**inputs, image_size=image_sizes,\n",
272
  " pad_token_id=tokenizer.pad_token_id,\n",
273
  " do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1,\n",
274
  " )\n",
 
253
  " for fn in sample['image_path']:\n",
254
  " img = PIL.Image.open(fn)\n",
255
  " display.display(Image(filename=fn, width=300))\n",
256
+ " image_list.append(image_processor([img], image_aspect_ratio='anyres')[\"pixel_values\"].cuda())\n",
257
  " image_sizes.append(img.size)\n",
258
  " inputs = {\n",
259
+ " \"pixel_values\": [image_list]\n",
260
  " }\n",
261
  " for query in sample['question']:\n",
262
  " prompt = apply_prompt_template(query)\n",
 
266
  " for name, value in inputs.items():\n",
267
  " if isinstance(value, torch.Tensor):\n",
268
  " inputs[name] = value.cuda()\n",
269
+ " generated_text = model.generate(**inputs, image_size=[image_sizes],\n",
 
 
270
  " pad_token_id=tokenizer.pad_token_id,\n",
271
  " do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1,\n",
272
  " )\n",
image_processing_blip_3.py CHANGED
@@ -109,26 +109,11 @@ class Blip3ImageProcessor(BaseImageProcessor):
109
 
110
  if all(x.shape == new_images[0].shape for x in new_images):
111
  new_images = torch.stack(new_images, dim=0)
112
- if image_aspect_ratio == 'pad':
113
- new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0).unsqueeze(0)}, tensor_type=return_tensors)
114
- else:
115
  new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
 
 
116
  return new_images
117
- # def preprocess(self,
118
- # images: ImageInput,
119
- # return_tensors: Optional[Union[str, TensorType]] = None,
120
- # **kwargs) -> BatchFeature:
121
- # transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
122
- # transforms.extend([
123
- # self.convert_rgb,
124
- # ToTensor(),
125
- # Normalize(mean=self.image_mean, std=self.image_std)
126
- # ])
127
- # composed_transforms = Compose(transforms)
128
- # images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
129
- # encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
130
- # return encoded_outputs
131
-
132
 
133
  class ResizeKeepRatio:
134
  """ Resize and Keep Ratio
 
109
 
110
  if all(x.shape == new_images[0].shape for x in new_images):
111
  new_images = torch.stack(new_images, dim=0)
112
+ if image_aspect_ratio == 'anyres':
 
 
113
  new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
114
+ else:
115
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)}, tensor_type=return_tensors)
116
  return new_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  class ResizeKeepRatio:
119
  """ Resize and Keep Ratio
vlm.py CHANGED
@@ -1043,10 +1043,6 @@ class VLMWithLanguageStream(VLM):
1043
  multimodal_labels.append(labels[i].clone())
1044
  continue
1045
 
1046
- # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
1047
- for j, img_idx in enumerate(image_token_idxs):
1048
- image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j # FIXME: different offset for any resolution encoding when has multiple images.
1049
-
1050
  # loop through the image_token_idxs and insert the vision tokens
1051
  new_embed = lang_embeds[i].clone()
1052
  new_attention_mask = (
@@ -1056,9 +1052,6 @@ class VLMWithLanguageStream(VLM):
1056
  new_label = labels[i].clone()
1057
 
1058
  for img_num, img_idx in enumerate(image_token_idxs):
1059
- if img_num > 0:
1060
- # FIXME: hardcoded as such to avoid assertion error, but this only works for single image samples.
1061
- break
1062
  # Get vision token attention mask for padded llava-style any resolution image tokens.
1063
  if self.image_aspect_ratio =='anyres':
1064
  num_vis_tokens = vision_tokens[i][img_num].shape[0]
@@ -1078,7 +1071,6 @@ class VLMWithLanguageStream(VLM):
1078
  vis_attention_mask = torch.ones(
1079
  num_vis_tokens, dtype=torch.long
1080
  ).to(attention_mask.device)
1081
-
1082
 
1083
  new_embed = torch.cat(
1084
  (
@@ -1275,123 +1267,6 @@ class XGenMMPerceiver(VLMWithLanguageStream):
1275
  """
1276
  return True
1277
 
1278
- def forward(
1279
- self,
1280
- vision_x: Optional[torch.Tensor],
1281
- lang_x: torch.Tensor,
1282
- attention_mask: Optional[torch.Tensor] = None,
1283
- labels: Optional[torch.Tensor] = None,
1284
- image_size: Optional[Tuple] = None,
1285
- past_key_values: Optional[
1286
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1287
- ] = None,
1288
- past_media_locations: Optional[torch.Tensor] = None,
1289
- past_vision_tokens: Optional[torch.Tensor] = None,
1290
- use_cache: Optional[bool] = False,
1291
- **kwargs,
1292
- ):
1293
- """
1294
- Args:
1295
- vision_x: Vision input
1296
- shape (B, T_img, F, C, H, W) with F=1
1297
- only F = 1 is supported (single-frame videos)
1298
- if T_img > the number of media tokens in the corresponding input_ids (lang_x),
1299
- only the first number of media tokens in lang_x are used
1300
- lang_x: Language input ids, with media tokens denoting where
1301
- visual media should be inserted.
1302
- shape (B, T_txt)
1303
- attention_mask: Attention mask. Defaults to None.
1304
- labels: Labels. Defaults to None.
1305
- shape (B, T_txt)
1306
- past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
1307
- list of length = number of decoder layers in the LM
1308
- exact implementation depends on LM, see Hugging Face docs
1309
- past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
1310
- shape (B, T_txt)
1311
- past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
1312
- use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
1313
- If True, includes key_values, media_locations, and vision_tokens in the output.
1314
- """
1315
- assert not (past_vision_tokens is None) ^ (
1316
- past_media_locations is None
1317
- ), "past_vision_tokens and past_media_locations must both be None or both be not None"
1318
-
1319
- # convert pixels to vision tokens
1320
- vision_attention_mask = None
1321
- if vision_x is not None:
1322
- if self.image_aspect_ratio == 'anyres':
1323
- input_dict = dict(image=vision_x, image_size=image_size)
1324
- vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1325
- else:
1326
- vision_features = self._encode_vision_x(vision_x=vision_x)
1327
- vision_attn_masks = None
1328
- # Same for attention masks: [b, Np, v] -> [b*Np, v]
1329
- if self.anyres_patch_sampling:
1330
- split_sizes = [feature.shape[0] for feature in vision_features]
1331
- # Nested splits for multi-image samples.
1332
- if isinstance(vision_x[0], list):
1333
- nt_images = [len(images) for images in vision_x]
1334
- split_split_sizes = []
1335
- img_id = 0
1336
- for nt in nt_images:
1337
- split_split_sizes.append(split_sizes[img_id:img_id+nt])
1338
- img_id += nt
1339
- else:
1340
- nt_images = [1] * len(vision_x)
1341
- split_split_sizes = split_sizes
1342
- vision_features = torch.cat(vision_features, dim=0)
1343
- vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1344
- vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1345
- # TODO: add an option that allows restoring the T dimension for video tokenization.
1346
- vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1347
-
1348
- # Post-processing: Split the batches into groups of patches and concatenate them together.
1349
- if self.anyres_patch_sampling:
1350
- # assert isinstance(vision_x, list)
1351
- if isinstance(vision_x[0], list):
1352
- vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
1353
- vision_tokens = []
1354
-
1355
- for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1356
- patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1357
- flatten_vision_tokens = []
1358
- for image_vis_token in patch_vis_token_groups:
1359
- image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1360
- flatten_vision_tokens.append(image_vis_token)
1361
- vision_tokens_i = flatten_vision_tokens
1362
- vision_tokens.append(vision_tokens_i)
1363
- else:
1364
- vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1365
- vision_tokens = []
1366
- for patch_vis_tokens in vision_token_groups:
1367
- patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1368
- vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
1369
- else:
1370
- vision_tokens = None
1371
-
1372
- # fuse the vision and language tokens
1373
- new_inputs = self._prepare_inputs_for_forward(
1374
- vision_tokens=vision_tokens,
1375
- lang_x=lang_x,
1376
- attention_mask=attention_mask,
1377
- vision_attention_mask=vision_attention_mask,
1378
- labels=labels,
1379
- past_key_values=past_key_values,
1380
- past_media_locations=past_media_locations,
1381
- padding_side="right",
1382
- past_vision_tokens=past_vision_tokens,
1383
- )
1384
- output = self.lang_model(
1385
- **new_inputs,
1386
- use_cache=use_cache,
1387
- past_key_values=past_key_values,
1388
- **kwargs,
1389
- )
1390
-
1391
- # postforward hooks
1392
- self._post_forward_hook()
1393
- return output
1394
-
1395
  def generate(
1396
  self,
1397
  vision_x: torch.Tensor,
@@ -1429,7 +1304,7 @@ class XGenMMPerceiver(VLMWithLanguageStream):
1429
  else:
1430
  vision_features = self._encode_vision_x(vision_x=vision_x)
1431
  vision_attn_masks = None
1432
- # TODO: If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1433
  # Same for attention masks: [b, Np, v] -> [b*Np, v]
1434
  if self.anyres_patch_sampling:
1435
  split_sizes = [feature.shape[0] for feature in vision_features]
 
1043
  multimodal_labels.append(labels[i].clone())
1044
  continue
1045
 
 
 
 
 
1046
  # loop through the image_token_idxs and insert the vision tokens
1047
  new_embed = lang_embeds[i].clone()
1048
  new_attention_mask = (
 
1052
  new_label = labels[i].clone()
1053
 
1054
  for img_num, img_idx in enumerate(image_token_idxs):
 
 
 
1055
  # Get vision token attention mask for padded llava-style any resolution image tokens.
1056
  if self.image_aspect_ratio =='anyres':
1057
  num_vis_tokens = vision_tokens[i][img_num].shape[0]
 
1071
  vis_attention_mask = torch.ones(
1072
  num_vis_tokens, dtype=torch.long
1073
  ).to(attention_mask.device)
 
1074
 
1075
  new_embed = torch.cat(
1076
  (
 
1267
  """
1268
  return True
1269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1270
  def generate(
1271
  self,
1272
  vision_x: torch.Tensor,
 
1304
  else:
1305
  vision_features = self._encode_vision_x(vision_x=vision_x)
1306
  vision_attn_masks = None
1307
+ # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1308
  # Same for attention masks: [b, Np, v] -> [b*Np, v]
1309
  if self.anyres_patch_sampling:
1310
  split_sizes = [feature.shape[0] for feature in vision_features]