RamAnanth1 commited on
Commit
d019ade
1 Parent(s): 5d898d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -66
app.py CHANGED
@@ -11,6 +11,7 @@ model, vis_processors, _ = load_model_and_preprocess(
11
  model_type=model_type,
12
  is_eval=True,
13
  device=device,
 
14
  )
15
 
16
  def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method):
@@ -53,73 +54,74 @@ with gr.Blocks(theme=theme, analytics_enabled=False,css=css) as demo:
53
  The demo is based on the official <a href="https://github.com/salesforce/LAVIS/tree/main/projects/instructblip" style="text-decoration: underline;" target="_blank"> Github </a> implementation
54
  """
55
  )
56
- with gr.Column(scale=3):
57
- image_input = gr.Image(type="pil")
58
- prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2)
59
- output = gr.Textbox(label="Output")
60
- submit = gr.Button("Run", variant="primary")
 
61
 
62
- with gr.Column(scale=1):
63
- min_len = gr.Slider(
64
- minimum=1,
65
- maximum=50,
66
- value=1,
67
- step=1,
68
- interactive=True,
69
- label="Min Length",
70
- )
71
-
72
- max_len = gr.Slider(
73
- minimum=10,
74
- maximum=500,
75
- value=250,
76
- step=5,
77
- interactive=True,
78
- label="Max Length",
79
- )
80
-
81
- sampling = gr.Radio(
82
- choices=["Beam search", "Nucleus sampling"],
83
- value="Beam search",
84
- label="Text Decoding Method",
85
- interactive=True,
86
- )
87
-
88
- top_p = gr.Slider(
89
- minimum=0.5,
90
- maximum=1.0,
91
- value=0.9,
92
- step=0.1,
93
- interactive=True,
94
- label="Top p",
95
- )
96
-
97
- beam_size = gr.Slider(
98
- minimum=1,
99
- maximum=10,
100
- value=5,
101
- step=1,
102
- interactive=True,
103
- label="Beam Size",
104
- )
105
-
106
- len_penalty = gr.Slider(
107
- minimum=-1,
108
- maximum=2,
109
- value=1,
110
- step=0.2,
111
- interactive=True,
112
- label="Length Penalty",
113
- )
114
-
115
- repetition_penalty = gr.Slider(
116
- minimum=-1,
117
- maximum=3,
118
- value=1,
119
- step=0.2,
120
- interactive=True,
121
- label="Repetition Penalty",
122
- )
123
 
124
  submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output])
125
 
 
11
  model_type=model_type,
12
  is_eval=True,
13
  device=device,
14
+ dtype=torch.float16
15
  )
16
 
17
  def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method):
 
54
  The demo is based on the official <a href="https://github.com/salesforce/LAVIS/tree/main/projects/instructblip" style="text-decoration: underline;" target="_blank"> Github </a> implementation
55
  """
56
  )
57
+ with gr.Row():
58
+ with gr.Column(scale=3):
59
+ image_input = gr.Image(type="pil")
60
+ prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2)
61
+ output = gr.Textbox(label="Output")
62
+ submit = gr.Button("Run", variant="primary")
63
 
64
+ with gr.Column(scale=1):
65
+ min_len = gr.Slider(
66
+ minimum=1,
67
+ maximum=50,
68
+ value=1,
69
+ step=1,
70
+ interactive=True,
71
+ label="Min Length",
72
+ )
73
+
74
+ max_len = gr.Slider(
75
+ minimum=10,
76
+ maximum=500,
77
+ value=250,
78
+ step=5,
79
+ interactive=True,
80
+ label="Max Length",
81
+ )
82
+
83
+ sampling = gr.Radio(
84
+ choices=["Beam search", "Nucleus sampling"],
85
+ value="Beam search",
86
+ label="Text Decoding Method",
87
+ interactive=True,
88
+ )
89
+
90
+ top_p = gr.Slider(
91
+ minimum=0.5,
92
+ maximum=1.0,
93
+ value=0.9,
94
+ step=0.1,
95
+ interactive=True,
96
+ label="Top p",
97
+ )
98
+
99
+ beam_size = gr.Slider(
100
+ minimum=1,
101
+ maximum=10,
102
+ value=5,
103
+ step=1,
104
+ interactive=True,
105
+ label="Beam Size",
106
+ )
107
+
108
+ len_penalty = gr.Slider(
109
+ minimum=-1,
110
+ maximum=2,
111
+ value=1,
112
+ step=0.2,
113
+ interactive=True,
114
+ label="Length Penalty",
115
+ )
116
+
117
+ repetition_penalty = gr.Slider(
118
+ minimum=-1,
119
+ maximum=3,
120
+ value=1,
121
+ step=0.2,
122
+ interactive=True,
123
+ label="Repetition Penalty",
124
+ )
125
 
126
  submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output])
127