chendl commited on
Commit
7f11231
1 Parent(s): 1fb7e67

update chat

Browse files
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -366,11 +366,25 @@ class Chat:
366
  image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
367
  image_start_index_list = [[x] for x in image_start_index_list]
368
  image_nums = [1] * len(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  # and torch.cuda.amp.autocast(dtype=torch.float16)
370
  with torch.no_grad():
371
- outputs = model(
372
- vision_x=vision_x,
373
- lang_x=lang_x,
374
  attention_mask=attention_mask,
375
  image_nums=image_nums,
376
  image_start_index_list=image_start_index_list,
@@ -411,7 +425,7 @@ class Chat:
411
  # # conv.messages[-1][1] = output_text
412
  # print(
413
  # f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
414
- output_text = "here"
415
  return output_text, out_image
416
 
417
  def upload_img(self, image, conv, img_list):
 
366
  image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
367
  image_start_index_list = [[x] for x in image_start_index_list]
368
  image_nums = [1] * len(input_ids)
369
+ added_bbox_list = []
370
+ with torch.inference_mode():
371
+ text_outputs = self.model.generate(
372
+ batch_images,
373
+ input_ids,
374
+ attention_mask=attention_mask,
375
+ max_new_tokens=20,
376
+ # min_new_tokens=8,
377
+ num_beams=1,
378
+ # length_penalty=0,
379
+ image_start_index_list=image_start_index_list,
380
+ image_nums=image_nums,
381
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
382
+ )
383
  # and torch.cuda.amp.autocast(dtype=torch.float16)
384
  with torch.no_grad():
385
+ outputs = self.model(
386
+ vision_x=batch_images,
387
+ lang_x=input_ids,
388
  attention_mask=attention_mask,
389
  image_nums=image_nums,
390
  image_start_index_list=image_start_index_list,
 
425
  # # conv.messages[-1][1] = output_text
426
  # print(
427
  # f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
428
+ output_text = self.tokenizer.decode(text_outputs[0])
429
  return output_text, out_image
430
 
431
  def upload_img(self, image, conv, img_list):