Spaces:
Sleeping
Sleeping
Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +6 -0
my_model/KBVQA.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import os
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
5 |
from typing import Optional
|
|
|
6 |
from my_model.captioner.image_captioning import ImageCaptioningModel
|
7 |
from my_model.object_detection import ObjectDetector
|
8 |
|
@@ -141,20 +142,25 @@ class KBVQA():
|
|
141 |
return output_text.capitalize()
|
142 |
|
143 |
def prepare_kbvqa_model(detection_model):
|
|
|
144 |
kbvqa = KBVQA()
|
145 |
# Progress bar for model loading
|
146 |
with st.spinner('Loading models...'):
|
|
|
147 |
progress_bar = st.progress(0)
|
148 |
kbvqa.load_detector(detection_model)
|
149 |
progress_bar.progress(33)
|
150 |
kbvqa.load_caption_model()
|
|
|
151 |
progress_bar.progress(66)
|
152 |
kbvqa.load_fine_tuned_model()
|
|
|
153 |
progress_bar.progress(100)
|
154 |
|
155 |
if kbvqa.all_models_loaded:
|
156 |
st.success('Model loaded successfully!')
|
157 |
kbvqa.kbvqa_model.eval()
|
|
|
158 |
return kbvqa
|
159 |
|
160 |
|
|
|
3 |
import os
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
5 |
from typing import Optional
|
6 |
+
from my_model.utilities import free_gpu_resources
|
7 |
from my_model.captioner.image_captioning import ImageCaptioningModel
|
8 |
from my_model.object_detection import ObjectDetector
|
9 |
|
|
|
142 |
return output_text.capitalize()
|
143 |
|
144 |
def prepare_kbvqa_model(detection_model):
|
145 |
+
free_gpu_resources()
|
146 |
kbvqa = KBVQA()
|
147 |
# Progress bar for model loading
|
148 |
with st.spinner('Loading models...'):
|
149 |
+
|
150 |
progress_bar = st.progress(0)
|
151 |
kbvqa.load_detector(detection_model)
|
152 |
progress_bar.progress(33)
|
153 |
kbvqa.load_caption_model()
|
154 |
+
free_gpu_resources()
|
155 |
progress_bar.progress(66)
|
156 |
kbvqa.load_fine_tuned_model()
|
157 |
+
free_gpu_resources()
|
158 |
progress_bar.progress(100)
|
159 |
|
160 |
if kbvqa.all_models_loaded:
|
161 |
st.success('Model loaded successfully!')
|
162 |
kbvqa.kbvqa_model.eval()
|
163 |
+
free_gpu_resources()
|
164 |
return kbvqa
|
165 |
|
166 |
|