Spaces:
Runtime error
Runtime error
from html import escape | |
import re | |
import torch | |
import streamlit as st | |
import pandas as pd, numpy as np | |
from transformers import CLIPProcessor, CLIPModel, FlavaModel, FlavaProcessor | |
from st_clickable_images import clickable_images | |
MODEL_NAMES = ["flava-full", "vit-base-patch32", "vit-base-patch16", "vit-large-patch14", "vit-large-patch14-336"] | |
def load(): | |
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} | |
models = {} | |
processors = {} | |
embeddings = {} | |
for name in MODEL_NAMES: | |
if "flava" not in name: | |
model = CLIPModel | |
processor = CLIPProcessor | |
prefix = "openai/clip-" | |
else: | |
model = FlavaModel | |
processor = FlavaProcessor | |
prefix = "facebook/" | |
models[name] = model.from_pretrained(f"{prefix}{name}") | |
models[name].eval() | |
processors[name] = processor.from_pretrained(f"{prefix}{name}") | |
embeddings[name] = { | |
0: np.load(f"embeddings-{name}.npy"), | |
1: np.load(f"embeddings2-{name}.npy"), | |
} | |
for k in [0, 1]: | |
embeddings[name][k] = embeddings[name][k] / np.linalg.norm( | |
embeddings[name][k], axis=1, keepdims=True | |
) | |
return models, processors, df, embeddings | |
models, processors, df, embeddings = load() | |
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} | |
def compute_text_embeddings(list_of_strings, name): | |
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
result = models[name].get_text_features(**inputs) | |
if "flava" in name: | |
result = result[:, 0, :] | |
result = result.detach().numpy() | |
return result / np.linalg.norm(result, axis=1, keepdims=True) | |
def image_search(query, corpus, name, n_results=24): | |
positive_embeddings = None | |
def concatenate_embeddings(e1, e2): | |
if e1 is None: | |
return e2 | |
else: | |
return np.concatenate((e1, e2), axis=0) | |
splitted_query = query.split("EXCLUDING ") | |
dot_product = 0 | |
k = 0 if corpus == "Unsplash" else 1 | |
if len(splitted_query[0]) > 0: | |
positive_queries = splitted_query[0].split(";") | |
for positive_query in positive_queries: | |
match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query) | |
if match: | |
corpus2, idx, remainder = match.groups() | |
idx, remainder = int(idx), remainder.strip() | |
k2 = 0 if corpus2 == "Unsplash" else 1 | |
positive_embeddings = concatenate_embeddings( | |
positive_embeddings, embeddings[name][k2][idx : idx + 1, :] | |
) | |
if len(remainder) > 0: | |
positive_embeddings = concatenate_embeddings( | |
positive_embeddings, compute_text_embeddings([remainder], name) | |
) | |
else: | |
positive_embeddings = concatenate_embeddings( | |
positive_embeddings, compute_text_embeddings([positive_query], name) | |
) | |
dot_product = embeddings[name][k] @ positive_embeddings.T | |
dot_product = dot_product - np.median(dot_product, axis=0) | |
dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True) | |
dot_product = np.min(dot_product, axis=1) | |
if len(splitted_query) > 1: | |
negative_queries = (" ".join(splitted_query[1:])).split(";") | |
negative_embeddings = compute_text_embeddings(negative_queries, name) | |
dot_product2 = embeddings[name][k] @ negative_embeddings.T | |
dot_product2 = dot_product2 - np.median(dot_product2, axis=0) | |
dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True) | |
dot_product -= np.max(np.maximum(dot_product2, 0), axis=1) | |
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] | |
return [ | |
( | |
df[k].iloc[i]["path"], | |
df[k].iloc[i]["tooltip"] + source[k], | |
i, | |
) | |
for i in results | |
] | |
description = """ | |
# FLAVA Semantic Image-Text Search | |
""" | |
instruction= """ | |
### **Enter your query and hit enter** | |
**Things to try:** compare with other models or search for "a field in country side EXCLUDING green" | |
""" | |
credit = """ | |
*Built with FAIR's [FLAVA](https://arxiv.org/abs/2112.04482) models, π€ Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* | |
*Forked and inspired from a similar app available [here](https://huggingface.co/spaces/vivien/clip/)* | |
""" | |
options = """ | |
## Compare | |
Check results for a single model or compare two models by using the dropdown below: | |
""" | |
howto = """ | |
## Advanced Use | |
- Click on an image to use it as a query and find similar images | |
- Several queries, including one based on an image, can be combined (use "**;**" as a separator). | |
- Try "a person walking on a grass field; red flowers". | |
- If the input includes "**EXCLUDING**", text following it will be used as a negative query. | |
- Try "a field in country side which is green" and "a field in countryside EXCLUDING green". | |
""" | |
div_style = { | |
"display": "flex", | |
"justify-content": "center", | |
"flex-wrap": "wrap", | |
} | |
def main(): | |
st.markdown( | |
""" | |
<style> | |
.block-container{ | |
max-width: 1200px; | |
} | |
div.row-widget.stRadio > div{ | |
flex-direction:row; | |
display: flex; | |
justify-content: center; | |
} | |
div.row-widget.stRadio > div > label{ | |
margin-left: 5px; | |
margin-right: 5px; | |
} | |
.row-widget { | |
margin-top: -25px; | |
} | |
section>div:first-child { | |
padding-top: 30px; | |
} | |
div.reportview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
footer { | |
visibility: hidden; | |
} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown(description) | |
st.sidebar.markdown(options) | |
mode = st.sidebar.selectbox( | |
"", ["Results for FLAVA full", "Comparison of 2 models"], index=0 | |
) | |
st.sidebar.markdown(howto) | |
st.sidebar.markdown(credit) | |
_, c, _ = st.columns((1, 3, 1)) | |
c.markdown(instruction) | |
if "query" in st.session_state: | |
query = c.text_input("", value=st.session_state["query"]) | |
else: | |
query = c.text_input("", value="a field in the countryside which is green") | |
corpus = st.radio("", ["Unsplash", "Movies"]) | |
models_dict = { | |
"FLAVA": "flava-full", | |
"ViT-B/32 (quickest)": "vit-base-patch32", | |
"ViT-B/16 (quick)": "vit-base-patch16", | |
"ViT-L/14 (slow)": "vit-large-patch14", | |
"ViT-L/14@336px (slowest)": "vit-large-patch14-336", | |
} | |
if "Comparison" in mode: | |
c1, c2 = st.columns((1, 1)) | |
selection1 = c1.selectbox("", models_dict.keys(), index=0) | |
selection2 = c2.selectbox("", models_dict.keys(), index=3) | |
name1 = models_dict[selection1] | |
name2 = models_dict[selection2] | |
else: | |
name1 = MODEL_NAMES[0] | |
if len(query) > 0: | |
results1 = image_search(query, corpus, name1) | |
if "Comparison" in mode: | |
with c1: | |
clicked1 = clickable_images( | |
[result[0] for result in results1], | |
titles=[result[1] for result in results1], | |
div_style=div_style, | |
img_style={"margin": "2px", "height": "150px"}, | |
key=query + corpus + name1 + "1", | |
) | |
results2 = image_search(query, corpus, name2) | |
with c2: | |
clicked2 = clickable_images( | |
[result[0] for result in results2], | |
titles=[result[1] for result in results2], | |
div_style=div_style, | |
img_style={"margin": "2px", "height": "150px"}, | |
key=query + corpus + name2 + "2", | |
) | |
else: | |
clicked1 = clickable_images( | |
[result[0] for result in results1], | |
titles=[result[1] for result in results1], | |
div_style=div_style, | |
img_style={"margin": "2px", "height": "200px"}, | |
key=query + corpus + name1 + "1", | |
) | |
clicked2 = -1 | |
if clicked2 >= 0 or clicked1 >= 0: | |
change_query = False | |
if "last_clicked" not in st.session_state: | |
change_query = True | |
else: | |
if max(clicked2, clicked1) != st.session_state["last_clicked"]: | |
change_query = True | |
if change_query: | |
if clicked1 >= 0: | |
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]" | |
elif clicked2 >= 0: | |
st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]" | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() | |