winfred2027 commited on
Commit
1499e69
1 Parent(s): 77e656b

Update demo_support/retrieval.py

Browse files
Files changed (1) hide show
  1. demo_support/retrieval.py +2 -1
demo_support/retrieval.py CHANGED
@@ -41,7 +41,8 @@ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
41
  #for chunk in torch.split(feats, 10240):
42
  for chunk in feats:
43
  sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
44
- sims = torch.cat(sims)
 
45
  sims, idx = torch.sort(sims, descending=True)
46
  sim_mask = sims > sim_th
47
  sims = sims[sim_mask]
 
41
  #for chunk in torch.split(feats, 10240):
42
  for chunk in feats:
43
  sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
44
+ #sims = torch.cat(sims)
45
+ sims = torch.tensor(sims)
46
  sims, idx = torch.sort(sims, descending=True)
47
  sim_mask = sims > sim_th
48
  sims = sims[sim_mask]