Abhaykoul commited on
Commit
6757f4d
1 Parent(s): 33d8c37

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import hashlib
5
+ import torch
6
+ from threading import Thread
7
+ from transformers import AutoModel, AutoProcessor, TextIteratorStreamer
8
+ import gradio as gr
9
+
10
+ # Initialize the model and processor
11
+ def initialize_model_and_processor():
12
+ model = AutoModel.from_pretrained("OEvortex/HelpingAI-Vision", torch_dtype=torch.float16, trust_remote_code=True).to("cuda" if torch.cuda.is_available() else "cpu")
13
+ processor = AutoProcessor.from_pretrained("OEvortex/HelpingAI-Vision", trust_remote_code=True)
14
+ return model, processor
15
+
16
+ # Function to process images and cache results
17
+ def cached_vision_process(image, max_crops, num_tokens):
18
+ image_hash = hashlib.sha256(image.tobytes()).hexdigest()
19
+ cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
20
+ if os.path.exists(cache_path):
21
+ return torch.load(cache_path).to(model.device, dtype=model.dtype)
22
+ else:
23
+ processor_outputs = processor.image_processor([image], max_crops)
24
+ pixel_values = [value.to(model.device, model.dtype) for value in processor_outputs["pixel_values"]]
25
+ coords = [value.to(model.device, model.dtype) for value in processor_outputs["coords"]]
26
+ image_outputs = model.vision_model(pixel_values, coords, num_tokens)
27
+ image_features = model.multi_modal_projector(image_outputs)
28
+ os.makedirs("visual_cache", exist_ok=True)
29
+ torch.save(image_features, cache_path)
30
+ return image_features.to(model.device, model.dtype)
31
+
32
+ # Function to answer questions about images
33
+ def answer_question(image, question, max_crops, num_tokens, sample, temperature, top_k):
34
+ if not question.strip() or not image:
35
+ return "Please provide both an image and a question."
36
+
37
+ prompt = f"""user
38
+ <image>
39
+ {question}
40
+ assistant
41
+ """
42
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
43
+ with torch.inference_mode():
44
+ inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
45
+
46
+ generation_kwargs = {
47
+ "input_ids": inputs["input_ids"],
48
+ "attention_mask": inputs["attention_mask"],
49
+ "image_features": inputs["image_features"],
50
+ "streamer": streamer,
51
+ "max_length": 1000,
52
+ "use_cache": True,
53
+ "eos_token_id": processor.tokenizer.eos_token_id,
54
+ "pad_token_id": processor.tokenizer.eos_token_id,
55
+ "temperature": temperature,
56
+ "do_sample": sample,
57
+ "top_k": top_k,
58
+ }
59
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
60
+ thread.start()
61
+
62
+ buffer = ""
63
+ output_started = False
64
+ for new_text in streamer:
65
+ if not output_started:
66
+ if "assistant" in new_text:
67
+ output_started = True
68
+ continue
69
+ buffer += new_text
70
+ if len(buffer) > 1:
71
+ yield buffer
72
+ return buffer
73
+
74
+ # Initialize the model and processor
75
+ model, processor = initialize_model_and_processor()
76
+
77
+ # Gradio interface setup
78
+ with gr.Blocks() as demo:
79
+ with gr.Group():
80
+ with gr.Row():
81
+ prompt = gr.Textbox(label="Question", placeholder="e.g. Describe this?", scale=4)
82
+ submit = gr.Button("Send", scale=1)
83
+ with gr.Row():
84
+ max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
85
+ num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
86
+ with gr.Row():
87
+ img = gr.Image(type="pil", label="Upload or Drag an Image")
88
+ output = gr.TextArea(label="Answer")
89
+ with gr.Row():
90
+ sample = gr.Checkbox(label="Sample", value=False)
91
+ temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0, label="Temperature")
92
+ top_k = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Top-K")
93
+
94
+ submit.click(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
95
+ prompt.submit(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
96
+
97
+ demo.queue().launch(debug=True)