JoJosmin commited on
Commit
af1b60c
1 Parent(s): 84424e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -12,17 +12,23 @@ from ultralytics import YOLO
12
  import cv2
13
  import chromadb
14
 
15
- # CLIP 모델 로드
16
  @st.cache_resource
17
  def load_clip_model():
18
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
19
  tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
 
 
 
 
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model.to(device)
 
22
  return model, preprocess_val, tokenizer, device
23
 
24
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
25
 
 
26
  @st.cache_resource
27
  def load_yolo_model():
28
  return YOLO("./accessaries.pt")
@@ -116,12 +122,12 @@ if 'selected_category' not in st.session_state:
116
  st.session_state.selected_category = None
117
 
118
  # Streamlit app
119
- st.title("Accessory Search App")
120
 
121
  # 단계별 처리
122
  if st.session_state.step == 'input':
123
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
124
- if st.button("Detect accessory"):
125
  if st.session_state.query_image_url:
126
  query_image = load_image_from_url(st.session_state.query_image_url)
127
  if query_image is not None:
 
12
  import cv2
13
  import chromadb
14
 
 
15
  @st.cache_resource
16
  def load_clip_model():
17
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
18
  tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
19
+
20
+ # 파인튜닝한 모델의 state_dict 불러오기
21
+ state_dict = torch.load('./accessory_clip.pt', map_location=torch.device('cpu'))
22
+ model.load_state_dict(state_dict) # 모델에 state_dict 적용
23
+
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
+
27
  return model, preprocess_val, tokenizer, device
28
 
29
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
30
 
31
+
32
  @st.cache_resource
33
  def load_yolo_model():
34
  return YOLO("./accessaries.pt")
 
122
  st.session_state.selected_category = None
123
 
124
  # Streamlit app
125
+ st.title("Accessary Search App")
126
 
127
  # 단계별 처리
128
  if st.session_state.step == 'input':
129
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
130
+ if st.button("Detect accesseary"):
131
  if st.session_state.query_image_url:
132
  query_image = load_image_from_url(st.session_state.query_image_url)
133
  if query_image is not None: