Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,17 +12,23 @@ from ultralytics import YOLO
|
|
12 |
import cv2
|
13 |
import chromadb
|
14 |
|
15 |
-
# CLIP 모델 로드
|
16 |
@st.cache_resource
|
17 |
def load_clip_model():
|
18 |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
19 |
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
|
|
|
|
|
|
|
|
|
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
model.to(device)
|
|
|
22 |
return model, preprocess_val, tokenizer, device
|
23 |
|
24 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
25 |
|
|
|
26 |
@st.cache_resource
|
27 |
def load_yolo_model():
|
28 |
return YOLO("./accessaries.pt")
|
@@ -116,12 +122,12 @@ if 'selected_category' not in st.session_state:
|
|
116 |
st.session_state.selected_category = None
|
117 |
|
118 |
# Streamlit app
|
119 |
-
st.title("
|
120 |
|
121 |
# 단계별 처리
|
122 |
if st.session_state.step == 'input':
|
123 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
124 |
-
if st.button("Detect
|
125 |
if st.session_state.query_image_url:
|
126 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
127 |
if query_image is not None:
|
|
|
12 |
import cv2
|
13 |
import chromadb
|
14 |
|
|
|
15 |
@st.cache_resource
|
16 |
def load_clip_model():
|
17 |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
18 |
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
19 |
+
|
20 |
+
# 파인튜닝한 모델의 state_dict 불러오기
|
21 |
+
state_dict = torch.load('./accessory_clip.pt', map_location=torch.device('cpu'))
|
22 |
+
model.load_state_dict(state_dict) # 모델에 state_dict 적용
|
23 |
+
|
24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
model.to(device)
|
26 |
+
|
27 |
return model, preprocess_val, tokenizer, device
|
28 |
|
29 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
30 |
|
31 |
+
|
32 |
@st.cache_resource
|
33 |
def load_yolo_model():
|
34 |
return YOLO("./accessaries.pt")
|
|
|
122 |
st.session_state.selected_category = None
|
123 |
|
124 |
# Streamlit app
|
125 |
+
st.title("Accessary Search App")
|
126 |
|
127 |
# 단계별 처리
|
128 |
if st.session_state.step == 'input':
|
129 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
130 |
+
if st.button("Detect accesseary"):
|
131 |
if st.session_state.query_image_url:
|
132 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
133 |
if query_image is not None:
|