Spaces:
Running
on
Zero
Running
on
Zero
praeclarumjj3
commited on
Commit
·
15341f5
1
Parent(s):
c5a315a
:zap: Fix version
Browse files
app.py
CHANGED
@@ -249,51 +249,49 @@ def regenerate(state, image_process_mode):
|
|
249 |
state.skip_next = False
|
250 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
251 |
|
252 |
-
# @spaces.GPU
|
253 |
-
# def get_interm_outs(state):
|
254 |
-
|
255 |
-
|
256 |
@spaces.GPU
|
257 |
-
def
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
if
|
264 |
-
if len(images)
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
|
274 |
-
else:
|
275 |
-
inp_images = inp_images.to(model.device, dtype=torch.float16)
|
276 |
else:
|
277 |
-
inp_images =
|
278 |
-
image_sizes = None
|
279 |
-
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
280 |
else:
|
281 |
inp_images = None
|
282 |
-
|
|
|
|
|
|
|
|
|
283 |
|
284 |
-
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
|
295 |
-
|
|
|
296 |
|
|
|
|
|
297 |
prompt = state.get_prompt()
|
298 |
images = state.get_images(return_pil=True)
|
299 |
#prompt, image_args = process_image(prompt, images)
|
@@ -451,7 +449,7 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
|
|
451 |
inter_vis_btn.click(
|
452 |
generate,
|
453 |
[state],
|
454 |
-
[depth_box, seg_box, gen_box
|
455 |
)
|
456 |
|
457 |
clear_btn.click(
|
|
|
249 |
state.skip_next = False
|
250 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
251 |
|
|
|
|
|
|
|
|
|
252 |
@spaces.GPU
|
253 |
+
def get_interm_outs(state):
|
254 |
+
prompt = state.get_prompt()
|
255 |
+
images = state.get_images(return_pil=True)
|
256 |
+
#prompt, image_args = process_image(prompt, images)
|
257 |
+
|
258 |
+
if images is not None and len(images) > 0:
|
259 |
+
if len(images) > 0:
|
260 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
261 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
262 |
+
|
263 |
+
#images = [load_image_from_base64(image) for image in images]
|
264 |
+
image_sizes = [image.size for image in images]
|
265 |
+
inp_images = process_images(images, image_processor, model.config)
|
266 |
+
|
267 |
+
if type(inp_images) is list:
|
268 |
+
inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
|
|
|
|
|
|
|
269 |
else:
|
270 |
+
inp_images = inp_images.to(model.device, dtype=torch.float16)
|
|
|
|
|
271 |
else:
|
272 |
inp_images = None
|
273 |
+
image_sizes = None
|
274 |
+
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
275 |
+
else:
|
276 |
+
inp_images = None
|
277 |
+
image_args = {}
|
278 |
|
279 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
280 |
|
281 |
+
interm_outs = model.get_visual_interpretations(
|
282 |
+
input_ids,
|
283 |
+
**image_args
|
284 |
+
)
|
285 |
+
|
286 |
+
depth_outs = get_depth_images(interm_outs, image_sizes[0])
|
287 |
+
seg_outs = get_seg_images(interm_outs, images[0])
|
288 |
+
gen_outs = get_gen_images(interm_outs)
|
289 |
|
290 |
+
return depth_outs, seg_outs, gen_outs
|
291 |
+
|
292 |
|
293 |
+
@spaces.GPU
|
294 |
+
def generate(state, temperature, top_p, max_output_tokens):
|
295 |
prompt = state.get_prompt()
|
296 |
images = state.get_images(return_pil=True)
|
297 |
#prompt, image_args = process_image(prompt, images)
|
|
|
449 |
inter_vis_btn.click(
|
450 |
generate,
|
451 |
[state],
|
452 |
+
[depth_box, seg_box, gen_box],
|
453 |
)
|
454 |
|
455 |
clear_btn.click(
|