praeclarumjj3 commited on
Commit
20b4d0d
β€’
1 Parent(s): 297e5e9

:zap: Fix version

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +48 -38
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ”
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  import numpy as np
5
-
6
  from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
7
 
8
  from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
@@ -23,6 +22,14 @@ import math
23
  from transformers import TextIteratorStreamer
24
  from threading import Thread
25
 
 
 
 
 
 
 
 
 
26
  def make_grid(pil_images, layer_indices=None):
27
  new_images = []
28
  new_captions = []
@@ -242,48 +249,51 @@ def regenerate(state, image_process_mode):
242
  state.skip_next = False
243
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
244
 
245
- @spaces.GPU
246
- def get_interm_outs(state):
247
- prompt = state.get_prompt()
248
- images = state.get_images(return_pil=True)
249
- #prompt, image_args = process_image(prompt, images)
250
-
251
- if images is not None and len(images) > 0:
252
- if len(images) > 0:
253
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
254
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
255
-
256
- #images = [load_image_from_base64(image) for image in images]
257
- image_sizes = [image.size for image in images]
258
- inp_images = process_images(images, image_processor, model.config)
259
 
260
- if type(inp_images) is list:
261
- inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  else:
263
- inp_images = inp_images.to(model.device, dtype=torch.float16)
 
 
264
  else:
265
  inp_images = None
266
- image_sizes = None
267
- image_args = {"images": inp_images, "image_sizes": image_sizes}
268
- else:
269
- inp_images = None
270
- image_args = {}
271
 
272
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
273
 
274
- interm_outs = model.get_visual_interpretations(
275
- input_ids,
276
- **image_args
277
- )
278
-
279
- depth_outs = get_depth_images(interm_outs, image_sizes[0])
280
- seg_outs = get_seg_images(interm_outs, images[0])
281
- gen_outs = get_gen_images(interm_outs)
282
 
283
- return depth_outs, seg_outs, gen_outs
284
 
285
- @spaces.GPU
286
- def generate(state, temperature, top_p, max_output_tokens):
287
  prompt = state.get_prompt()
288
  images = state.get_images(return_pil=True)
289
  #prompt, image_args = process_image(prompt, images)
@@ -439,9 +449,9 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
439
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
440
 
441
  inter_vis_btn.click(
442
- get_interm_outs,
443
  [state],
444
- [depth_box, seg_box, gen_box],
445
  )
446
 
447
  clear_btn.click(
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
+ import spaces
5
  from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
6
 
7
  from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
 
22
  from transformers import TextIteratorStreamer
23
  from threading import Thread
24
 
25
+ import subprocess
26
+ # Install flash attention, skipping CUDA build if necessary
27
+ subprocess.run(
28
+ "pip install flash-attn --no-build-isolation",
29
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
30
+ shell=True,
31
+ )
32
+
33
  def make_grid(pil_images, layer_indices=None):
34
  new_images = []
35
  new_captions = []
 
249
  state.skip_next = False
250
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
251
 
252
+ # @spaces.GPU
253
+ # def get_interm_outs(state):
254
+
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ @spaces.GPU
257
+ def generate(state, temperature, top_p, max_output_tokens, is_inter=False):
258
+ if is_inter:
259
+ prompt = state.get_prompt()
260
+ images = state.get_images(return_pil=True)
261
+ #prompt, image_args = process_image(prompt, images)
262
+
263
+ if images is not None and len(images) > 0:
264
+ if len(images) > 0:
265
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
266
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
267
+
268
+ #images = [load_image_from_base64(image) for image in images]
269
+ image_sizes = [image.size for image in images]
270
+ inp_images = process_images(images, image_processor, model.config)
271
+
272
+ if type(inp_images) is list:
273
+ inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
274
+ else:
275
+ inp_images = inp_images.to(model.device, dtype=torch.float16)
276
  else:
277
+ inp_images = None
278
+ image_sizes = None
279
+ image_args = {"images": inp_images, "image_sizes": image_sizes}
280
  else:
281
  inp_images = None
282
+ image_args = {}
 
 
 
 
283
 
284
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
285
 
286
+ interm_outs = model.get_visual_interpretations(
287
+ input_ids,
288
+ **image_args
289
+ )
290
+
291
+ depth_outs = get_depth_images(interm_outs, image_sizes[0])
292
+ seg_outs = get_seg_images(interm_outs, images[0])
293
+ gen_outs = get_gen_images(interm_outs)
294
 
295
+ return depth_outs, seg_outs, gen_outs
296
 
 
 
297
  prompt = state.get_prompt()
298
  images = state.get_images(return_pil=True)
299
  #prompt, image_args = process_image(prompt, images)
 
449
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
450
 
451
  inter_vis_btn.click(
452
+ generate,
453
  [state],
454
+ [depth_box, seg_box, gen_box, True],
455
  )
456
 
457
  clear_btn.click(