liuyizhang commited on
Commit
bd50af0
1 Parent(s): 0cc37e5

add kosmos-2

Browse files
Files changed (3) hide show
  1. app.py +104 -26
  2. kosmos_utils.py +233 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -3,7 +3,8 @@ import warnings
3
  warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
- os.system("pip install gradio==3.36.1")
 
7
  import gradio as gr
8
 
9
  from loguru import logger
@@ -50,7 +51,7 @@ from io import BytesIO
50
  from diffusers import StableDiffusionInpaintPipeline
51
  from huggingface_hub import hf_hub_download
52
 
53
- from utils import computer_info
54
  # relate anything
55
  from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
56
  from ram_train_eval import RamModel,RamPredictor
@@ -61,6 +62,10 @@ from lama_cleaner.helper import (
61
  resize_max_size,
62
  )
63
 
 
 
 
 
64
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
65
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
66
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
@@ -81,6 +86,8 @@ sd_model = None
81
  lama_cleaner_model= None
82
  lama_cleaner_model_device = device
83
  ram_model = None
 
 
84
 
85
  def get_sam_vit_h_4b8939():
86
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
@@ -254,6 +261,7 @@ def set_device():
254
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
255
  else:
256
  device = 'cpu'
 
257
 
258
  def load_groundingdino_model():
259
  # initialize groundingdino model
@@ -366,6 +374,8 @@ class Ram_Predictor(RamPredictor):
366
  def load_ram_model():
367
  # load ram model
368
  global ram_model
 
 
369
  model_path = "./checkpoints/ram_epoch12.pth"
370
  ram_config = dict(
371
  model=dict(
@@ -510,19 +520,23 @@ mask_source_draw = "draw a mask on input image"
510
  mask_source_segment = "type what to detect below"
511
 
512
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
513
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
514
-
 
 
 
 
515
  if (task_type == 'relate anything'):
516
  output_images = relate_anything(input_image['image'], num_relation)
517
- return output_images, gr.Gallery.update(label='relate images')
518
 
519
  text_prompt = text_prompt.strip()
520
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
521
  if text_prompt == '':
522
- return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂')
523
 
524
  if input_image is None:
525
- return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂')
526
 
527
  file_temp = int(time.time())
528
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
@@ -562,7 +576,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
562
  )
563
  if boxes_filt.size(0) == 0:
564
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
565
- return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂')
566
  boxes_filt_ori = copy.deepcopy(boxes_filt)
567
 
568
  pred_dict = {
@@ -613,7 +627,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
613
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
614
  if task_type == 'detection' or task_type == 'segment':
615
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
616
- return output_images, gr.Gallery.update(label='result images')
617
  elif task_type == 'inpainting' or task_type == 'remove':
618
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
619
  task_type = 'remove'
@@ -678,27 +692,48 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
678
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
679
  output_images.append(image_inpainting)
680
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
681
- return output_images, gr.Gallery.update(label='result images')
682
  else:
683
  logger.info(f"task_type:{task_type} error!")
684
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
685
- return output_images, gr.Gallery.update(label='result images')
686
 
687
  def change_radio_display(task_type, mask_source_radio):
688
  text_prompt_visible = True
689
  inpaint_prompt_visible = False
690
  mask_source_radio_visible = False
691
  num_relation_visible = False
692
- if task_type == "inpainting":
693
- inpaint_prompt_visible = True
694
- if task_type == "inpainting" or task_type == "remove":
695
- mask_source_radio_visible = True
696
- if mask_source_radio == mask_source_draw:
697
- text_prompt_visible = False
698
- if task_type == "relate anything":
699
  text_prompt_visible = False
700
- num_relation_visible = True
701
- return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
 
703
  def get_model_device(module):
704
  try:
@@ -723,12 +758,18 @@ if __name__ == "__main__":
723
  print(f'args = {args}')
724
 
725
  set_device()
726
- get_sam_vit_h_4b8939()
727
  load_groundingdino_model()
728
- load_sam_model()
 
 
 
729
  load_sd_model()
730
  load_lama_cleaner_model()
731
  load_ram_model()
 
 
 
732
 
733
  if os.environ.get('IS_MY_DEBUG') is None:
734
  os.system("pip list")
@@ -744,7 +785,7 @@ if __name__ == "__main__":
744
  with gr.Row():
745
  with gr.Column():
746
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
747
- task_type = gr.Radio(["detection", "segment", "inpainting", "remove", "relate anything"], value="detection",
748
  label='Task type', visible=True)
749
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
750
  value=mask_source_segment, label="Mask from",
@@ -752,6 +793,9 @@ if __name__ == "__main__":
752
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
753
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
754
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
 
 
 
755
  run_button = gr.Button(label="Run", visible=True)
756
  with gr.Accordion("Advanced options", open=False) as advanced_options:
757
  box_threshold = gr.Slider(
@@ -773,16 +817,50 @@ if __name__ == "__main__":
773
  with gr.Column():
774
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
775
  ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
  run_button.click(fn=run_anything_task, inputs=[
778
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[image_gallery, image_gallery], show_progress=True, queue=True)
 
 
779
 
780
- mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
781
- task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
 
 
 
 
782
 
783
  DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
784
  DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
785
  DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
 
786
  DESCRIPTION += f'Thanks for their excellent work.'
787
  DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
788
  <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
 
3
  warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
+ # os.system("pip install gradio==3.36.1")
7
+ os.system("pip install gradio==3.41.2")
8
  import gradio as gr
9
 
10
  from loguru import logger
 
51
  from diffusers import StableDiffusionInpaintPipeline
52
  from huggingface_hub import hf_hub_download
53
 
54
+ from utils import computer_info
55
  # relate anything
56
  from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
57
  from ram_train_eval import RamModel,RamPredictor
 
62
  resize_max_size,
63
  )
64
 
65
+ # from transformers import AutoProcessor, AutoModelForVision2Seq
66
+ import ast
67
+ from kosmos_utils import *
68
+
69
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
70
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
71
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
 
86
  lama_cleaner_model= None
87
  lama_cleaner_model_device = device
88
  ram_model = None
89
+ kosmos_model = None
90
+ kosmos_processor = None
91
 
92
  def get_sam_vit_h_4b8939():
93
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
 
261
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
262
  else:
263
  device = 'cpu'
264
+ print(f'device={device}')
265
 
266
  def load_groundingdino_model():
267
  # initialize groundingdino model
 
374
  def load_ram_model():
375
  # load ram model
376
  global ram_model
377
+ if os.environ.get('IS_MY_DEBUG') is not None:
378
+ return
379
  model_path = "./checkpoints/ram_epoch12.pth"
380
  ram_config = dict(
381
  model=dict(
 
520
  mask_source_segment = "type what to detect below"
521
 
522
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
523
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
524
+ if (task_type == 'Kosmos-2'):
525
+ global kosmos_model, kosmos_processor
526
+ kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
527
+ return None, None, kosmos_image, kosmos_text, kosmos_entities
528
+
529
  if (task_type == 'relate anything'):
530
  output_images = relate_anything(input_image['image'], num_relation)
531
+ return output_images, gr.Gallery.update(label='relate images'), None, None, None
532
 
533
  text_prompt = text_prompt.strip()
534
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
535
  if text_prompt == '':
536
+ return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂'), None, None, None
537
 
538
  if input_image is None:
539
+ return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), None, None, None
540
 
541
  file_temp = int(time.time())
542
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
 
576
  )
577
  if boxes_filt.size(0) == 0:
578
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
579
+ return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂'), None, None, None
580
  boxes_filt_ori = copy.deepcopy(boxes_filt)
581
 
582
  pred_dict = {
 
627
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
628
  if task_type == 'detection' or task_type == 'segment':
629
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
630
+ return output_images, gr.Gallery.update(label='result images'), None, None, None
631
  elif task_type == 'inpainting' or task_type == 'remove':
632
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
633
  task_type = 'remove'
 
692
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
693
  output_images.append(image_inpainting)
694
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
695
+ return output_images, gr.Gallery.update(label='result images'), None, None, None
696
  else:
697
  logger.info(f"task_type:{task_type} error!")
698
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
699
+ return output_images, gr.Gallery.update(label='result images'), None, None, None
700
 
701
  def change_radio_display(task_type, mask_source_radio):
702
  text_prompt_visible = True
703
  inpaint_prompt_visible = False
704
  mask_source_radio_visible = False
705
  num_relation_visible = False
706
+
707
+ image_gallery_visible = True
708
+ kosmos_input_visible = False
709
+ kosmos_output_visible = False
710
+ kosmos_text_output_visible = False
711
+
712
+ if task_type == "Kosmos-2":
713
  text_prompt_visible = False
714
+ image_gallery_visible = False
715
+ kosmos_input_visible = True
716
+ kosmos_output_visible = True
717
+ kosmos_text_output_visible = True
718
+ else:
719
+ if task_type == "inpainting":
720
+ inpaint_prompt_visible = True
721
+ if task_type == "inpainting" or task_type == "remove":
722
+ mask_source_radio_visible = True
723
+ if mask_source_radio == mask_source_draw:
724
+ text_prompt_visible = False
725
+ if task_type == "relate anything":
726
+ text_prompt_visible = False
727
+ num_relation_visible = True
728
+
729
+ return (gr.Textbox.update(visible=text_prompt_visible),
730
+ gr.Textbox.update(visible=inpaint_prompt_visible),
731
+ gr.Radio.update(visible=mask_source_radio_visible),
732
+ gr.Slider.update(visible=num_relation_visible),
733
+ gr.Gallery.update(visible=image_gallery_visible),
734
+ gr.Radio.update(visible=kosmos_input_visible),
735
+ gr.Image.update(visible=kosmos_output_visible),
736
+ gr.HighlightedText.update(visible=kosmos_text_output_visible))
737
 
738
  def get_model_device(module):
739
  try:
 
758
  print(f'args = {args}')
759
 
760
  set_device()
761
+
762
  load_groundingdino_model()
763
+ if os.environ.get('IS_MY_DEBUG') is None:
764
+ get_sam_vit_h_4b8939()
765
+ load_sam_model()
766
+
767
  load_sd_model()
768
  load_lama_cleaner_model()
769
  load_ram_model()
770
+
771
+ if os.environ.get('IS_MY_DEBUG') is None:
772
+ kosmos_model, kosmos_processor = load_kosmos_model(device)
773
 
774
  if os.environ.get('IS_MY_DEBUG') is None:
775
  os.system("pip list")
 
785
  with gr.Row():
786
  with gr.Column():
787
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
788
+ task_type = gr.Radio(["detection", "segment", "inpainting", "remove", "relate anything", "Kosmos-2"], value="detection",
789
  label='Task type', visible=True)
790
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
791
  value=mask_source_segment, label="Mask from",
 
793
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
794
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
795
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
796
+
797
+ kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
798
+
799
  run_button = gr.Button(label="Run", visible=True)
800
  with gr.Accordion("Advanced options", open=False) as advanced_options:
801
  box_threshold = gr.Slider(
 
817
  with gr.Column():
818
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
819
  ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
820
+ kosmos_output = gr.Image(type="pil", label="result images", visible=False)
821
+ kosmos_text_output = gr.HighlightedText(
822
+ label="Generated Description",
823
+ combine_adjacent=False,
824
+ show_legend=True,
825
+ visible=False,
826
+ ).style(color_map=color_map)
827
+ # record which text span (label) is selected
828
+ selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
829
+
830
+ # record the current `entities`
831
+ entity_output = gr.Textbox(visible=False)
832
+
833
+ # get the current selected span label
834
+ def get_text_span_label(evt: gr.SelectData):
835
+ if evt.value[-1] is None:
836
+ return -1
837
+ return int(evt.value[-1])
838
+ # and set this information to `selected`
839
+ kosmos_text_output.select(get_text_span_label, None, selected)
840
+
841
+ # update output image when we change the span (enity) selection
842
+ def update_output_image(img_input, image_output, entities, idx):
843
+ entities = ast.literal_eval(entities)
844
+ updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
845
+ return updated_image
846
+ selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
847
 
848
  run_button.click(fn=run_anything_task, inputs=[
849
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
850
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
851
+ outputs=[image_gallery, image_gallery, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
852
 
853
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
854
+ outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
855
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
856
+ outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
857
+ image_gallery, kosmos_input, kosmos_output, kosmos_text_output
858
+ ])
859
 
860
  DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
861
  DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
862
  DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
863
+ DESCRIPTION += f'Kosmos-2 from [RelateAnything](https://huggingface.co/spaces/ydshieh/Kosmos-2). <br>'
864
  DESCRIPTION += f'Thanks for their excellent work.'
865
  DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
866
  <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
kosmos_utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import os
4
+ import requests
5
+ import torch
6
+ import torchvision.transforms as torchvision_T
7
+ from PIL import Image
8
+ from transformers import AutoProcessor, AutoModelForVision2Seq
9
+ import cv2
10
+ import ast
11
+
12
+ colors = [
13
+ (0, 255, 0),
14
+ (0, 0, 255),
15
+ (255, 255, 0),
16
+ (255, 0, 255),
17
+ (0, 255, 255),
18
+ (114, 128, 250),
19
+ (0, 165, 255),
20
+ (0, 128, 0),
21
+ (144, 238, 144),
22
+ (238, 238, 175),
23
+ (255, 191, 0),
24
+ (0, 128, 0),
25
+ (226, 43, 138),
26
+ (255, 0, 255),
27
+ (0, 215, 255),
28
+ (255, 0, 0),
29
+ ]
30
+
31
+ color_map = {
32
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors)
33
+ }
34
+
35
+
36
+ def is_overlapping(rect1, rect2):
37
+ x1, y1, x2, y2 = rect1
38
+ x3, y3, x4, y4 = rect2
39
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
40
+
41
+
42
+ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, entity_index=-1):
43
+ """_summary_
44
+ Args:
45
+ image (_type_): image or image path
46
+ collect_entity_location (_type_): _description_
47
+ """
48
+ if isinstance(image, Image.Image):
49
+ image_h = image.height
50
+ image_w = image.width
51
+ image = np.array(image)[:, :, [2, 1, 0]]
52
+ elif isinstance(image, str):
53
+ if os.path.exists(image):
54
+ pil_img = Image.open(image).convert("RGB")
55
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
56
+ image_h = pil_img.height
57
+ image_w = pil_img.width
58
+ else:
59
+ raise ValueError(f"invaild image path, {image}")
60
+ elif isinstance(image, torch.Tensor):
61
+ # pdb.set_trace()
62
+ image_tensor = image.cpu()
63
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
64
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
65
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
66
+ pil_img = torchvision_T.ToPILImage()(image_tensor)
67
+ image_h = pil_img.height
68
+ image_w = pil_img.width
69
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
70
+ else:
71
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
72
+
73
+ if len(entities) == 0:
74
+ return image
75
+
76
+ indices = list(range(len(entities)))
77
+ if entity_index >= 0:
78
+ indices = [entity_index]
79
+
80
+ # Not to show too many bboxes
81
+ entities = entities[:len(color_map)]
82
+
83
+ new_image = image.copy()
84
+ previous_bboxes = []
85
+ # size of text
86
+ text_size = 1
87
+ # thickness of text
88
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
89
+ box_line = 3
90
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
91
+ base_height = int(text_height * 0.675)
92
+ text_offset_original = text_height - base_height
93
+ text_spaces = 3
94
+
95
+ # num_bboxes = sum(len(x[-1]) for x in entities)
96
+ used_colors = colors # random.sample(colors, k=num_bboxes)
97
+
98
+ color_id = -1
99
+ for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
100
+ color_id += 1
101
+ if entity_idx not in indices:
102
+ continue
103
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
104
+ # if start is None and bbox_id > 0:
105
+ # color_id += 1
106
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
107
+
108
+ # draw bbox
109
+ # random color
110
+ color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
111
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
112
+
113
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
114
+
115
+ x1 = orig_x1 - l_o
116
+ y1 = orig_y1 - l_o
117
+
118
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
119
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
120
+ x1 = orig_x1 + r_o
121
+
122
+ # add text background
123
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
124
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
125
+
126
+ for prev_bbox in previous_bboxes:
127
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
128
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
129
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
130
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
131
+
132
+ if text_bg_y2 >= image_h:
133
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
134
+ text_bg_y2 = image_h
135
+ y1 = image_h
136
+ break
137
+
138
+ alpha = 0.5
139
+ for i in range(text_bg_y1, text_bg_y2):
140
+ for j in range(text_bg_x1, text_bg_x2):
141
+ if i < image_h and j < image_w:
142
+ if j < text_bg_x1 + 1.35 * c_width:
143
+ # original color
144
+ bg_color = color
145
+ else:
146
+ # white
147
+ bg_color = [255, 255, 255]
148
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
149
+
150
+ cv2.putText(
151
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
152
+ )
153
+ # previous_locations.append((x1, y1))
154
+ previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
155
+
156
+ pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
157
+ if save_path:
158
+ pil_image.save(save_path)
159
+ if show:
160
+ pil_image.show()
161
+
162
+ return pil_image
163
+
164
+ def load_kosmos_model(device):
165
+ ckpt = "ydshieh/kosmos-2-patch14-224"
166
+ kosmos_model = AutoModelForVision2Seq.from_pretrained(ckpt, trust_remote_code=True).to(device)
167
+ kosmos_processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
168
+ return kosmos_model, kosmos_processor
169
+
170
+ def kosmos_generate_predictions(image_input, text_input, kosmos_model, kosmos_processor):
171
+ if kosmos_model is None:
172
+ return None, None, None
173
+
174
+ # Save the image and load it again to match the original Kosmos-2 demo.
175
+ # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
176
+ user_image_path = "/tmp/user_input_test_image.jpg"
177
+ image_input.save(user_image_path)
178
+ # This might give different results from the original argument `image_input`
179
+ image_input = Image.open(user_image_path)
180
+
181
+ if text_input == "Brief":
182
+ text_input = "<grounding>An image of"
183
+ elif text_input == "Detailed":
184
+ text_input = "<grounding>Describe this image in detail:"
185
+ else:
186
+ text_input = f"<grounding>{text_input}"
187
+
188
+ inputs = kosmos_processor(text=text_input, images=image_input, return_tensors="pt")
189
+
190
+ generated_ids = kosmos_model.generate(
191
+ pixel_values=inputs["pixel_values"].to("cuda"),
192
+ input_ids=inputs["input_ids"][:, :-1].to("cuda"),
193
+ attention_mask=inputs["attention_mask"][:, :-1].to("cuda"),
194
+ img_features=None,
195
+ img_attn_mask=inputs["img_attn_mask"][:, :-1].to("cuda"),
196
+ use_cache=True,
197
+ max_new_tokens=128,
198
+ )
199
+ generated_text = kosmos_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
200
+
201
+ # By default, the generated text is cleanup and the entities are extracted.
202
+ processed_text, entities = kosmos_processor.post_process_generation(generated_text)
203
+
204
+ annotated_image = draw_entity_boxes_on_image(image_input, entities, show=False)
205
+
206
+ color_id = -1
207
+ entity_info = []
208
+ filtered_entities = []
209
+ for entity in entities:
210
+ entity_name, (start, end), bboxes = entity
211
+ if start == end:
212
+ # skip bounding bbox without a `phrase` associated
213
+ continue
214
+ color_id += 1
215
+ # for bbox_id, _ in enumerate(bboxes):
216
+ # if start is None and bbox_id > 0:
217
+ # color_id += 1
218
+ entity_info.append(((start, end), color_id))
219
+ filtered_entities.append(entity)
220
+
221
+ colored_text = []
222
+ prev_start = 0
223
+ end = 0
224
+ for idx, ((start, end), color_id) in enumerate(entity_info):
225
+ if start > prev_start:
226
+ colored_text.append((processed_text[prev_start:start], None))
227
+ colored_text.append((processed_text[start:end], f"{color_id}"))
228
+ prev_start = end
229
+
230
+ if end < len(processed_text):
231
+ colored_text.append((processed_text[end:len(processed_text)], None))
232
+
233
+ return annotated_image, colored_text, str(filtered_entities)
requirements.txt CHANGED
@@ -23,6 +23,7 @@ numba
23
  scipy
24
  safetensors
25
  pynvml
 
26
 
27
  lama-cleaner==1.1.2
28
  openmim==0.1.5
 
23
  scipy
24
  safetensors
25
  pynvml
26
+ sentencepiece
27
 
28
  lama-cleaner==1.1.2
29
  openmim==0.1.5