praeclarumjj3 commited on
Commit
15341f5
·
1 Parent(s): c5a315a

:zap: Fix version

Browse files
Files changed (1) hide show
  1. app.py +36 -38
app.py CHANGED
@@ -249,51 +249,49 @@ def regenerate(state, image_process_mode):
249
  state.skip_next = False
250
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
251
 
252
- # @spaces.GPU
253
- # def get_interm_outs(state):
254
-
255
-
256
  @spaces.GPU
257
- def generate(state, temperature, top_p, max_output_tokens, is_inter=False):
258
- if is_inter:
259
- prompt = state.get_prompt()
260
- images = state.get_images(return_pil=True)
261
- #prompt, image_args = process_image(prompt, images)
262
-
263
- if images is not None and len(images) > 0:
264
- if len(images) > 0:
265
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
266
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
267
-
268
- #images = [load_image_from_base64(image) for image in images]
269
- image_sizes = [image.size for image in images]
270
- inp_images = process_images(images, image_processor, model.config)
271
-
272
- if type(inp_images) is list:
273
- inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
274
- else:
275
- inp_images = inp_images.to(model.device, dtype=torch.float16)
276
  else:
277
- inp_images = None
278
- image_sizes = None
279
- image_args = {"images": inp_images, "image_sizes": image_sizes}
280
  else:
281
  inp_images = None
282
- image_args = {}
 
 
 
 
283
 
284
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
285
 
286
- interm_outs = model.get_visual_interpretations(
287
- input_ids,
288
- **image_args
289
- )
290
-
291
- depth_outs = get_depth_images(interm_outs, image_sizes[0])
292
- seg_outs = get_seg_images(interm_outs, images[0])
293
- gen_outs = get_gen_images(interm_outs)
294
 
295
- return depth_outs, seg_outs, gen_outs
 
296
 
 
 
297
  prompt = state.get_prompt()
298
  images = state.get_images(return_pil=True)
299
  #prompt, image_args = process_image(prompt, images)
@@ -451,7 +449,7 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
451
  inter_vis_btn.click(
452
  generate,
453
  [state],
454
- [depth_box, seg_box, gen_box, True],
455
  )
456
 
457
  clear_btn.click(
 
249
  state.skip_next = False
250
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
251
 
 
 
 
 
252
  @spaces.GPU
253
+ def get_interm_outs(state):
254
+ prompt = state.get_prompt()
255
+ images = state.get_images(return_pil=True)
256
+ #prompt, image_args = process_image(prompt, images)
257
+
258
+ if images is not None and len(images) > 0:
259
+ if len(images) > 0:
260
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
261
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
262
+
263
+ #images = [load_image_from_base64(image) for image in images]
264
+ image_sizes = [image.size for image in images]
265
+ inp_images = process_images(images, image_processor, model.config)
266
+
267
+ if type(inp_images) is list:
268
+ inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
 
 
 
269
  else:
270
+ inp_images = inp_images.to(model.device, dtype=torch.float16)
 
 
271
  else:
272
  inp_images = None
273
+ image_sizes = None
274
+ image_args = {"images": inp_images, "image_sizes": image_sizes}
275
+ else:
276
+ inp_images = None
277
+ image_args = {}
278
 
279
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
280
 
281
+ interm_outs = model.get_visual_interpretations(
282
+ input_ids,
283
+ **image_args
284
+ )
285
+
286
+ depth_outs = get_depth_images(interm_outs, image_sizes[0])
287
+ seg_outs = get_seg_images(interm_outs, images[0])
288
+ gen_outs = get_gen_images(interm_outs)
289
 
290
+ return depth_outs, seg_outs, gen_outs
291
+
292
 
293
+ @spaces.GPU
294
+ def generate(state, temperature, top_p, max_output_tokens):
295
  prompt = state.get_prompt()
296
  images = state.get_images(return_pil=True)
297
  #prompt, image_args = process_image(prompt, images)
 
449
  inter_vis_btn.click(
450
  generate,
451
  [state],
452
+ [depth_box, seg_box, gen_box],
453
  )
454
 
455
  clear_btn.click(