yeecin commited on
Commit
20f03a5
1 Parent(s): 0862b0a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -14,6 +14,7 @@
14
  # # 加载模型和处理器
15
  # # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
  # # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
17
  # processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
18
  # model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
19
  # def generate_caption(image):
@@ -50,46 +51,53 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
 
51
  _MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
52
  HF_TOKEN = os.getenv('HF_TOKEN')
53
- processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN)
54
- model = BlipForConditionalGeneration.from_pretrained(
55
- _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device)
 
 
 
 
56
 
57
  def inference(raw_image, model_n, strategy):
58
  if model_n == 'Image Captioning':
59
- input = processor(raw_image, return_tensors="pt").to(device)
60
  with torch.no_grad():
61
  if strategy == "Beam search":
 
62
  config = GenerationConfig(
63
  do_sample=False,
64
  num_beams=3,
65
  max_length=50,
66
  min_length=5,
67
  )
68
- captions = model.generate(**input, generation_config=config)
69
  else:
 
70
  config = GenerationConfig(
71
  do_sample=True,
72
  top_p=0.9,
73
  max_length=50,
74
  min_length=5,
75
  )
76
- captions = model.generate(**input, generation_config=config)
77
  caption = processor.decode(captions[0], skip_special_tokens=True)
78
  caption = caption.replace(' ', '')
79
  print(caption)
80
- return 'caption: ' + caption
81
 
82
  inputs = [
83
  gr.Image(type='pil', label="Upload Image"),
84
- gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),
85
- gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")
86
  ]
87
  outputs = gr.Textbox(label="Output")
88
 
89
- title = "BLIP"
90
- description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
91
- article = "<p style='text-align: center'><a href='https://github.com/IDEA-CCNL/Fengshenbang-LM' target='_blank'>Github Repo</a></p>"
92
 
93
- gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[
94
  ['demo.jpg', "Image Captioning", "Nucleus sampling"]
95
- ]).launch()
 
 
 
 
14
  # # 加载模型和处理器
15
  # # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
  # # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
17
+
18
  # processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
19
  # model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
20
  # def generate_caption(image):
 
51
 
52
  _MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
53
  HF_TOKEN = os.getenv('HF_TOKEN')
54
+
55
+ processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN)
56
+ model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN).eval().to(device)
57
+
58
+ # processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN)
59
+ # model = BlipForConditionalGeneration.from_pretrained(
60
+ # _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device)
61
 
62
  def inference(raw_image, model_n, strategy):
63
  if model_n == 'Image Captioning':
64
+ inputs = processor(raw_image ,return_tensors= "pt").to(device)
65
  with torch.no_grad():
66
  if strategy == "Beam search":
67
+ # Beam search,即集束搜索,每次生成多个词,然后选择概率最大的前 k 个词,然后继续生成,直到生成结束
68
  config = GenerationConfig(
69
  do_sample=False,
70
  num_beams=3,
71
  max_length=50,
72
  min_length=5,
73
  )
74
+ captions = model.generate(**inputs ,generation_config=config)
75
  else:
76
+ # Nucleus sampling,即 top-p sampling,只保留累积概率大于 p 的词,然后重新归一化,得到一个新的概率分布,再从中采样,这样可以保证采样的结果更多样
77
  config = GenerationConfig(
78
  do_sample=True,
79
  top_p=0.9,
80
  max_length=50,
81
  min_length=5,
82
  )
83
+ captions = model.generate(**inputs ,generation_config=config)
84
  caption = processor.decode(captions[0], skip_special_tokens=True)
85
  caption = caption.replace(' ', '')
86
  print(caption)
87
+ return caption
88
 
89
  inputs = [
90
  gr.Image(type='pil', label="Upload Image"),
91
+ gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),# 任务选择,目前只有图片描述生成
92
+ gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")# 两种生成策略,Beam search 和 Nucleus sampling,前者生成的结果更准确,后者更多样
93
  ]
94
  outputs = gr.Textbox(label="Output")
95
 
96
+ title = "图片描述生成器"
 
 
97
 
98
+ gradio_app=gr.Interface(inference, inputs, outputs, title=title, examples=[
99
  ['demo.jpg', "Image Captioning", "Nucleus sampling"]
100
+ ])
101
+
102
+ if __name__ == "__main__":
103
+ gradio_app.launch()