huyingfan commited on
Commit
7e02c9f
1 Parent(s): bd0b8c1
Files changed (2) hide show
  1. app.py +191 -0
  2. imagenet-val.zip +3 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ import os.path as osp
4
+
5
+ import numpy as np
6
+ import requests
7
+ import streamlit as st
8
+ from mmengine.dataset import Compose, default_collate
9
+ from mmengine.fileio import list_from_file
10
+ from mmengine.registry import init_default_scope
11
+ from PIL import Image
12
+ import mmengine
13
+ import logging
14
+ from mmengine.logging.logger import MMFormatter
15
+ from mmcls import list_models as list_models_
16
+ from mmcls.apis.model import ModelHub, init_model
17
+
18
+
19
+ @st.cache()
20
+ def prepare_data():
21
+ import subprocess
22
+ subprocess.run(['unzip', '-n', 'imagenet-val.zip'])
23
+
24
+
25
+ @st.cache()
26
+ def load_demo_image():
27
+ response = requests.get(
28
+ 'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true', # noqa
29
+ stream=True).raw
30
+ img = Image.open(response).convert('RGB')
31
+ return img
32
+
33
+
34
+ @st.cache()
35
+ def list_models(*args, **kwargs):
36
+ return sorted(list_models_(*args, **kwargs))
37
+
38
+
39
+ DATA_ROOT = '.'
40
+ ANNO_FILE = 'meta/val.txt'
41
+ LOG_FILE = 'demo.log'
42
+
43
+
44
+ def get_model(model_name, pretrained=True):
45
+
46
+ metainfo = ModelHub.get(model_name)
47
+
48
+ if pretrained:
49
+ if metainfo.weights is None:
50
+ raise ValueError(
51
+ f"The model {model_name} doesn't have pretrained weights.")
52
+ ckpt = metainfo.weights
53
+ else:
54
+ ckpt = None
55
+
56
+ cfg = metainfo.config
57
+ cfg.model.backbone.init_cfg = dict(
58
+ type='Pretrained', checkpoint=ckpt, prefix='backbone')
59
+ new_model_cfg = dict()
60
+ new_model_cfg['type'] = 'ImageToImageRetriever'
61
+ if hasattr(cfg.model, 'neck') and cfg.model.neck is not None:
62
+ new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck]
63
+ else:
64
+ new_model_cfg['image_encoder'] = cfg.model.backbone
65
+ cfg.model = new_model_cfg
66
+
67
+ # prepare prototype
68
+ cached_path = f'cache/{model_name}_prototype.pt' # noqa
69
+ cfg.model.prototype = cached_path
70
+
71
+ model = init_model(metainfo.config, None, device='cuda')
72
+ with st.spinner(f'Loading model {model_name} on the server...This is '
73
+ 'slow at the first time.'):
74
+ model.init_weights()
75
+ st.success('Model loaded!')
76
+
77
+ with st.spinner('Preparing prototype for all image...This is '
78
+ 'slow at the first time.'):
79
+ model.prepare_prototype()
80
+
81
+ return model
82
+
83
+
84
+ def get_pred(name, img):
85
+
86
+ logger = mmengine.logging.MMLogger.get_current_instance()
87
+ file_handler = logging.FileHandler(LOG_FILE, 'w')
88
+ # `StreamHandler` record year, month, day hour, minute,
89
+ # and second timestamp. file_handler will only record logs
90
+ # without color to avoid garbled code saved in files.
91
+ file_handler.setFormatter(
92
+ MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
93
+ file_handler.setLevel('INFO')
94
+ logger.handlers.append(file_handler)
95
+
96
+ init_default_scope('mmcls')
97
+
98
+ model = get_model(name)
99
+
100
+ cfg = model.cfg
101
+ # build the data pipeline
102
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
103
+ if isinstance(img, str):
104
+ if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
105
+ test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
106
+ data = dict(img_path=img)
107
+ elif isinstance(img, np.ndarray):
108
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
109
+ test_pipeline_cfg.pop(0)
110
+ data = dict(img=img)
111
+ elif isinstance(img, Image.Image):
112
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
113
+ test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img'])
114
+ data = dict(img=img)
115
+
116
+ test_pipeline = Compose(test_pipeline_cfg)
117
+ data = test_pipeline(data)
118
+ data = default_collate([data])
119
+
120
+ labels = model.val_step(data)[0].pred_label.label
121
+ scores = model.val_step(data)[0].pred_label.score[labels]
122
+
123
+ image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE))
124
+ data_root = osp.join(DATA_ROOT, 'val')
125
+ result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score)
126
+ for idx, score in zip(labels, scores)]
127
+ return result_list
128
+
129
+
130
+ def app():
131
+ prepare_data()
132
+
133
+ model_name = st.sidebar.selectbox("Model:", ['resnet50_8xb32_in1k'])
134
+
135
+ st.markdown(
136
+ "<h1 style='text-align: center;'>Image To Image Retrieval</h1>",
137
+ unsafe_allow_html=True,
138
+ )
139
+
140
+ file = st.file_uploader(
141
+ 'Please upload your own image or use the provided:')
142
+
143
+ container1 = st.container()
144
+ if file:
145
+ raw_img = Image.open(file).convert('RGB')
146
+ else:
147
+ raw_img = load_demo_image()
148
+
149
+ container1.header('Image')
150
+
151
+ w, h = raw_img.size
152
+ scaling_factor = 360 / w
153
+ resized_image = raw_img.resize(
154
+ (int(w * scaling_factor), int(h * scaling_factor)))
155
+
156
+ container1.image(resized_image, use_column_width='auto')
157
+ button = container1.button('Search')
158
+
159
+ st.header('Results')
160
+
161
+ topk = st.sidebar.number_input('Topk(1-50)', min_value=1, max_value=50)
162
+
163
+ # search on both selection of topk and button
164
+ if button or topk > 1:
165
+
166
+ result_list = get_pred(model_name, raw_img)
167
+ # auto adjust number of images in a row but 5 at most.
168
+ col = min(int(math.sqrt(topk)), 5)
169
+ row = math.ceil(topk / col)
170
+
171
+ grid = []
172
+ for i in range(row):
173
+ with st.container():
174
+ grid.append(st.columns(col))
175
+
176
+ grid = list(itertools.chain.from_iterable(grid))[:topk]
177
+
178
+ for cell, (image_path, score) in zip(grid, result_list[:topk]):
179
+ image = Image.open(image_path).convert('RGB')
180
+
181
+ w, h = raw_img.size
182
+ scaling_factor = 360 / w
183
+ resized_image = raw_img.resize(
184
+ (int(w * scaling_factor), int(h * scaling_factor)))
185
+
186
+ cell.caption('Score: {:.4f}'.format(float(score)))
187
+ cell.image(image)
188
+
189
+
190
+ if __name__ == '__main__':
191
+ app()
imagenet-val.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bdbaf0b147e7d6a598619a2a4b4da9821f671f86b95848d7cfaf7e0112cf378
3
+ size 401104218