yeecin commited on
Commit
0862b0a
1 Parent(s): 2280605

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -29
app.py CHANGED
@@ -1,35 +1,95 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from transformers import BlipProcessor ,BlipForConditionalGeneration
4
- from PIL import Image
5
- from transformers import CLIPProcessor, ChineseCLIPVisionModel ,AutoProcessor
 
 
6
 
7
- # 设置环境变量 HF_HOME 和 HF_ENDPOINT
8
- # os.environ['HF_HOME'] = 'D:/AI/OCR/img2text/models'
9
- # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
13
- # processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
14
- # 加载模型和处理器
15
- # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
- # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
17
- processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
18
- model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
19
- def generate_caption(image):
20
- inputs = processor(image, return_tensors="pt")
21
- outputs = model(**inputs)
22
- description = processor.decode(outputs[0], skip_special_tokens=True)
23
- return description
24
 
25
- # 创建Gradio接口
26
- gradio_app = gr.Interface(
27
- fn=generate_caption,
28
- inputs=gr.Image(type="pil"),
29
- outputs="text",
30
- title="图片描述生成器",
31
- description="上传一张图片,生成相应的描述。"
32
- )
33
 
34
- if __name__ == "__main__":
35
- gradio_app.launch()
 
 
1
+ # import os
2
+ # import gradio as gr
3
+ # from transformers import BlipProcessor ,BlipForConditionalGeneration
4
+ # from PIL import Image
5
+ # from transformers import CLIPProcessor, ChineseCLIPVisionModel ,AutoProcessor
6
+ #
7
+ # # 设置环境变量 HF_HOME 和 HF_ENDPOINT
8
+ # # os.environ['HF_HOME'] = 'D:/AI/OCR/img2text/models'
9
+ # # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
10
+ #
11
+ #
12
+ # # model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
13
+ # # processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
14
+ # # 加载模型和处理器
15
+ # # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
+ # # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
17
+ # processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
18
+ # model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
19
+ # def generate_caption(image):
20
+ # # 确保 image 是 PIL.Image 类型
21
+ # if not isinstance(image, Image.Image):
22
+ # raise ValueError("Input must be a PIL.Image")
23
+ #
24
+ # inputs = processor(image, return_tensors="pt")
25
+ # input_ids = inputs.get("input_ids")
26
+ # if input_ids is None:
27
+ # raise ValueError("Processor did not return input_ids")
28
+ #
29
+ # outputs = model.generate(input_ids=input_ids, max_length=50)
30
+ # description = processor.decode(outputs[0], skip_special_tokens=True)
31
+ # return description
32
+ #
33
+ # # 创建Gradio接口
34
+ # gradio_app = gr.Interface(
35
+ # fn=generate_caption,
36
+ # inputs=gr.Image(type="pil"),
37
+ # outputs="text",
38
+ # title="图片描述生成器",
39
+ # description="上传一张图片,生成相应的描述。"
40
+ # )
41
+ #
42
+ # if __name__ == "__main__":
43
+ # gradio_app.launch()
44
  import gradio as gr
45
+ import torch
46
+ import os
47
+ from transformers import BlipForConditionalGeneration, BlipProcessor, GenerationConfig
48
+
49
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
 
51
+ _MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
52
+ HF_TOKEN = os.getenv('HF_TOKEN')
53
+ processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN)
54
+ model = BlipForConditionalGeneration.from_pretrained(
55
+ _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device)
56
 
57
+ def inference(raw_image, model_n, strategy):
58
+ if model_n == 'Image Captioning':
59
+ input = processor(raw_image, return_tensors="pt").to(device)
60
+ with torch.no_grad():
61
+ if strategy == "Beam search":
62
+ config = GenerationConfig(
63
+ do_sample=False,
64
+ num_beams=3,
65
+ max_length=50,
66
+ min_length=5,
67
+ )
68
+ captions = model.generate(**input, generation_config=config)
69
+ else:
70
+ config = GenerationConfig(
71
+ do_sample=True,
72
+ top_p=0.9,
73
+ max_length=50,
74
+ min_length=5,
75
+ )
76
+ captions = model.generate(**input, generation_config=config)
77
+ caption = processor.decode(captions[0], skip_special_tokens=True)
78
+ caption = caption.replace(' ', '')
79
+ print(caption)
80
+ return 'caption: ' + caption
81
 
82
+ inputs = [
83
+ gr.Image(type='pil', label="Upload Image"),
84
+ gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),
85
+ gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")
86
+ ]
87
+ outputs = gr.Textbox(label="Output")
 
 
 
 
 
 
88
 
89
+ title = "BLIP"
90
+ description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
91
+ article = "<p style='text-align: center'><a href='https://github.com/IDEA-CCNL/Fengshenbang-LM' target='_blank'>Github Repo</a></p>"
 
 
 
 
 
92
 
93
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[
94
+ ['demo.jpg', "Image Captioning", "Nucleus sampling"]
95
+ ]).launch()