jiajunlong commited on
Commit
1a607ef
1 Parent(s): f90ff9d

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -259
utils.py DELETED
@@ -1,259 +0,0 @@
1
- from PIL import Image
2
- from io import BytesIO
3
- import base64
4
-
5
- import torch
6
- from transformers import StoppingCriteria
7
-
8
- import math
9
- import ast
10
-
11
- # Model Constants
12
- IGNORE_INDEX = -100
13
- IMAGE_TOKEN_INDEX = -200
14
- DEFAULT_IMAGE_TOKEN = "<image>"
15
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
16
- DEFAULT_IM_START_TOKEN = "<im_start>"
17
- DEFAULT_IM_END_TOKEN = "<im_end>"
18
- IMAGE_PLACEHOLDER = "<image-placeholder>"
19
-
20
- def select_best_resolution(original_size, possible_resolutions):
21
- """
22
- Selects the best resolution from a list of possible resolutions based on the original size.
23
-
24
- Args:
25
- original_size (tuple): The original size of the image in the format (width, height).
26
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
27
-
28
- Returns:
29
- tuple: The best fit resolution in the format (width, height).
30
- """
31
- original_width, original_height = original_size
32
- best_fit = None
33
- max_effective_resolution = 0
34
- min_wasted_resolution = float('inf')
35
-
36
- for width, height in possible_resolutions:
37
- scale = min(width / original_width, height / original_height)
38
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
39
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
40
- wasted_resolution = (width * height) - effective_resolution
41
-
42
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
43
- max_effective_resolution = effective_resolution
44
- min_wasted_resolution = wasted_resolution
45
- best_fit = (width, height)
46
-
47
- return best_fit
48
-
49
-
50
- ## added by llava-1.6
51
- def resize_and_pad_image(image, target_resolution):
52
- """
53
- Resize and pad an image to a target resolution while maintaining aspect ratio.
54
-
55
- Args:
56
- image (PIL.Image.Image): The input image.
57
- target_resolution (tuple): The target resolution (width, height) of the image.
58
-
59
- Returns:
60
- PIL.Image.Image: The resized and padded image.
61
- """
62
- original_width, original_height = image.size
63
- target_width, target_height = target_resolution
64
-
65
- scale_w = target_width / original_width
66
- scale_h = target_height / original_height
67
-
68
- if scale_w < scale_h:
69
- new_width = target_width
70
- new_height = min(math.ceil(original_height * scale_w), target_height)
71
- else:
72
- new_height = target_height
73
- new_width = min(math.ceil(original_width * scale_h), target_width)
74
-
75
- # Resize the image
76
- resized_image = image.resize((new_width, new_height))
77
-
78
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
79
- paste_x = (target_width - new_width) // 2
80
- paste_y = (target_height - new_height) // 2
81
- new_image.paste(resized_image, (paste_x, paste_y))
82
-
83
- return new_image
84
-
85
-
86
- ## added by llava-1.6
87
- def divide_to_patches(image, patch_size):
88
- """
89
- Divides an image into patches of a specified size.
90
-
91
- Args:
92
- image (PIL.Image.Image): The input image.
93
- patch_size (int): The size of each patch.
94
-
95
- Returns:
96
- list: A list of PIL.Image.Image objects representing the patches.
97
- """
98
- patches = []
99
- width, height = image.size
100
- for i in range(0, height, patch_size):
101
- for j in range(0, width, patch_size):
102
- box = (j, i, j + patch_size, i + patch_size)
103
- patch = image.crop(box)
104
- patches.append(patch)
105
-
106
- return patches
107
-
108
-
109
- ## added by llava-1.6
110
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
111
- """
112
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
113
-
114
- Args:
115
- image_size (tuple): The size of the input image in the format (width, height).
116
- grid_pinpoints (str): A string representation of a list of possible resolutions.
117
- patch_size (int): The size of each image patch.
118
-
119
- Returns:
120
- tuple: The shape of the image patch grid in the format (width, height).
121
- """
122
- if type(grid_pinpoints) is list:
123
- possible_resolutions = grid_pinpoints
124
- else:
125
- possible_resolutions = ast.literal_eval(grid_pinpoints)
126
- width, height = select_best_resolution(image_size, possible_resolutions)
127
- return width // patch_size, height // patch_size
128
-
129
-
130
- ## added by llava-1.6
131
- def process_anyres_image(image, processor, grid_pinpoints):
132
- """
133
- Process an image with variable resolutions.
134
-
135
- Args:
136
- image (PIL.Image.Image): The input image to be processed.
137
- processor: The image processor object.
138
- grid_pinpoints (str): A string representation of a list of possible resolutions.
139
-
140
- Returns:
141
- torch.Tensor: A tensor containing the processed image patches.
142
- """
143
- if type(grid_pinpoints) is list:
144
- possible_resolutions = grid_pinpoints
145
- else:
146
- possible_resolutions = ast.literal_eval(grid_pinpoints)
147
- best_resolution = select_best_resolution(image.size, possible_resolutions)
148
- image_padded = resize_and_pad_image(image, best_resolution)
149
-
150
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
151
-
152
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
153
-
154
- image_patches = [image_original_resize] + patches
155
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
156
- for image_patch in image_patches]
157
- return torch.stack(image_patches, dim=0)
158
-
159
-
160
- def load_image_from_base64(image):
161
- return Image.open(BytesIO(base64.b64decode(image)))
162
-
163
-
164
- def expand2square(pil_img, background_color):
165
- width, height = pil_img.size
166
- if width == height:
167
- return pil_img
168
- elif width > height:
169
- result = Image.new(pil_img.mode, (width, width), background_color)
170
- result.paste(pil_img, (0, (width - height) // 2))
171
- return result
172
- else:
173
- result = Image.new(pil_img.mode, (height, height), background_color)
174
- result.paste(pil_img, ((height - width) // 2, 0))
175
- return result
176
-
177
-
178
- def process_images(images, image_processor, model_cfg):
179
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
180
- new_images = []
181
- if image_aspect_ratio == 'pad':
182
- for image in images:
183
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
184
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
185
- new_images.append(image)
186
- elif image_aspect_ratio == "anyres":
187
- for image in images:
188
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
189
- new_images.append(image)
190
- else:
191
- return image_processor(images, return_tensors='pt')['pixel_values']
192
- if all(x.shape == new_images[0].shape for x in new_images):
193
- new_images = torch.stack(new_images, dim=0)
194
- return new_images
195
-
196
-
197
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
198
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
199
-
200
- def insert_separator(X, sep):
201
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
202
-
203
- input_ids = []
204
- offset = 0
205
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
206
- offset = 1
207
- input_ids.append(prompt_chunks[0][0])
208
-
209
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
210
- input_ids.extend(x[offset:])
211
-
212
- if return_tensors is not None:
213
- if return_tensors == 'pt':
214
- return torch.tensor(input_ids, dtype=torch.long)
215
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
216
- return input_ids
217
-
218
-
219
- def get_model_name_from_path(model_path):
220
- model_path = model_path.strip("/")
221
- model_paths = model_path.split("/")
222
- if model_paths[-1].startswith('checkpoint-'):
223
- return model_paths[-2] + "_" + model_paths[-1]
224
- else:
225
- return model_paths[-1]
226
-
227
-
228
- class KeywordsStoppingCriteria(StoppingCriteria):
229
- def __init__(self, keywords, tokenizer, input_ids):
230
- self.keywords = keywords
231
- self.keyword_ids = []
232
- self.max_keyword_len = 0
233
- for keyword in keywords:
234
- cur_keyword_ids = tokenizer(keyword).input_ids
235
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
236
- cur_keyword_ids = cur_keyword_ids[1:]
237
- if len(cur_keyword_ids) > self.max_keyword_len:
238
- self.max_keyword_len = len(cur_keyword_ids)
239
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
240
- self.tokenizer = tokenizer
241
- self.start_len = input_ids.shape[1]
242
-
243
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
245
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
246
- for keyword_id in self.keyword_ids:
247
- if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
248
- return True
249
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
250
- for keyword in self.keywords:
251
- if keyword in outputs:
252
- return True
253
- return False
254
-
255
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
256
- outputs = []
257
- for i in range(output_ids.shape[0]):
258
- outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
259
- return all(outputs)