Spaces:
Running
on
Zero
Running
on
Zero
praeclarumjj3
commited on
Commit
β’
20b4d0d
1
Parent(s):
297e5e9
:zap: Fix version
Browse files
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: π
|
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.42.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
-
import spaces
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
-
|
6 |
from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
|
7 |
|
8 |
from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
@@ -23,6 +22,14 @@ import math
|
|
23 |
from transformers import TextIteratorStreamer
|
24 |
from threading import Thread
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def make_grid(pil_images, layer_indices=None):
|
27 |
new_images = []
|
28 |
new_captions = []
|
@@ -242,48 +249,51 @@ def regenerate(state, image_process_mode):
|
|
242 |
state.skip_next = False
|
243 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
244 |
|
245 |
-
@spaces.GPU
|
246 |
-
def get_interm_outs(state):
|
247 |
-
|
248 |
-
images = state.get_images(return_pil=True)
|
249 |
-
#prompt, image_args = process_image(prompt, images)
|
250 |
-
|
251 |
-
if images is not None and len(images) > 0:
|
252 |
-
if len(images) > 0:
|
253 |
-
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
254 |
-
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
255 |
-
|
256 |
-
#images = [load_image_from_base64(image) for image in images]
|
257 |
-
image_sizes = [image.size for image in images]
|
258 |
-
inp_images = process_images(images, image_processor, model.config)
|
259 |
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
else:
|
263 |
-
inp_images =
|
|
|
|
|
264 |
else:
|
265 |
inp_images = None
|
266 |
-
|
267 |
-
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
268 |
-
else:
|
269 |
-
inp_images = None
|
270 |
-
image_args = {}
|
271 |
|
272 |
-
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
|
283 |
-
|
284 |
|
285 |
-
@spaces.GPU
|
286 |
-
def generate(state, temperature, top_p, max_output_tokens):
|
287 |
prompt = state.get_prompt()
|
288 |
images = state.get_images(return_pil=True)
|
289 |
#prompt, image_args = process_image(prompt, images)
|
@@ -439,9 +449,9 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
|
|
439 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
440 |
|
441 |
inter_vis_btn.click(
|
442 |
-
|
443 |
[state],
|
444 |
-
[depth_box, seg_box, gen_box],
|
445 |
)
|
446 |
|
447 |
clear_btn.click(
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
import spaces
|
5 |
from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
|
6 |
|
7 |
from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
|
|
22 |
from transformers import TextIteratorStreamer
|
23 |
from threading import Thread
|
24 |
|
25 |
+
import subprocess
|
26 |
+
# Install flash attention, skipping CUDA build if necessary
|
27 |
+
subprocess.run(
|
28 |
+
"pip install flash-attn --no-build-isolation",
|
29 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
30 |
+
shell=True,
|
31 |
+
)
|
32 |
+
|
33 |
def make_grid(pil_images, layer_indices=None):
|
34 |
new_images = []
|
35 |
new_captions = []
|
|
|
249 |
state.skip_next = False
|
250 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
251 |
|
252 |
+
# @spaces.GPU
|
253 |
+
# def get_interm_outs(state):
|
254 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
+
@spaces.GPU
|
257 |
+
def generate(state, temperature, top_p, max_output_tokens, is_inter=False):
|
258 |
+
if is_inter:
|
259 |
+
prompt = state.get_prompt()
|
260 |
+
images = state.get_images(return_pil=True)
|
261 |
+
#prompt, image_args = process_image(prompt, images)
|
262 |
+
|
263 |
+
if images is not None and len(images) > 0:
|
264 |
+
if len(images) > 0:
|
265 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
266 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
267 |
+
|
268 |
+
#images = [load_image_from_base64(image) for image in images]
|
269 |
+
image_sizes = [image.size for image in images]
|
270 |
+
inp_images = process_images(images, image_processor, model.config)
|
271 |
+
|
272 |
+
if type(inp_images) is list:
|
273 |
+
inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
|
274 |
+
else:
|
275 |
+
inp_images = inp_images.to(model.device, dtype=torch.float16)
|
276 |
else:
|
277 |
+
inp_images = None
|
278 |
+
image_sizes = None
|
279 |
+
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
280 |
else:
|
281 |
inp_images = None
|
282 |
+
image_args = {}
|
|
|
|
|
|
|
|
|
283 |
|
284 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
285 |
|
286 |
+
interm_outs = model.get_visual_interpretations(
|
287 |
+
input_ids,
|
288 |
+
**image_args
|
289 |
+
)
|
290 |
+
|
291 |
+
depth_outs = get_depth_images(interm_outs, image_sizes[0])
|
292 |
+
seg_outs = get_seg_images(interm_outs, images[0])
|
293 |
+
gen_outs = get_gen_images(interm_outs)
|
294 |
|
295 |
+
return depth_outs, seg_outs, gen_outs
|
296 |
|
|
|
|
|
297 |
prompt = state.get_prompt()
|
298 |
images = state.get_images(return_pil=True)
|
299 |
#prompt, image_args = process_image(prompt, images)
|
|
|
449 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
450 |
|
451 |
inter_vis_btn.click(
|
452 |
+
generate,
|
453 |
[state],
|
454 |
+
[depth_box, seg_box, gen_box, True],
|
455 |
)
|
456 |
|
457 |
clear_btn.click(
|