File size: 4,714 Bytes
256a58e
5f988b9
 
 
 
 
 
 
 
 
256a58e
 
 
a450a39
256a58e
941a695
5f988b9
941a695
256a58e
 
5f988b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256a58e
5f988b9
 
 
 
 
 
 
 
256a58e
5f988b9
 
 
256a58e
5f988b9
 
 
 
 
256a58e
 
 
 
 
5f988b9
 
 
256a58e
 
 
5f988b9
 
 
256a58e
 
 
5f988b9
 
 
256a58e
5f988b9
 
 
 
 
 
 
 
 
256a58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697a6d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection


TITLE="""<h1 style="font-size: 42px;" align="center">Video Retrieval</h1>"""

DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
IMAGE='<div style="text-align: left;"><img src="https://huggingface.co/spaces/Diangle/Clip4Clip-webvid/resolve/main/Searchium.png"/>'

DATA_PATH = './data'

ft_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database.npy'

#load database features:
ft_visual_features_database = np.load(ft_visual_features_file)

database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
database_df = pd.read_csv(database_csv_path)


class NearestNeighbors:
    """
    Class for NearestNeighbors.   
    """
    def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
        """
         metric = 'cosine' / 'binary' 
         if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
        """
        self.n_neighbors = n_neighbors
        self.metric = metric        
        self.rerank_from = rerank_from                
        
    def normalize(self, a):
        return a / np.sum(a**2, axis=1, keepdims=True)
    
    def fit(self, data, o_data=None):
        if self.metric == 'cosine':
            data = self.normalize(data)
            self.index = faiss.IndexFlatIP(data.shape[1])     
        elif self.metric == 'binary':
            self.o_data = data if o_data is None else o_data
            #assuming data already packed
            self.index = faiss.IndexBinaryFlat(data.shape[1]*8)            
        self.index.add(np.ascontiguousarray(data))
        
    def kneighbors(self, q_data):                
        if self.metric == 'cosine':
            q_data = self.normalize(q_data)      
            sim, idx = self.index.search(q_data, self.n_neighbors)        
        else:            
            if self.metric == 'binary':
                print('binary search: ')
                bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)            
            print(bq_data.shape, self.index.d)
            sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
            
            if self.rerank_from > self.n_neighbors:
                rerank_data = self.o_data[idx[0]]
                rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine') 
                rerank_search.fit(rerank_data)
                sim, re_idxs = rerank_search.kneighbors(q_data)
                idx = [idx[0][re_idxs[0]]]
        
        return sim, idx
    
model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")
    
def search(search_sentence):
    inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)

    outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
    # Customized projection layer    
    text_projection = model.state_dict()['text_projection.weight']
    text_embeds = outputs[1] @ text_projection
    final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]

    # Normalization
    final_output = final_output / final_output.norm(dim=-1, keepdim=True)
    final_output = final_output.cpu().detach().numpy()
    sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
    
    nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
    nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
    sims, idxs = nn_search.kneighbors(sequence_output)  
    return database_df.iloc[idxs[0]]['contentUrl'].to_list()


with gr.Blocks() as demo:
    gr.HTML(TITLE)
    gr.Markdown(DESCRIPTION)
    gr.HTML(IMAGE)
    gr.Markdown("Retrieval of top 5 videos relevant to the input sentence: ")
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(placeholder="Write a sentence.")
            btn = gr.Button(value="Retrieve")
            ex = [["a woman waving to the camera"],["a basketball player performing a slam dunk"], ["how to bake a chocolate cake"], ["birds fly in the sky"]]
            gr.Examples(examples=ex,
                    inputs=[inp],
                    )
        with gr.Column():
            out = [gr.Video(format='mp4') for _ in range(5)]
        btn.click(search, inputs=inp, outputs=out)
        
demo.launch()