img2text / app.py
yeecin
Upload app.py
86fce4a
raw
history blame contribute delete
No virus
4.65 kB
# import os
# import gradio as gr
# from transformers import BlipProcessor ,BlipForConditionalGeneration
# from PIL import Image
# from transformers import CLIPProcessor, ChineseCLIPVisionModel ,AutoProcessor
#
# # 设置环境变量 HF_HOME 和 HF_ENDPOINT
# # os.environ['HF_HOME'] = 'D:/AI/OCR/img2text/models'
# # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
#
#
# # model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
# # processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
# # 加载模型和处理器
# # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
# model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
# def generate_caption(image):
# # 确保 image 是 PIL.Image 类型
# if not isinstance(image, Image.Image):
# raise ValueError("Input must be a PIL.Image")
#
# inputs = processor(image, return_tensors="pt")
# input_ids = inputs.get("input_ids")
# if input_ids is None:
# raise ValueError("Processor did not return input_ids")
#
# outputs = model.generate(input_ids=input_ids, max_length=50)
# description = processor.decode(outputs[0], skip_special_tokens=True)
# return description
#
# # 创建Gradio接口
# gradio_app = gr.Interface(
# fn=generate_caption,
# inputs=gr.Image(type="pil"),
# outputs="text",
# title="图片描述生成器",
# description="上传一张图片,生成相应的描述。"
# )
#
# if __name__ == "__main__":
# gradio_app.launch()
import gradio as gr
import torch
import os
from transformers import BlipForConditionalGeneration, BlipProcessor, GenerationConfig
print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
HF_TOKEN = os.getenv('HF_TOKEN')
processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN)
model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN).eval().to(device)
# processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN)
# model = BlipForConditionalGeneration.from_pretrained(
# _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device)
def inference(raw_image, model_n, strategy):
if model_n == 'Image Captioning':
inputs = processor(raw_image ,return_tensors= "pt").to(device)
with torch.no_grad():
if strategy == "Beam search":
# Beam search,即集束搜索,每次生成多个词,然后选择概率最大的前 k 个词,然后继续生成,直到生成结束
config = GenerationConfig(
do_sample=False,
num_beams=3,
max_length=50,
min_length=5,
)
captions = model.generate(**inputs ,generation_config=config)
else:
# Nucleus sampling,即 top-p sampling,只保留累积概率大于 p 的词,然后重新归一化,得到一个新的概率分布,再从中采样,这样可以保证采样的结果更多样
config = GenerationConfig(
do_sample=True,
top_p=0.8,
max_length=50,
min_length=5,
)
captions = model.generate(**inputs ,generation_config=config)
caption = processor.decode(captions[0], skip_special_tokens=True)
caption = caption.replace(' ', '')
print(caption)
return caption
inputs = [
gr.Image(type='pil', label="Upload Image"),
gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),# 任务选择,目前只有图片描述生成
gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")# 两种生成策略,Beam search 和 Nucleus sampling,前者生成的结果更准确,后者更多样
]
outputs = gr.Textbox(label="Output")
title = "图片描述生成器"
gradio_app=gr.Interface(inference, inputs, outputs, title=title, examples=[
['demo.jpg', "Image Captioning", "Nucleus sampling"]
])
if __name__ == "__main__":
gradio_app.launch()