yeecin
commited on
Commit
•
20f03a5
1
Parent(s):
0862b0a
Upload app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
# # 加载模型和处理器
|
15 |
# # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
16 |
# # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
|
17 |
# processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
|
18 |
# model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
|
19 |
# def generate_caption(image):
|
@@ -50,46 +51,53 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
50 |
|
51 |
_MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
|
52 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
57 |
def inference(raw_image, model_n, strategy):
|
58 |
if model_n == 'Image Captioning':
|
59 |
-
|
60 |
with torch.no_grad():
|
61 |
if strategy == "Beam search":
|
|
|
62 |
config = GenerationConfig(
|
63 |
do_sample=False,
|
64 |
num_beams=3,
|
65 |
max_length=50,
|
66 |
min_length=5,
|
67 |
)
|
68 |
-
captions = model.generate(**
|
69 |
else:
|
|
|
70 |
config = GenerationConfig(
|
71 |
do_sample=True,
|
72 |
top_p=0.9,
|
73 |
max_length=50,
|
74 |
min_length=5,
|
75 |
)
|
76 |
-
captions = model.generate(**
|
77 |
caption = processor.decode(captions[0], skip_special_tokens=True)
|
78 |
caption = caption.replace(' ', '')
|
79 |
print(caption)
|
80 |
-
return
|
81 |
|
82 |
inputs = [
|
83 |
gr.Image(type='pil', label="Upload Image"),
|
84 |
-
gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task")
|
85 |
-
gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")
|
86 |
]
|
87 |
outputs = gr.Textbox(label="Output")
|
88 |
|
89 |
-
title = "
|
90 |
-
description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
91 |
-
article = "<p style='text-align: center'><a href='https://github.com/IDEA-CCNL/Fengshenbang-LM' target='_blank'>Github Repo</a></p>"
|
92 |
|
93 |
-
gr.Interface(inference, inputs, outputs, title=title,
|
94 |
['demo.jpg', "Image Captioning", "Nucleus sampling"]
|
95 |
-
])
|
|
|
|
|
|
|
|
14 |
# # 加载模型和处理器
|
15 |
# # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
16 |
# # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
17 |
+
|
18 |
# processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
|
19 |
# model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
|
20 |
# def generate_caption(image):
|
|
|
51 |
|
52 |
_MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
|
53 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
54 |
+
|
55 |
+
processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN)
|
56 |
+
model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN).eval().to(device)
|
57 |
+
|
58 |
+
# processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN)
|
59 |
+
# model = BlipForConditionalGeneration.from_pretrained(
|
60 |
+
# _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device)
|
61 |
|
62 |
def inference(raw_image, model_n, strategy):
|
63 |
if model_n == 'Image Captioning':
|
64 |
+
inputs = processor(raw_image ,return_tensors= "pt").to(device)
|
65 |
with torch.no_grad():
|
66 |
if strategy == "Beam search":
|
67 |
+
# Beam search,即集束搜索,每次生成多个词,然后选择概率最大的前 k 个词,然后继续生成,直到生成结束
|
68 |
config = GenerationConfig(
|
69 |
do_sample=False,
|
70 |
num_beams=3,
|
71 |
max_length=50,
|
72 |
min_length=5,
|
73 |
)
|
74 |
+
captions = model.generate(**inputs ,generation_config=config)
|
75 |
else:
|
76 |
+
# Nucleus sampling,即 top-p sampling,只保留累积概率大于 p 的词,然后重新归一化,得到一个新的概率分布,再从中采样,这样可以保证采样的结果更多样
|
77 |
config = GenerationConfig(
|
78 |
do_sample=True,
|
79 |
top_p=0.9,
|
80 |
max_length=50,
|
81 |
min_length=5,
|
82 |
)
|
83 |
+
captions = model.generate(**inputs ,generation_config=config)
|
84 |
caption = processor.decode(captions[0], skip_special_tokens=True)
|
85 |
caption = caption.replace(' ', '')
|
86 |
print(caption)
|
87 |
+
return caption
|
88 |
|
89 |
inputs = [
|
90 |
gr.Image(type='pil', label="Upload Image"),
|
91 |
+
gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),# 任务选择,目前只有图片描述生成
|
92 |
+
gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")# 两种生成策略,Beam search 和 Nucleus sampling,前者生成的结果更准确,后者更多样
|
93 |
]
|
94 |
outputs = gr.Textbox(label="Output")
|
95 |
|
96 |
+
title = "图片描述生成器"
|
|
|
|
|
97 |
|
98 |
+
gradio_app=gr.Interface(inference, inputs, outputs, title=title, examples=[
|
99 |
['demo.jpg', "Image Captioning", "Nucleus sampling"]
|
100 |
+
])
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
gradio_app.launch()
|