Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +7 -3
my_model/KBVQA.py
CHANGED
@@ -224,7 +224,7 @@ class KBVQA:
|
|
224 |
return p
|
225 |
|
226 |
@staticmethod
|
227 |
-
def trim_objects(
|
228 |
"""
|
229 |
Trim the last object from the detected objects string.
|
230 |
|
@@ -257,7 +257,9 @@ class KBVQA:
|
|
257 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
258 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
259 |
self.current_prompt_length = num_tokens
|
260 |
-
|
|
|
|
|
261 |
while self.current_prompt_length > self.max_context_window:
|
262 |
detected_objects_str = self.trim_objects(detected_objects_str)
|
263 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
@@ -265,7 +267,9 @@ class KBVQA:
|
|
265 |
|
266 |
if detected_objects_str == "":
|
267 |
break # Break if no objects are left
|
268 |
-
|
|
|
|
|
269 |
|
270 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
271 |
free_gpu_resources()
|
|
|
224 |
return p
|
225 |
|
226 |
@staticmethod
|
227 |
+
def trim_objects(detected_objects_str):
|
228 |
"""
|
229 |
Trim the last object from the detected objects string.
|
230 |
|
|
|
257 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
258 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
259 |
self.current_prompt_length = num_tokens
|
260 |
+
if self.current_prompt_length > self.max_context_window:
|
261 |
+
trim = True
|
262 |
+
st.warning(f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2, objects detected with low confidence will be removed one at a time until the prompt length is within the maximum context window ...")
|
263 |
while self.current_prompt_length > self.max_context_window:
|
264 |
detected_objects_str = self.trim_objects(detected_objects_str)
|
265 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
|
|
267 |
|
268 |
if detected_objects_str == "":
|
269 |
break # Break if no objects are left
|
270 |
+
if trim:
|
271 |
+
st.warning(f"New prompt length is: {self.current_prompt_length}")
|
272 |
+
trim = False
|
273 |
|
274 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
275 |
free_gpu_resources()
|