import argparse import torch import gradio as gr from threading import Thread from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM from PIL import ImageDraw import re from torchvision.transforms.v2 import Resize parser = argparse.ArgumentParser() parser.add_argument("--cpu", action="store_true", help="Use CPU for computation") args = parser.parse_args([]) DEVICE = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu" # Determine device based on availability and argument DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 LATEST_REVISION = "2024-05-20" MODEL_ID = "yeshavyas27/moondream-ft" tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=LATEST_REVISION) moondream = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=DTYPE ).to(device=DEVICE) moondream.eval() def answer_question(img, prompt): image_embeds = moondream.encode_image(img) streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) thread = Thread( target=moondream.answer_question, kwargs={ "image_embeds": image_embeds, "question": prompt, "tokenizer": tokenizer, "streamer": streamer, }, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer def extract_floats(text): # Regular expression to match an array of four floating point numbers pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]" match = re.search(pattern, text) if match: # Extract the numbers and convert them to floats return [float(num) for num in match.groups()] return None # Return None if no match is found def extract_bbox(text): bbox = None if extract_floats(text) is not None: x1, y1, x2, y2 = extract_floats(text) bbox = (x1, y1, x2, y2) return bbox def process_answer(img, answer): if extract_bbox(answer) is not None: x1, y1, x2, y2 = extract_bbox(answer) draw_image = Resize(768)(img) width, height = draw_image.size x1, x2 = int(x1 * width), int(x2 * width) y1, y2 = int(y1 * height), int(y2 * height) bbox = (x1, y1, x2, y2) ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3) return gr.update(visible=True, value=draw_image) return gr.update(visible=False, value=None) with gr.Blocks() as demo: gr.Markdown( """ # 🌔 VQA Visual Question Answering """ ) with gr.Row(): prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...", scale=4) submit = gr.Button("Submit") with gr.Row(): img = gr.Image(type="pil", label="Upload an Image") with gr.Column(): output = gr.Markdown(label="Response") ann = gr.Image(visible=False, label="Annotated Image") submit.click(answer_question, [img, prompt], output) prompt.submit(answer_question, [img, prompt], output) output.change(process_answer, [img, output], ann, show_progress=False) demo.queue().launch(debug=True)