jiajunlong commited on
Commit
ced6b81
1 Parent(s): 1a607ef

Update generate_model.py

Browse files
Files changed (1) hide show
  1. generate_model.py +534 -5
generate_model.py CHANGED
@@ -10,10 +10,538 @@ from PIL import Image
10
  import torch
11
  from transformers import AutoTokenizer
12
 
13
- from modeling_tinyllava_elm import TinyLlavaForConditionalGeneration
14
- from configuration import *
15
- from conversion import *
16
- from utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
 
@@ -59,9 +587,10 @@ def generate(
59
  if isinstance(model, str):
60
  checkpoint_path = model
61
  # print(f'loading model from {checkpoint_path}...')
62
- model = TinyLlavaForConditionalGeneration.from_pretrained(
63
  checkpoint_path,
64
  torch_dtype=torch.float16,
 
65
  )
66
  # print('model load over')
67
  config = model.config
 
10
  import torch
11
  from transformers import AutoTokenizer
12
 
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+
15
+ from PIL import Image
16
+ from io import BytesIO
17
+ import base64
18
+
19
+ import torch
20
+ from transformers import StoppingCriteria
21
+
22
+ import math
23
+ import ast
24
+
25
+ # Model Constants
26
+ IGNORE_INDEX = -100
27
+ IMAGE_TOKEN_INDEX = -200
28
+ DEFAULT_IMAGE_TOKEN = "<image>"
29
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
30
+ DEFAULT_IM_START_TOKEN = "<im_start>"
31
+ DEFAULT_IM_END_TOKEN = "<im_end>"
32
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
33
+ import dataclasses
34
+ from enum import auto, Enum
35
+ from typing import List, Tuple
36
+
37
+
38
+ class SeparatorStyle(Enum):
39
+ """Different separator style."""
40
+ SINGLE = auto()
41
+ TWO = auto()
42
+ MPT = auto()
43
+ PLAIN = auto()
44
+ LLAMA_2 = auto()
45
+ TINY_LLAMA = auto()
46
+ QWEN_2 = auto()
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class Conversation:
51
+ """A class that keeps all conversation history."""
52
+ system: str
53
+ roles: List[str]
54
+ messages: List[List[str]]
55
+ offset: int
56
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
57
+ sep: str = "###"
58
+ sep2: str = None
59
+ version: str = "Unknown"
60
+
61
+ skip_next: bool = False
62
+
63
+ def get_prompt(self):
64
+ messages = self.messages
65
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
66
+ messages = self.messages.copy()
67
+ init_role, init_msg = messages[0].copy()
68
+ init_msg = init_msg[0].replace("<image>", "").strip()
69
+ if 'mmtag' in self.version:
70
+ messages[0] = (init_role, init_msg)
71
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
72
+ messages.insert(1, (self.roles[1], "Received."))
73
+ else:
74
+ messages[0] = (init_role, "<image>\n" + init_msg)
75
+
76
+ if self.sep_style == SeparatorStyle.SINGLE:
77
+ ret = self.system + self.sep
78
+ for role, message in messages:
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ ret += role + ": " + message + self.sep
83
+ else:
84
+ ret += role + ":"
85
+ elif self.sep_style == SeparatorStyle.TWO:
86
+ seps = [self.sep, self.sep2]
87
+ ret = self.system + seps[0]
88
+ for i, (role, message) in enumerate(messages):
89
+ if message:
90
+ if type(message) is tuple:
91
+ message, _, _ = message
92
+ ret += role + ": " + message + seps[i % 2]
93
+ else:
94
+ ret += role + ":"
95
+ elif self.sep_style == SeparatorStyle.MPT:
96
+ ret = self.system + self.sep
97
+ for role, message in messages:
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += role + message + self.sep
102
+ else:
103
+ ret += role
104
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
105
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
106
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
107
+ ret = ""
108
+
109
+ for i, (role, message) in enumerate(messages):
110
+ if i == 0:
111
+ assert message, "first message should not be none"
112
+ assert role == self.roles[0], "first message should come from user"
113
+ if message:
114
+ if type(message) is tuple:
115
+ message, _, _ = message
116
+ if i == 0: message = wrap_sys(self.system) + message
117
+ if i % 2 == 0:
118
+ message = wrap_inst(message)
119
+ ret += self.sep + message
120
+ else:
121
+ ret += " " + message + " " + self.sep2
122
+ else:
123
+ ret += ""
124
+ ret = ret.lstrip(self.sep)
125
+ elif self.sep_style == SeparatorStyle.TINY_LLAMA:
126
+ sep = "</s>"
127
+ wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
128
+ wrap_user = lambda msg: f"<|user|>\n{msg}\n"
129
+ wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
130
+ ret = ""
131
+
132
+ for i, (role, message) in enumerate(messages):
133
+ if i == 0:
134
+ assert message, "first message should not be none"
135
+ assert role == self.roles[0], "first message should come from user"
136
+ if message:
137
+ if type(message) is tuple:
138
+ message, _, _ = message
139
+ if i % 2 == 0:
140
+ message = wrap_user(message)
141
+ if i == 0:
142
+ message = wrap_sys(self.system) + message
143
+ ret += self.sep + message
144
+ else:
145
+ message = wrap_assistant(message) + self.sep2
146
+ ret += message
147
+ else:
148
+ ret += "<|assistant|>\n"
149
+ ret = ret.lstrip(self.sep)
150
+ elif self.sep_style == SeparatorStyle.QWEN_2:
151
+ ret = self.system + self.sep
152
+ for role, message in messages:
153
+ if message:
154
+ if type(message) is tuple:
155
+ message, _, _ = message
156
+ ret += role + message + self.sep
157
+ else:
158
+ ret += role
159
+ elif self.sep_style == SeparatorStyle.PLAIN:
160
+ seps = [self.sep, self.sep2]
161
+ ret = self.system
162
+ for i, (role, message) in enumerate(messages):
163
+ if message:
164
+ if type(message) is tuple:
165
+ message, _, _ = message
166
+ ret += message + seps[i % 2]
167
+ else:
168
+ ret += ""
169
+ else:
170
+ raise ValueError(f"Invalid style: {self.sep_style}")
171
+
172
+ return ret
173
+
174
+ def append_message(self, role, message):
175
+ self.messages.append([role, message])
176
+
177
+ def get_images(self, return_pil=False):
178
+ images = []
179
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
180
+ if i % 2 == 0:
181
+ if type(msg) is tuple:
182
+ import base64
183
+ from io import BytesIO
184
+ from PIL import Image
185
+ msg, image, image_process_mode = msg
186
+ if image_process_mode == "Pad":
187
+ def expand2square(pil_img, background_color=(122, 116, 104)):
188
+ width, height = pil_img.size
189
+ if width == height:
190
+ return pil_img
191
+ elif width > height:
192
+ result = Image.new(pil_img.mode, (width, width), background_color)
193
+ result.paste(pil_img, (0, (width - height) // 2))
194
+ return result
195
+ else:
196
+ result = Image.new(pil_img.mode, (height, height), background_color)
197
+ result.paste(pil_img, ((height - width) // 2, 0))
198
+ return result
199
+ image = expand2square(image)
200
+ elif image_process_mode in ["Default", "Crop"]:
201
+ pass
202
+ elif image_process_mode == "Resize":
203
+ image = image.resize((336, 336))
204
+ else:
205
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
206
+ max_hw, min_hw = max(image.size), min(image.size)
207
+ aspect_ratio = max_hw / min_hw
208
+ max_len, min_len = 800, 400
209
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
210
+ longest_edge = int(shortest_edge * aspect_ratio)
211
+ W, H = image.size
212
+ if longest_edge != max(image.size):
213
+ if H > W:
214
+ H, W = longest_edge, shortest_edge
215
+ else:
216
+ H, W = shortest_edge, longest_edge
217
+ image = image.resize((W, H))
218
+ if return_pil:
219
+ images.append(image)
220
+ else:
221
+ buffered = BytesIO()
222
+ image.save(buffered, format="PNG")
223
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
224
+ images.append(img_b64_str)
225
+ return images
226
+
227
+ def to_gradio_chatbot(self):
228
+ ret = []
229
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
230
+ if i % 2 == 0:
231
+ if type(msg) is tuple:
232
+ import base64
233
+ from io import BytesIO
234
+ msg, image, image_process_mode = msg
235
+ max_hw, min_hw = max(image.size), min(image.size)
236
+ aspect_ratio = max_hw / min_hw
237
+ max_len, min_len = 800, 400
238
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
239
+ longest_edge = int(shortest_edge * aspect_ratio)
240
+ W, H = image.size
241
+ if H > W:
242
+ H, W = longest_edge, shortest_edge
243
+ else:
244
+ H, W = shortest_edge, longest_edge
245
+ image = image.resize((W, H))
246
+ buffered = BytesIO()
247
+ image.save(buffered, format="JPEG")
248
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
249
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
250
+ msg = img_str + msg.replace('<image>', '').strip()
251
+ ret.append([msg, None])
252
+ else:
253
+ ret.append([msg, None])
254
+ else:
255
+ ret[-1][-1] = msg
256
+ return ret
257
+
258
+ def copy(self):
259
+ return Conversation(
260
+ system=self.system,
261
+ roles=self.roles,
262
+ messages=[[x, y] for x, y in self.messages],
263
+ offset=self.offset,
264
+ sep_style=self.sep_style,
265
+ sep=self.sep,
266
+ sep2=self.sep2,
267
+ version=self.version)
268
+
269
+ def dict(self):
270
+ if len(self.get_images()) > 0:
271
+ return {
272
+ "system": self.system,
273
+ "roles": self.roles,
274
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
275
+ "offset": self.offset,
276
+ "sep": self.sep,
277
+ "sep2": self.sep2,
278
+ }
279
+ return {
280
+ "system": self.system,
281
+ "roles": self.roles,
282
+ "messages": self.messages,
283
+ "offset": self.offset,
284
+ "sep": self.sep,
285
+ "sep2": self.sep2,
286
+ }
287
+
288
+
289
+
290
+
291
+ conv_phi_v0 = Conversation(
292
+ system="A chat between a curious user and an artificial intelligence assistant. "
293
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
294
+ roles=("USER", "ASSISTANT"),
295
+ version="phi",
296
+ messages=(),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.TWO,
299
+ sep=" ",
300
+ sep2="<|endoftext|>",
301
+ )
302
+
303
+
304
+
305
+ def select_best_resolution(original_size, possible_resolutions):
306
+ """
307
+ Selects the best resolution from a list of possible resolutions based on the original size.
308
+
309
+ Args:
310
+ original_size (tuple): The original size of the image in the format (width, height).
311
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
312
+
313
+ Returns:
314
+ tuple: The best fit resolution in the format (width, height).
315
+ """
316
+ original_width, original_height = original_size
317
+ best_fit = None
318
+ max_effective_resolution = 0
319
+ min_wasted_resolution = float('inf')
320
+
321
+ for width, height in possible_resolutions:
322
+ scale = min(width / original_width, height / original_height)
323
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
324
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
325
+ wasted_resolution = (width * height) - effective_resolution
326
+
327
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
328
+ max_effective_resolution = effective_resolution
329
+ min_wasted_resolution = wasted_resolution
330
+ best_fit = (width, height)
331
+
332
+ return best_fit
333
+
334
+
335
+ ## added by llava-1.6
336
+ def resize_and_pad_image(image, target_resolution):
337
+ """
338
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
339
+
340
+ Args:
341
+ image (PIL.Image.Image): The input image.
342
+ target_resolution (tuple): The target resolution (width, height) of the image.
343
+
344
+ Returns:
345
+ PIL.Image.Image: The resized and padded image.
346
+ """
347
+ original_width, original_height = image.size
348
+ target_width, target_height = target_resolution
349
+
350
+ scale_w = target_width / original_width
351
+ scale_h = target_height / original_height
352
+
353
+ if scale_w < scale_h:
354
+ new_width = target_width
355
+ new_height = min(math.ceil(original_height * scale_w), target_height)
356
+ else:
357
+ new_height = target_height
358
+ new_width = min(math.ceil(original_width * scale_h), target_width)
359
+
360
+ # Resize the image
361
+ resized_image = image.resize((new_width, new_height))
362
+
363
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
364
+ paste_x = (target_width - new_width) // 2
365
+ paste_y = (target_height - new_height) // 2
366
+ new_image.paste(resized_image, (paste_x, paste_y))
367
+
368
+ return new_image
369
+
370
+
371
+ ## added by llava-1.6
372
+ def divide_to_patches(image, patch_size):
373
+ """
374
+ Divides an image into patches of a specified size.
375
+
376
+ Args:
377
+ image (PIL.Image.Image): The input image.
378
+ patch_size (int): The size of each patch.
379
+
380
+ Returns:
381
+ list: A list of PIL.Image.Image objects representing the patches.
382
+ """
383
+ patches = []
384
+ width, height = image.size
385
+ for i in range(0, height, patch_size):
386
+ for j in range(0, width, patch_size):
387
+ box = (j, i, j + patch_size, i + patch_size)
388
+ patch = image.crop(box)
389
+ patches.append(patch)
390
+
391
+ return patches
392
+
393
+
394
+ ## added by llava-1.6
395
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
396
+ """
397
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
398
+
399
+ Args:
400
+ image_size (tuple): The size of the input image in the format (width, height).
401
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
402
+ patch_size (int): The size of each image patch.
403
+
404
+ Returns:
405
+ tuple: The shape of the image patch grid in the format (width, height).
406
+ """
407
+ if type(grid_pinpoints) is list:
408
+ possible_resolutions = grid_pinpoints
409
+ else:
410
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
411
+ width, height = select_best_resolution(image_size, possible_resolutions)
412
+ return width // patch_size, height // patch_size
413
+
414
+
415
+ ## added by llava-1.6
416
+ def process_anyres_image(image, processor, grid_pinpoints):
417
+ """
418
+ Process an image with variable resolutions.
419
+
420
+ Args:
421
+ image (PIL.Image.Image): The input image to be processed.
422
+ processor: The image processor object.
423
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
424
+
425
+ Returns:
426
+ torch.Tensor: A tensor containing the processed image patches.
427
+ """
428
+ if type(grid_pinpoints) is list:
429
+ possible_resolutions = grid_pinpoints
430
+ else:
431
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
432
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
433
+ image_padded = resize_and_pad_image(image, best_resolution)
434
+
435
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
436
+
437
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
438
+
439
+ image_patches = [image_original_resize] + patches
440
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
441
+ for image_patch in image_patches]
442
+ return torch.stack(image_patches, dim=0)
443
+
444
+
445
+ def load_image_from_base64(image):
446
+ return Image.open(BytesIO(base64.b64decode(image)))
447
+
448
+
449
+ def expand2square(pil_img, background_color):
450
+ width, height = pil_img.size
451
+ if width == height:
452
+ return pil_img
453
+ elif width > height:
454
+ result = Image.new(pil_img.mode, (width, width), background_color)
455
+ result.paste(pil_img, (0, (width - height) // 2))
456
+ return result
457
+ else:
458
+ result = Image.new(pil_img.mode, (height, height), background_color)
459
+ result.paste(pil_img, ((height - width) // 2, 0))
460
+ return result
461
+
462
+
463
+ def process_images(images, image_processor, model_cfg):
464
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
465
+ new_images = []
466
+ if image_aspect_ratio == 'pad':
467
+ for image in images:
468
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
469
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
470
+ new_images.append(image)
471
+ elif image_aspect_ratio == "anyres":
472
+ for image in images:
473
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
474
+ new_images.append(image)
475
+ else:
476
+ return image_processor(images, return_tensors='pt')['pixel_values']
477
+ if all(x.shape == new_images[0].shape for x in new_images):
478
+ new_images = torch.stack(new_images, dim=0)
479
+ return new_images
480
+
481
+
482
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
483
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
484
+
485
+ def insert_separator(X, sep):
486
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
487
+
488
+ input_ids = []
489
+ offset = 0
490
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
491
+ offset = 1
492
+ input_ids.append(prompt_chunks[0][0])
493
+
494
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
495
+ input_ids.extend(x[offset:])
496
+
497
+ if return_tensors is not None:
498
+ if return_tensors == 'pt':
499
+ return torch.tensor(input_ids, dtype=torch.long)
500
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
501
+ return input_ids
502
+
503
+
504
+ def get_model_name_from_path(model_path):
505
+ model_path = model_path.strip("/")
506
+ model_paths = model_path.split("/")
507
+ if model_paths[-1].startswith('checkpoint-'):
508
+ return model_paths[-2] + "_" + model_paths[-1]
509
+ else:
510
+ return model_paths[-1]
511
+
512
+
513
+ class KeywordsStoppingCriteria(StoppingCriteria):
514
+ def __init__(self, keywords, tokenizer, input_ids):
515
+ self.keywords = keywords
516
+ self.keyword_ids = []
517
+ self.max_keyword_len = 0
518
+ for keyword in keywords:
519
+ cur_keyword_ids = tokenizer(keyword).input_ids
520
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
521
+ cur_keyword_ids = cur_keyword_ids[1:]
522
+ if len(cur_keyword_ids) > self.max_keyword_len:
523
+ self.max_keyword_len = len(cur_keyword_ids)
524
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
525
+ self.tokenizer = tokenizer
526
+ self.start_len = input_ids.shape[1]
527
+
528
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
529
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
530
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
531
+ for keyword_id in self.keyword_ids:
532
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
533
+ return True
534
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
535
+ for keyword in self.keywords:
536
+ if keyword in outputs:
537
+ return True
538
+ return False
539
+
540
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
541
+ outputs = []
542
+ for i in range(output_ids.shape[0]):
543
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
544
+ return all(outputs)
545
 
546
 
547
 
 
587
  if isinstance(model, str):
588
  checkpoint_path = model
589
  # print(f'loading model from {checkpoint_path}...')
590
+ model = AutoModelForCausalLM.from_pretrained(
591
  checkpoint_path,
592
  torch_dtype=torch.float16,
593
+ trust_remote_code=True
594
  )
595
  # print('model load over')
596
  config = model.config