g-h-chen commited on
Commit
e48a6a5
1 Parent(s): 41f653e

upload generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +270 -0
generation_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from copy import deepcopy
7
+ import requests, os
8
+
9
+ IMAGE_TOKEN_INDEX=-200
10
+ blacklist = ['<image>', '<s>', '</s>']
11
+ max_num_images = 3 # phi has a context length limit of 2048 and each image occupies 576 tokens.
12
+
13
+ def input_moderation(texts: list[list[str]]):
14
+ # perform input moderation on each message
15
+ for text_pair in texts:
16
+ # in-place operation
17
+ for b in blacklist:
18
+ text_pair[0] = text_pair[0].replace(b, '')
19
+ if text_pair[1] is not None:
20
+ text_pair[1] = text_pair[1].replace(b, '')
21
+
22
+ return texts
23
+
24
+ def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'):
25
+ for _ in range(num_images):
26
+ t = f"{placeholder}{sep}" + t
27
+ return t
28
+
29
+ def get_conv(texts):
30
+ ret = []
31
+
32
+ for conv in texts:
33
+ ret.append({'from': 'human', 'value': conv[0]})
34
+ ret.append({'from': 'gpt', 'value': conv[1]}) # this is None for the last one
35
+
36
+ return ret
37
+
38
+ # copied from llava
39
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
40
+ prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')]
41
+
42
+ def insert_separator(X, sep):
43
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
44
+
45
+ input_ids = []
46
+ offset = 0
47
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
48
+ offset = 1
49
+ input_ids.append(prompt_chunks[0][0])
50
+
51
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
52
+ input_ids.extend(x[offset:])
53
+
54
+ if return_tensors is not None:
55
+ if return_tensors == 'pt':
56
+ return torch.tensor(input_ids, dtype=torch.long)
57
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
58
+ return input_ids
59
+
60
+ def preprocess(tokenizer, data: list, return_tensors='pt'):
61
+ '''
62
+ [
63
+ {
64
+ 'from': 'human',
65
+ 'value': xxx,
66
+ },
67
+ {
68
+ 'from': 'gpt',
69
+ 'value': xxx
70
+ }
71
+ ]
72
+ '''
73
+ # needs update
74
+ if not isinstance(data, list):
75
+ raise ValueError('must be a list')
76
+
77
+ # this is per model (tokenizer)
78
+ return preprocess_allava(tokenizer, data, return_tensors=return_tensors)
79
+
80
+
81
+
82
+
83
+ def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: # tokenize and concat the coversations
84
+ input_ids = torch.tensor([1]).long()
85
+
86
+ for ind, conv in enumerate(convs):
87
+
88
+ if ind % 2 == 0: # human
89
+ h = conv['value'].strip()
90
+ h = f"<|user|>\n{h}<|end|>\n"
91
+ cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors)
92
+
93
+ # if len(labels) > 0:
94
+ # labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * (len(value_ids)-1)
95
+ # input_ids += cur_input_ids
96
+ if input_ids is None:
97
+ input_ids = cur_input_ids
98
+ else:
99
+ input_ids = torch.cat([input_ids, cur_input_ids])
100
+
101
+ else: # gpt
102
+ g = conv['value']
103
+ if g is not None:
104
+ g = f"<|assistant|>\n{g}<|end|>\n"
105
+ cur_input_ids = tokenizer(g, add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0]
106
+ input_ids = torch.cat([input_ids, cur_input_ids])
107
+ else:
108
+ g = f'<|assistant|>\n'
109
+
110
+ return input_ids
111
+
112
+
113
+ # copied from llava
114
+ def get_image_tensors(processor, images, device):
115
+ list_image_tensors = []
116
+ crop_size = processor.crop_size
117
+ for fp in images:
118
+ if fp is None: # None is used as a placeholder
119
+ list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device))
120
+ continue
121
+ elif isinstance(fp, str):
122
+ image = Image.open(fp).convert('RGB')
123
+ elif isinstance(fp, Image.Image):
124
+ image = fp # already an image
125
+ else:
126
+ raise TypeError(f'Unsupported type {type(fp)}')
127
+
128
+ # this is the way of preprocessing images we used in training, so we impose it here
129
+ if True:
130
+ # self.data_args.image_aspect_ratio == 'pad'
131
+ def expand2square(pil_img, background_color):
132
+ width, height = pil_img.size
133
+ if pil_img.mode == 'L':
134
+ pil_img = pil_img.convert('RGB')
135
+
136
+ if width == height:
137
+ return pil_img
138
+ elif width > height:
139
+ result = Image.new(pil_img.mode, (width, width), background_color)
140
+ result.paste(pil_img, (0, (width - height) // 2))
141
+ return result
142
+ else:
143
+ result = Image.new(pil_img.mode, (height, height), background_color)
144
+ result.paste(pil_img, ((height - width) // 2, 0))
145
+ return result
146
+
147
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
148
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
149
+ else:
150
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # a tensor
151
+ list_image_tensors.append(image.to(device))
152
+ # list_image_tensors.append(image)
153
+ return list_image_tensors
154
+
155
+
156
+
157
+
158
+ def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'):
159
+ '''
160
+ texts: [[]]
161
+ '''
162
+
163
+ ############################
164
+ # 1. preprocess texts
165
+ ############################
166
+ if isinstance(texts, str):
167
+ texts = [[texts, None]]
168
+ else:
169
+ assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list'
170
+
171
+ if history is not None:
172
+ texts = history + texts # concat them together
173
+
174
+ texts = input_moderation(texts)
175
+
176
+
177
+ ############################
178
+ # 2. preprocess images
179
+ ############################
180
+ if isinstance(images, str) or isinstance(images, Image.Image):
181
+ images = [images]
182
+
183
+ valid_images = []
184
+ if images is None:
185
+ images = [None]
186
+
187
+ for img in images:
188
+ try:
189
+ if os.path.exists(img): # make sure that the path exists
190
+ img = Image.open(img).convert('RGB')
191
+ else: # else it must be a URL
192
+ img = Image.open(requests.get(img, stream=True).raw)
193
+
194
+ valid_images.append(img)
195
+ except:
196
+ continue
197
+
198
+ images = valid_images
199
+
200
+ if images == []:
201
+ images = [None]
202
+
203
+
204
+ assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported'
205
+
206
+ ############################
207
+ # 3. collate conv
208
+ ############################
209
+
210
+ history = deepcopy(texts) # history is the texts without <image> placeholders
211
+
212
+ # insert <image>
213
+ image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) # only insert the placeholders for user input at the 1st round
214
+ texts[0][0] = image_place_holder_inserted
215
+
216
+ # collate strings into conv
217
+ conv = get_conv(texts)
218
+
219
+ # make input ids
220
+ input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device)
221
+
222
+ list_image_tensors = get_image_tensors(processor, images, device)
223
+ image_tensors = torch.stack(list_image_tensors)
224
+
225
+ try:
226
+ dtype = torch.bfloat16
227
+ # if your hardware does not support bf16, the following line raises an error
228
+ torch.tensor(1, dtype=dtype).cuda()
229
+ except:
230
+ # default using fp16
231
+ dtype = torch.float16
232
+
233
+ if return_history:
234
+ return input_ids, image_tensors, history
235
+
236
+ return input_ids, image_tensors, None
237
+
238
+
239
+
240
+ class TextIterStreamer:
241
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
242
+ self.tokenizer = tokenizer
243
+ self.skip_prompt = skip_prompt
244
+ self.skip_special_tokens = skip_special_tokens
245
+ self.tokens = []
246
+ self.text_queue = Queue()
247
+ self.next_tokens_are_prompt = True
248
+
249
+ def put(self, value):
250
+ if self.skip_prompt and self.next_tokens_are_prompt:
251
+ self.next_tokens_are_prompt = False
252
+ else:
253
+ if len(value.shape) > 1:
254
+ value = value[0]
255
+ self.tokens.extend(value.tolist())
256
+ self.text_queue.put(
257
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
258
+
259
+ def end(self):
260
+ self.text_queue.put(None)
261
+
262
+ def __iter__(self):
263
+ return self
264
+
265
+ def __next__(self):
266
+ value = self.text_queue.get()
267
+ if value is None:
268
+ raise StopIteration()
269
+ else:
270
+ return value