Diangle commited on
Commit
cc6ed45
·
1 Parent(s): 04d9048

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -20
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
9
 
10
 
11
- TITLE="""<h1 style="font-size: 42px;" align="center">Video Retrieval</h1>"""
12
 
13
  DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
14
  IMAGE='<div style="text-align: right;"><img src="https://huggingface.co/spaces/Diangle/Clip4Clip-webvid/resolve/main/Searchium.png" width="333" height="216"/>'
@@ -23,7 +23,6 @@ ft_visual_features_database = np.load(ft_visual_features_file)
23
  database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
24
  database_df = pd.read_csv(database_csv_path)
25
 
26
-
27
  class NearestNeighbors:
28
  """
29
  Class for NearestNeighbors.
@@ -56,28 +55,33 @@ class NearestNeighbors:
56
  sim, idx = self.index.search(q_data, self.n_neighbors)
57
  else:
58
  if self.metric == 'binary':
59
- print('binary search: ')
60
- bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
61
- print(bq_data.shape, self.index.d)
62
  sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
63
 
64
  if self.rerank_from > self.n_neighbors:
65
- rerank_data = self.o_data[idx[0]]
66
- rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
67
- rerank_search.fit(rerank_data)
68
- sim, re_idxs = rerank_search.kneighbors(q_data)
69
- idx = [idx[0][re_idxs[0]]]
70
-
 
 
 
 
 
 
71
  return sim, idx
72
 
 
73
  model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
74
- tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")
75
 
76
  def search(search_sentence):
77
  inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
78
 
79
- outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
80
- # Customized projection layer
81
  text_projection = model.state_dict()['text_projection.weight']
82
  text_embeds = outputs[1] @ text_projection
83
  final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
@@ -89,8 +93,15 @@ def search(search_sentence):
89
 
90
  nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
91
  nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
92
- sims, idxs = nn_search.kneighbors(sequence_output)
93
- return database_df.iloc[idxs[0]]['contentUrl'].to_list()
 
 
 
 
 
 
 
94
 
95
 
96
  with gr.Blocks() as demo:
@@ -102,12 +113,13 @@ with gr.Blocks() as demo:
102
  with gr.Column():
103
  inp = gr.Textbox(placeholder="Write a sentence.")
104
  btn = gr.Button(value="Retrieve")
105
- ex = [["a woman waving to the camera"],["a basketball player performing a slam dunk"], ["how to bake a chocolate cake"], ["birds fly in the sky"]]
 
106
  gr.Examples(examples=ex,
107
- inputs=[inp],
108
- )
109
  with gr.Column():
110
- out = [gr.Video(format='mp4') for _ in range(5)]
111
  btn.click(search, inputs=inp, outputs=out)
112
 
113
  demo.launch()
 
8
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
9
 
10
 
11
+ TITLE="""<h1 style="font-size: 64px;" align="center">Video Retrieval</h1>"""
12
 
13
  DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
14
  IMAGE='<div style="text-align: right;"><img src="https://huggingface.co/spaces/Diangle/Clip4Clip-webvid/resolve/main/Searchium.png" width="333" height="216"/>'
 
23
  database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
24
  database_df = pd.read_csv(database_csv_path)
25
 
 
26
  class NearestNeighbors:
27
  """
28
  Class for NearestNeighbors.
 
55
  sim, idx = self.index.search(q_data, self.n_neighbors)
56
  else:
57
  if self.metric == 'binary':
58
+ print('This is binary search.')
59
+ bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
 
60
  sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
61
 
62
  if self.rerank_from > self.n_neighbors:
63
+ re_sims = np.zeros([len(q_data), self.n_neighbors], dtype=float)
64
+ re_idxs = np.zeros([len(q_data), self.n_neighbors], dtype=float)
65
+ for i, q in enumerate(q_data):
66
+ rerank_data = self.o_data[idx[i]]
67
+ rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
68
+ rerank_search.fit(rerank_data)
69
+ re_sim, re_idx = rerank_search.kneighbors(np.asarray([q]))
70
+ re_sims[i, :] = re_sim
71
+ re_idxs[i, :] = idx[i][re_idx]
72
+ idx = re_idxs
73
+ sim = re_sims
74
+
75
  return sim, idx
76
 
77
+
78
  model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
79
+ tokenizer = CLIPTokenizer.from_pretrained("Diangle/clip4clip-webvid")
80
 
81
  def search(search_sentence):
82
  inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
83
 
84
+ outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
 
85
  text_projection = model.state_dict()['text_projection.weight']
86
  text_embeds = outputs[1] @ text_projection
87
  final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
 
93
 
94
  nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
95
  nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
96
+ sims, idxs = nn_search.kneighbors(sequence_output)
97
+ # print(database_df.iloc[idxs[0]]['contentUrl'])
98
+ urls = database_df.iloc[idxs[0]]['contentUrl'].to_list()
99
+ AUTOPLAY_VIDEOS = []
100
+ for url in urls:
101
+ AUTOPLAY_VIDEOS.append("""<video controls muted autoplay>
102
+ <source src={} type="video/mp4">
103
+ </video>""".format(url))
104
+ return AUTOPLAY_VIDEOS
105
 
106
 
107
  with gr.Blocks() as demo:
 
113
  with gr.Column():
114
  inp = gr.Textbox(placeholder="Write a sentence.")
115
  btn = gr.Button(value="Retrieve")
116
+ ex = [["mind-blowing magic tricks"],["baking chocolate cake"],
117
+ ["birds fly in the sky"], ["natural wonders of the world"]]
118
  gr.Examples(examples=ex,
119
+ inputs=[inp]
120
+ )
121
  with gr.Column():
122
+ out = [gr.HTML() for _ in range(5)]
123
  btn.click(search, inputs=inp, outputs=out)
124
 
125
  demo.launch()