Video-Search / app.py
Diangle's picture
Update app.py
9aeabaa
raw
history blame
4.74 kB
import gradio as gr
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection
TITLE="""<h1 style="font-size: 42px;" align="center">Video Retrieval</h1>"""
DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
IMAGE='<div style="text-align: left;"><img src="https://huggingface.co/spaces/Diangle/Clip4Clip-webvid/resolve/main/Searchium.png" width="333" height="216"/>'
DATA_PATH = './data'
ft_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database.npy'
#load database features:
ft_visual_features_database = np.load(ft_visual_features_file)
database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
database_df = pd.read_csv(database_csv_path)
class NearestNeighbors:
"""
Class for NearestNeighbors.
"""
def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
"""
metric = 'cosine' / 'binary'
if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
"""
self.n_neighbors = n_neighbors
self.metric = metric
self.rerank_from = rerank_from
def normalize(self, a):
return a / np.sum(a**2, axis=1, keepdims=True)
def fit(self, data, o_data=None):
if self.metric == 'cosine':
data = self.normalize(data)
self.index = faiss.IndexFlatIP(data.shape[1])
elif self.metric == 'binary':
self.o_data = data if o_data is None else o_data
#assuming data already packed
self.index = faiss.IndexBinaryFlat(data.shape[1]*8)
self.index.add(np.ascontiguousarray(data))
def kneighbors(self, q_data):
if self.metric == 'cosine':
q_data = self.normalize(q_data)
sim, idx = self.index.search(q_data, self.n_neighbors)
else:
if self.metric == 'binary':
print('binary search: ')
bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
print(bq_data.shape, self.index.d)
sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
if self.rerank_from > self.n_neighbors:
rerank_data = self.o_data[idx[0]]
rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
rerank_search.fit(rerank_data)
sim, re_idxs = rerank_search.kneighbors(q_data)
idx = [idx[0][re_idxs[0]]]
return sim, idx
model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")
def search(search_sentence):
inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
# Customized projection layer
text_projection = model.state_dict()['text_projection.weight']
text_embeds = outputs[1] @ text_projection
final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
# Normalization
final_output = final_output / final_output.norm(dim=-1, keepdim=True)
final_output = final_output.cpu().detach().numpy()
sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
sims, idxs = nn_search.kneighbors(sequence_output)
return database_df.iloc[idxs[0]]['contentUrl'].to_list()
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.Markdown(DESCRIPTION)
gr.HTML(IMAGE)
gr.Markdown("Retrieval of top 5 videos relevant to the input sentence: ")
with gr.Row():
with gr.Column():
inp = gr.Textbox(placeholder="Write a sentence.")
btn = gr.Button(value="Retrieve")
ex = [["a woman waving to the camera"],["a basketball player performing a slam dunk"], ["how to bake a chocolate cake"], ["birds fly in the sky"]]
gr.Examples(examples=ex,
inputs=[inp],
)
with gr.Column():
out = [gr.Video(format='mp4') for _ in range(5)]
btn.click(search, inputs=inp, outputs=out)
demo.launch()