Spaces:
Running
Running
# coding: utf-8 | |
# Copyright (C) 2023, [Breezedeus](https://github.com/breezedeus). | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
import os | |
import sys | |
import logging | |
from typing import List | |
import yaml | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from datasets import load_dataset | |
import chromadb | |
from chromadb import Settings | |
from coin_clip.utils import resize_img | |
from coin_clip.chroma_embedding import ChromaEmbeddingFunction | |
from coin_clip.detect import Detector | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
env = os.environ.get('COIN_ENV', 'local') | |
if env == 'hf': | |
config_fp = 'hf_config.yaml' | |
else: | |
config_fp = 'local_config.yaml' | |
logger.info(f'Use config file: {config_fp}') | |
total_config = yaml.safe_load(open(config_fp)) | |
DETECTOR = Detector( | |
model_name=total_config['detector']['model_name'], | |
device=total_config['detector']['device'], | |
) | |
# USE_REMOVE_BG = total_config['use_remove_bg'] | |
RESIZED_TO_BEFORE_DETECT = total_config['detector'].get('resized_to', 300) | |
def prepare_chromadb(): | |
if env == 'local': | |
return | |
from huggingface_hub import snapshot_download | |
snapshot_download(repo_type='model', repo_id='breezedeus/usa-coins-chromadb', local_dir='./') | |
def load_dataset(data_path): | |
logger.info('Load dataset from %s', data_path) | |
if env == 'hf': | |
dataset = load_dataset(data_path, split='train') | |
else: | |
dataset = load_dataset("imagefolder", data_dir=data_path, split='train') | |
return dataset | |
def detect(images): | |
outs = [] | |
for idx, img in enumerate(images): | |
img = resize_img(img, RESIZED_TO_BEFORE_DETECT) | |
out = DETECTOR.detect(np.array(img)) | |
if not out: | |
out = {'position': None, 'scores': 0.0} | |
else: | |
out = out[0] | |
out.pop('label') | |
out['position'] = out.pop('box') | |
out['from_image_idx'] = idx | |
outs.append(out) | |
box_images = [] | |
for out, img in zip(outs, images): | |
if out['position'] is None: | |
box_images.append(None) | |
else: | |
# box 比例值转化为绝对位置值 | |
w, h = img.size | |
box = out['position'] | |
box = (int(box[0] * w), int(box[1] * h), int(box[2] * w), int(box[3] * h)) | |
box_images.append(img.crop(box)) | |
return outs, box_images | |
def load_chroma_db(db_dir, collection_name, model_name, device='cpu'): | |
logger.info('Load chroma db from %s', db_dir) | |
client = chromadb.PersistentClient( | |
path=db_dir, settings=Settings(anonymized_telemetry=False) | |
) | |
embedding_function = ChromaEmbeddingFunction(model_name, device) | |
collection = client.get_collection( | |
name=collection_name, | |
embedding_function=embedding_function, | |
) | |
return collection | |
def retrieve(query_image: Image.Image, collection, top_k=20) -> List[Image.Image]: | |
query_image = np.array(query_image) | |
retrieved = collection.query( | |
query_images=[query_image], include=['metadatas', 'distances'], n_results=top_k, | |
) | |
logger.info('retrieved ids: %s', retrieved['ids'][0]) | |
logger.info('retrieved distances: %s', retrieved['distances'][0]) | |
return [ds_dict[id]['image'] for id in retrieved['ids'][0]] | |
dataset = load_dataset(**total_config['dataset']) | |
ds_dict = {_d['id']: _d for _d in dataset} | |
prepare_chromadb() | |
cc_collection = load_chroma_db(**total_config['coin_clip_db']) | |
clip_collection = load_chroma_db(**total_config['clip_db']) | |
def search(image_file: Image.Image): | |
images = [image_file.convert('RGB')] | |
detected_outs, box_images = detect(images) | |
box_images = [img for img in box_images if img is not None] | |
if len(box_images) == 0: | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
box_image = box_images[0] | |
# breakpoint() | |
cc_results = retrieve(box_image, cc_collection, top_k=30) | |
clip_results = retrieve(box_image, clip_collection, top_k=30) | |
return [ | |
gr.update(value=box_image, visible=True), | |
gr.update(visible=False), | |
gr.update(value=cc_results, visible=True), | |
gr.update(value=clip_results, visible=True), | |
] | |
def main(): | |
title = 'USA Coin Retrieval by' | |
desc = ( | |
'<p style="text-align: center">Coin-CLIP: ' | |
'<a href="https://huggingface.co/breezedeus/coin-clip-vit-base-patch32" target="_blank">Model</a>, ' | |
'<a href="https://github.com/breezedeus/coin-clip" target="_blank">Github</a>; ' | |
'Author: <a href="https://www.breezedeus.com" target="_blank">Breezedeus</a> , ' | |
'<a href="https://github.com/breezedeus" target="_blank">Github</a> </p>' | |
) | |
examples = [ | |
'examples/c2.jpeg', | |
'examples/c20.jpg', | |
'examples/c21.jpg', | |
'examples/c22.png', | |
'examples/c1.jpg', | |
'examples/c11.jpg', | |
'examples/c3.png', | |
'examples/c4.jpg', | |
'examples/c5.jpeg', | |
'examples/c6.jpeg', | |
'examples/c7.jpg', | |
'examples/c8.jpeg', | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
f'<h1 style="text-align: center; margin-bottom: 1rem;">{title} <a href="https://github.com/breezedeus/coin-clip" target="_blank">Coin-CLIP</a></h1>' | |
) | |
gr.Markdown(desc) | |
with gr.Row(equal_height=False): | |
with gr.Column(variant='compact', scale=1): | |
gr.Markdown('### Image within a coin') | |
image_file = gr.Image( | |
label='Coin Image to Search', | |
type="pil", | |
image_mode='RGB', | |
height=400, | |
) | |
sub_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(variant='compact', scale=1): | |
gr.Markdown('### Detected Coin') | |
detected_image = gr.Image( | |
label='Detected Coin', | |
type="pil", | |
interactive=False, | |
image_mode='RGB', | |
height=400, | |
) | |
no_detect_warn = gr.Markdown( | |
'**⚠️ Warning**: No coins detected in image', visible=False | |
) | |
with gr.Row(equal_height=False): | |
with gr.Column(variant='compact', scale=1): | |
gr.Markdown('### Results from Coin-CLIP') | |
cc_results = gr.Gallery( | |
label='Coin-CLIP Results', columns=3, height=2200, show_share_button=True, visible=False | |
) | |
with gr.Column(variant='compact', scale=1): | |
gr.Markdown('### Results from CLIP') | |
coin_results = gr.Gallery( | |
label='CLIP Results', columns=3, height=2200, show_share_button=True, visible=False | |
) | |
sub_btn.click( | |
search, | |
inputs=[image_file,], | |
outputs=[detected_image, no_detect_warn, cc_results, coin_results], | |
) | |
gr.Examples( | |
label='Examples', | |
examples=examples, | |
inputs=image_file, | |
outputs=[detected_image, no_detect_warn, cc_results, coin_results], | |
fn=search, | |
examples_per_page=12, | |
cache_examples=True, | |
) | |
demo.queue(max_size=20) | |
demo.launch() | |
if __name__ == '__main__': | |
main() | |