chendl commited on
Commit
e82e643
1 Parent(s): 098738d

update chat

Browse files
app.py CHANGED
@@ -248,12 +248,17 @@ def gradio_ask(user_message, chatbot, chat_state):
248
 
249
 
250
  def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
251
- llm_message = \
252
  chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
253
- max_length=2000)[0]
254
 
255
  chatbot[-1][1] = llm_message
256
- return chatbot, chat_state, img_list
 
 
 
 
 
257
 
258
 
259
  with gr.Blocks() as demo:
 
248
 
249
 
250
  def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
251
+ llm_message,image = \
252
  chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
253
+ max_length=2000)
254
 
255
  chatbot[-1][1] = llm_message
256
+ if image==None:
257
+ return chatbot, chat_state, img_list
258
+ else:
259
+ path = build_image(image)
260
+ chatbot = chatbot + [[(path,), None]]
261
+
262
 
263
 
264
  with gr.Blocks() as demo:
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -260,6 +260,12 @@ def preprocess_conv(data):
260
  conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
261
  return conversation
262
 
 
 
 
 
 
 
263
  class Chat:
264
  def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
265
  self.model = model
@@ -322,6 +328,7 @@ class Chat:
322
  # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
323
  # image_path = input("Please enter the image path: ")
324
  image = img_list[0].convert("RGB")
 
325
  image = image.resize((size, size))
326
  print(f"image size: {image.size}")
327
  batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
@@ -370,14 +377,31 @@ class Chat:
370
  image_start_index_list=image_start_index_list,
371
  image_nums=image_nums,
372
  )
373
- output_token = outputs[0, input_ids.shape[1]:]
374
- output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
375
- conv[-1]["value"] = output_text
376
- # conv.messages[-1][1] = output_text
377
- print(
378
- f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
 
 
 
 
 
 
 
 
 
 
379
 
380
- return output_text, output_token.cpu().numpy()
 
 
 
 
 
 
 
381
 
382
  def upload_img(self, image, conv, img_list):
383
  img_list.append(image)
 
260
  conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
261
  return conversation
262
 
263
+ def preprocess_image(sample, image_processor):
264
+ image = image_processor(sample)
265
+ if isinstance(image, transformers.image_processing_utils.BatchFeature):
266
+ image = torch.tensor(image["pixel_values"][0])
267
+ return image
268
+
269
  class Chat:
270
  def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
271
  self.model = model
 
328
  # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
329
  # image_path = input("Please enter the image path: ")
330
  image = img_list[0].convert("RGB")
331
+ image_ori = image
332
  image = image.resize((size, size))
333
  print(f"image size: {image.size}")
334
  batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
 
377
  image_start_index_list=image_start_index_list,
378
  image_nums=image_nums,
379
  )
380
+ boxes = outputs["boxes"]
381
+ scores = outputs["scores"]
382
+ if len(scores) > 0:
383
+ box = boxes[scores.argmax()] / 224
384
+ print(f"{box}")
385
+ out_image = None
386
+
387
+ if len(boxes)>0:
388
+ open_cv_image = np.array(image_ori)
389
+ # Convert RGB to BGR
390
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
391
+ box = box * [width, height, width, height]
392
+ # for box in boxes:
393
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
394
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
395
+
396
 
397
+ # output_token = outputs[0, input_ids.shape[1]:]
398
+ # output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
399
+ # conv[-1]["value"] = output_text
400
+ # # conv.messages[-1][1] = output_text
401
+ # print(
402
+ # f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
403
+ output_text = "here"
404
+ return output_text, out_image
405
 
406
  def upload_img(self, image, conv, img_list):
407
  img_list.append(image)