Spaces:
Runtime error
Runtime error
update chat
Browse files
multimodal/open_flamingo/chat/conversation.py
CHANGED
@@ -366,11 +366,25 @@ class Chat:
|
|
366 |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
367 |
image_start_index_list = [[x] for x in image_start_index_list]
|
368 |
image_nums = [1] * len(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
# and torch.cuda.amp.autocast(dtype=torch.float16)
|
370 |
with torch.no_grad():
|
371 |
-
outputs = model(
|
372 |
-
vision_x=
|
373 |
-
lang_x=
|
374 |
attention_mask=attention_mask,
|
375 |
image_nums=image_nums,
|
376 |
image_start_index_list=image_start_index_list,
|
@@ -411,7 +425,7 @@ class Chat:
|
|
411 |
# # conv.messages[-1][1] = output_text
|
412 |
# print(
|
413 |
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
414 |
-
output_text =
|
415 |
return output_text, out_image
|
416 |
|
417 |
def upload_img(self, image, conv, img_list):
|
|
|
366 |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
367 |
image_start_index_list = [[x] for x in image_start_index_list]
|
368 |
image_nums = [1] * len(input_ids)
|
369 |
+
added_bbox_list = []
|
370 |
+
with torch.inference_mode():
|
371 |
+
text_outputs = self.model.generate(
|
372 |
+
batch_images,
|
373 |
+
input_ids,
|
374 |
+
attention_mask=attention_mask,
|
375 |
+
max_new_tokens=20,
|
376 |
+
# min_new_tokens=8,
|
377 |
+
num_beams=1,
|
378 |
+
# length_penalty=0,
|
379 |
+
image_start_index_list=image_start_index_list,
|
380 |
+
image_nums=image_nums,
|
381 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
382 |
+
)
|
383 |
# and torch.cuda.amp.autocast(dtype=torch.float16)
|
384 |
with torch.no_grad():
|
385 |
+
outputs = self.model(
|
386 |
+
vision_x=batch_images,
|
387 |
+
lang_x=input_ids,
|
388 |
attention_mask=attention_mask,
|
389 |
image_nums=image_nums,
|
390 |
image_start_index_list=image_start_index_list,
|
|
|
425 |
# # conv.messages[-1][1] = output_text
|
426 |
# print(
|
427 |
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
428 |
+
output_text = self.tokenizer.decode(text_outputs[0])
|
429 |
return output_text, out_image
|
430 |
|
431 |
def upload_img(self, image, conv, img_list):
|