Update my_model/tabs/run_inference.py
Browse files
my_model/tabs/run_inference.py
CHANGED
@@ -4,6 +4,7 @@ import bitsandbytes
|
|
4 |
import accelerate
|
5 |
import scipy
|
6 |
import copy
|
|
|
7 |
from PIL import Image
|
8 |
import torch.nn as nn
|
9 |
import pandas as pd
|
@@ -32,6 +33,7 @@ class InferenceRunner(StateManager):
|
|
32 |
# Display sample images as clickable thumbnails
|
33 |
self.col1.write("Choose from sample images:")
|
34 |
cols = self.col1.columns(len(self.sample_images))
|
|
|
35 |
for idx, sample_image_path in enumerate(self.sample_images):
|
36 |
with cols[idx]:
|
37 |
image = Image.open(sample_image_path)
|
@@ -108,7 +110,7 @@ class InferenceRunner(StateManager):
|
|
108 |
with st.container():
|
109 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
110 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
111 |
-
|
112 |
if st.session_state.button_label == "Load Model":
|
113 |
if self.is_model_loaded():
|
114 |
free_gpu_resources()
|
@@ -121,10 +123,12 @@ class InferenceRunner(StateManager):
|
|
121 |
|
122 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
123 |
force_reload_full_model = True
|
|
|
124 |
|
125 |
if load_fine_tuned_model:
|
126 |
free_gpu_resources()
|
127 |
self.load_model()
|
|
|
128 |
st.session_state['loading_in_progress'] = False
|
129 |
|
130 |
elif fine_tuned_model_already_loaded:
|
@@ -139,8 +143,11 @@ class InferenceRunner(StateManager):
|
|
139 |
|
140 |
elif force_reload_full_model:
|
141 |
free_gpu_resources()
|
|
|
142 |
self.force_reload_model()
|
|
|
143 |
st.session_state['loading_in_progress'] = False
|
|
|
144 |
|
145 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
146 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
@@ -148,8 +155,9 @@ class InferenceRunner(StateManager):
|
|
148 |
|
149 |
|
150 |
if self.is_model_loaded():
|
151 |
-
st.session_state['
|
152 |
free_gpu_resources()
|
|
|
153 |
self.image_qa_app(self.get_model())
|
154 |
st.write(st.session_state['loading_in_progress'])
|
155 |
|
|
|
4 |
import accelerate
|
5 |
import scipy
|
6 |
import copy
|
7 |
+
import time
|
8 |
from PIL import Image
|
9 |
import torch.nn as nn
|
10 |
import pandas as pd
|
|
|
33 |
# Display sample images as clickable thumbnails
|
34 |
self.col1.write("Choose from sample images:")
|
35 |
cols = self.col1.columns(len(self.sample_images))
|
36 |
+
st.write(st.session_state['loading_in_progress'])
|
37 |
for idx, sample_image_path in enumerate(self.sample_images):
|
38 |
with cols[idx]:
|
39 |
image = Image.open(sample_image_path)
|
|
|
110 |
with st.container():
|
111 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
112 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
113 |
+
t1=time.time()
|
114 |
if st.session_state.button_label == "Load Model":
|
115 |
if self.is_model_loaded():
|
116 |
free_gpu_resources()
|
|
|
123 |
|
124 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
125 |
force_reload_full_model = True
|
126 |
+
t1=time.time()
|
127 |
|
128 |
if load_fine_tuned_model:
|
129 |
free_gpu_resources()
|
130 |
self.load_model()
|
131 |
+
|
132 |
st.session_state['loading_in_progress'] = False
|
133 |
|
134 |
elif fine_tuned_model_already_loaded:
|
|
|
143 |
|
144 |
elif force_reload_full_model:
|
145 |
free_gpu_resources()
|
146 |
+
|
147 |
self.force_reload_model()
|
148 |
+
|
149 |
st.session_state['loading_in_progress'] = False
|
150 |
+
st.session_state['model_loaded'] = True
|
151 |
|
152 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
153 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
|
|
155 |
|
156 |
|
157 |
if self.is_model_loaded():
|
158 |
+
st.session_state['time_taken_to_load_model'] = time.time()-t1
|
159 |
free_gpu_resources()
|
160 |
+
st.session_state['loading_in_progress'] = False
|
161 |
self.image_qa_app(self.get_model())
|
162 |
st.write(st.session_state['loading_in_progress'])
|
163 |
|