mmcls-retriever / app.py
huyingfan
add more models
2a534ec
import itertools
import math
import os.path as osp
import numpy as np
import requests
import streamlit as st
from mmengine.dataset import Compose, default_collate
from mmengine.fileio import list_from_file
from mmengine.registry import init_default_scope
from PIL import Image
import mmengine
import logging
from mmengine.logging.logger import MMFormatter
from mmcls import list_models as list_models_
from mmcls.apis.model import ModelHub, init_model
import os
@st.cache()
def prepare_data():
import subprocess
subprocess.run(['unzip', '-n', 'imagenet-val.zip'])
@st.cache()
def load_demo_image():
response = requests.get(
'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true', # noqa
stream=True).raw
img = Image.open(response).convert('RGB')
return img
@st.cache()
def list_models(*args, **kwargs):
return sorted(list_models_(*args, **kwargs))
DATA_ROOT = '.'
ANNO_FILE = 'meta/val.txt'
LOG_FILE = 'demo.log'
CACHED_PATH = 'cache'
def get_model(model_name, pretrained=True):
metainfo = ModelHub.get(model_name)
if pretrained:
if metainfo.weights is None:
raise ValueError(
f"The model {model_name} doesn't have pretrained weights.")
ckpt = metainfo.weights
else:
ckpt = None
cfg = metainfo.config
cfg.model.backbone.init_cfg = dict(
type='Pretrained', checkpoint=ckpt, prefix='backbone')
new_model_cfg = dict()
new_model_cfg['type'] = 'ImageToImageRetriever'
if hasattr(cfg.model, 'neck') and cfg.model.neck is not None:
new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck]
else:
new_model_cfg['image_encoder'] = cfg.model.backbone
cfg.model = new_model_cfg
# prepare prototype
cached_path = f'{CACHED_PATH}/{model_name}_prototype.pt' # noqa
cfg.model.prototype = cached_path
model = init_model(metainfo.config, None, device='cpu')
with st.spinner(f'Loading model {model_name} on the server...This is '
'slow at the first time.'):
model.init_weights()
st.success('Model loaded!')
with st.spinner('Preparing prototype for all image...This is '
'slow at the first time.'):
model.prepare_prototype()
return model
def get_pred(name, img):
logger = mmengine.logging.MMLogger.get_current_instance()
file_handler = logging.FileHandler(LOG_FILE, 'w')
# `StreamHandler` record year, month, day hour, minute,
# and second timestamp. file_handler will only record logs
# without color to avoid garbled code saved in files.
file_handler.setFormatter(
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel('INFO')
logger.handlers.append(file_handler)
init_default_scope('mmcls')
model = get_model(name)
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
elif isinstance(img, np.ndarray):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
elif isinstance(img, Image.Image):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img'])
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = default_collate([data])
labels = model.val_step(data)[0].pred_label.label
scores = model.val_step(data)[0].pred_label.score[labels]
image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE))
data_root = osp.join(DATA_ROOT, 'val')
result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score)
for idx, score in zip(labels, scores)]
return result_list
def app():
prepare_data()
model_name = st.sidebar.selectbox(
"Model:",
[m.split('_prototype.pt')[0] for m in os.listdir(CACHED_PATH)])
st.markdown(
"<h1>Image To Image Retrieval</h1>",
unsafe_allow_html=True,
)
st.write(
'This is a demo for image to image retrieval in around 3k images from '
'ImageNet tiny val set using mmclassification apis. You can try more '
'features on [mmclassification]'
'(https://github.com/open-mmlab/mmclassification).')
file = st.file_uploader(
'Please upload your own image or use the provided:')
container1 = st.container()
if file:
raw_img = Image.open(file).convert('RGB')
else:
raw_img = load_demo_image()
container1.header('Image')
w, h = raw_img.size
scaling_factor = 360 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor)))
container1.image(resized_image, use_column_width='auto')
button = container1.button('Search')
st.header('Results')
topk = st.sidebar.number_input('Topk(1-50)', min_value=1, max_value=50)
# search on both selection of topk and button
if button or topk > 1:
result_list = get_pred(model_name, raw_img)
# auto adjust number of images in a row but 5 at most.
col = min(int(math.sqrt(topk)), 5)
row = math.ceil(topk / col)
grid = []
for i in range(row):
with st.container():
grid.append(st.columns(col))
grid = list(itertools.chain.from_iterable(grid))[:topk]
for cell, (image_path, score) in zip(grid, result_list[:topk]):
image = Image.open(image_path).convert('RGB')
w, h = raw_img.size
scaling_factor = 360 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor)))
cell.caption('Score: {:.4f}'.format(float(score)))
cell.image(image)
if __name__ == '__main__':
app()