test-demo-retriever-qa / bloom_retriever.py
houyu0930's picture
Upload DemoRetrieverQAPipeline
69c6116 verified
raw
history blame
No virus
1.57 kB
import torch
import numpy as np
from transformers import FeatureExtractionPipeline
from scipy.spatial.distance import cdist
class DemoRetrieverQAPipeline(FeatureExtractionPipeline):
def preprocess(self, inputs, **tokenize_kwargs):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
inputs = inputs.to(device)
self.query = inputs['question']
self.contexts = inputs['contexts']
return super().preprocess(self.query)
def _infer(self, inputs, return_tensors=False):
model_inputs = self.tokenizer(inputs, return_tensors=self.framework)
model_outputs = self.model(**model_inputs)
if return_tensors:
outputs = model_outputs[0]
if self.framework == "pt":
outputs = model_outputs[0].tolist()
elif self.framework == "tf":
outputs = model_outputs[0].numpy().tolist()
return [[ii[0][-1]] for ii in outputs]
def postprocess(self, model_outputs, return_tensors=False):
emb_contexts = np.concatenate([self._infer(context, return_tensors) for context in self.contexts], axis=0)
emb_queries = np.concatenate([self._infer(self.query, return_tensors)], axis=0)
# Important: take l2 distance!
dist = cdist(emb_queries, emb_contexts, 'euclidean')
top_k = lambda x: [
[self.contexts[qq] for qq in ii]
for ii in dist.argsort(axis=-1)[:,:x]
]
# top 5 nearest contexts for each queries
top_contexts = top_k(1)
return {"context": top_contexts[0][0]}