RamAnanth1 commited on
Commit
732325f
1 Parent(s): 0d4f09e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from lavis.models import load_model_and_preprocess
3
+ import torch
4
+
5
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
6
+
7
+
8
+ model_name = "blip2_t5_instruct"
9
+ model_type = "flant5xl"
10
+ model, vis_processors, _ = load_model_and_preprocess(
11
+ name=args.model_name,
12
+ model_type=args.model_type,
13
+ is_eval=True,
14
+ device=device,
15
+ )
16
+
17
+ def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method):
18
+ use_nucleus_sampling = decoding_method == "Nucleus sampling"
19
+ print(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling)
20
+ image = vis_processors["eval"](image).unsqueeze(0).to(device)
21
+
22
+ samples = {
23
+ "image": image,
24
+ "prompt": prompt,
25
+ }
26
+
27
+ output = model.generate(
28
+ samples,
29
+ length_penalty=float(len_penalty),
30
+ repetition_penalty=float(repetition_penalty),
31
+ num_beams=beam_size,
32
+ max_length=max_len,
33
+ min_length=min_len,
34
+ top_p=top_p,
35
+ use_nucleus_sampling=use_nucleus_sampling,
36
+ )
37
+
38
+ return output[0]
39
+
40
+ theme = gr.themes.Monochrome(
41
+ primary_hue="indigo",
42
+ secondary_hue="blue",
43
+ neutral_hue="slate",
44
+ radius_size=gr.themes.sizes.radius_sm,
45
+ font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
46
+ )
47
+ css = ".generating {visibility: hidden}"
48
+
49
+ with gr.Blocks(theme=theme, analytics_enabled=False,css=css) as demo:
50
+ with gr.Column(scale=3):
51
+ image_input = gr.Image(type="pil")
52
+ prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2)
53
+ output = gr.Textbox(label="Output")
54
+ submit = gr.Button("Run", variant="primary")
55
+
56
+ with gr.Column(scale=1):
57
+ min_len = gr.Slider(
58
+ minimum=1,
59
+ maximum=50,
60
+ value=1,
61
+ step=1,
62
+ interactive=True,
63
+ label="Min Length",
64
+ )
65
+
66
+ max_len = gr.Slider(
67
+ minimum=10,
68
+ maximum=500,
69
+ value=250,
70
+ step=5,
71
+ interactive=True,
72
+ label="Max Length",
73
+ )
74
+
75
+ sampling = gr.Radio(
76
+ choices=["Beam search", "Nucleus sampling"],
77
+ value="Beam search",
78
+ label="Text Decoding Method",
79
+ interactive=True,
80
+ )
81
+
82
+ top_p = gr.Slider(
83
+ minimum=0.5,
84
+ maximum=1.0,
85
+ value=0.9,
86
+ step=0.1,
87
+ interactive=True,
88
+ label="Top p",
89
+ )
90
+
91
+ beam_size = gr.Slider(
92
+ minimum=1,
93
+ maximum=10,
94
+ value=5,
95
+ step=1,
96
+ interactive=True,
97
+ label="Beam Size",
98
+ )
99
+
100
+ len_penalty = gr.Slider(
101
+ minimum=-1,
102
+ maximum=2,
103
+ value=1,
104
+ step=0.2,
105
+ interactive=True,
106
+ label="Length Penalty",
107
+ )
108
+
109
+ repetition_penalty = gr.Slider(
110
+ minimum=-1,
111
+ maximum=3,
112
+ value=1,
113
+ step=0.2,
114
+ interactive=True,
115
+ label="Repetition Penalty",
116
+ )
117
+
118
+ submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output])
119
+
120
+ demo.queue(concurrency_count=16).launch(debug=True)