Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +14 -4
my_model/KBVQA.py
CHANGED
@@ -99,6 +99,7 @@ class KBVQA:
|
|
99 |
|
100 |
self.captioner = ImageCaptioningModel()
|
101 |
self.captioner.load_model()
|
|
|
102 |
|
103 |
def get_caption(self, img: Image.Image) -> str:
|
104 |
"""
|
@@ -110,8 +111,9 @@ class KBVQA:
|
|
110 |
Returns:
|
111 |
str: The generated caption for the image.
|
112 |
"""
|
113 |
-
|
114 |
-
|
|
|
115 |
|
116 |
def load_detector(self, model: str) -> None:
|
117 |
"""
|
@@ -123,6 +125,7 @@ class KBVQA:
|
|
123 |
|
124 |
self.detector = ObjectDetector()
|
125 |
self.detector.load_model(model)
|
|
|
126 |
|
127 |
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
|
128 |
"""
|
@@ -136,8 +139,11 @@ class KBVQA:
|
|
136 |
"""
|
137 |
|
138 |
image = self.detector.process_image(img)
|
|
|
139 |
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state['confidence_level'])
|
|
|
140 |
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
|
|
|
141 |
return image_with_boxes, detected_objects_string
|
142 |
|
143 |
def load_fine_tuned_model(self) -> None:
|
@@ -150,6 +156,8 @@ class KBVQA:
|
|
150 |
low_cpu_mem_usage=True,
|
151 |
quantization_config=self.bnb_config,
|
152 |
token=self.access_token)
|
|
|
|
|
153 |
|
154 |
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
|
155 |
use_fast=self.use_fast,
|
@@ -157,7 +165,7 @@ class KBVQA:
|
|
157 |
trust_remote_code=self.trust_remote,
|
158 |
add_eos_token=self.add_eos_token,
|
159 |
token=self.access_token)
|
160 |
-
|
161 |
|
162 |
@property
|
163 |
def all_models_loaded(self):
|
@@ -225,7 +233,7 @@ class KBVQA:
|
|
225 |
Returns:
|
226 |
str: The generated answer to the question.
|
227 |
"""
|
228 |
-
|
229 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
230 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
231 |
self.current_prompt_length = num_tokens
|
@@ -234,8 +242,10 @@ class KBVQA:
|
|
234 |
return
|
235 |
|
236 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
|
|
237 |
input_ids = model_inputs["input_ids"]
|
238 |
output_ids = self.kbvqa_model.generate(input_ids)
|
|
|
239 |
index = input_ids.shape[1] # needed to avoid printing the input prompt
|
240 |
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
241 |
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
|
|
|
99 |
|
100 |
self.captioner = ImageCaptioningModel()
|
101 |
self.captioner.load_model()
|
102 |
+
free_gpu_resources()
|
103 |
|
104 |
def get_caption(self, img: Image.Image) -> str:
|
105 |
"""
|
|
|
111 |
Returns:
|
112 |
str: The generated caption for the image.
|
113 |
"""
|
114 |
+
caption = self.captioner.generate_caption(img)
|
115 |
+
free_gpu_resources()
|
116 |
+
return caption
|
117 |
|
118 |
def load_detector(self, model: str) -> None:
|
119 |
"""
|
|
|
125 |
|
126 |
self.detector = ObjectDetector()
|
127 |
self.detector.load_model(model)
|
128 |
+
free_gpu_resources()
|
129 |
|
130 |
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
|
131 |
"""
|
|
|
139 |
"""
|
140 |
|
141 |
image = self.detector.process_image(img)
|
142 |
+
free_gpu_resources()
|
143 |
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state['confidence_level'])
|
144 |
+
free_gpu_resources()
|
145 |
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
|
146 |
+
free_gpu_resources()
|
147 |
return image_with_boxes, detected_objects_string
|
148 |
|
149 |
def load_fine_tuned_model(self) -> None:
|
|
|
156 |
low_cpu_mem_usage=True,
|
157 |
quantization_config=self.bnb_config,
|
158 |
token=self.access_token)
|
159 |
+
|
160 |
+
free_gpu_resources()
|
161 |
|
162 |
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
|
163 |
use_fast=self.use_fast,
|
|
|
165 |
trust_remote_code=self.trust_remote,
|
166 |
add_eos_token=self.add_eos_token,
|
167 |
token=self.access_token)
|
168 |
+
free_gpu_resources()
|
169 |
|
170 |
@property
|
171 |
def all_models_loaded(self):
|
|
|
233 |
Returns:
|
234 |
str: The generated answer to the question.
|
235 |
"""
|
236 |
+
free_gpu_resources()
|
237 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
238 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
239 |
self.current_prompt_length = num_tokens
|
|
|
242 |
return
|
243 |
|
244 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
245 |
+
free_gpu_resources()
|
246 |
input_ids = model_inputs["input_ids"]
|
247 |
output_ids = self.kbvqa_model.generate(input_ids)
|
248 |
+
free_gpu_resources()
|
249 |
index = input_ids.shape[1] # needed to avoid printing the input prompt
|
250 |
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
251 |
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
|