import torch import numpy as np from transformers import FeatureExtractionPipeline from scipy.spatial.distance import cdist class DemoRetrieverQAPipeline(FeatureExtractionPipeline): def preprocess(self, inputs, **tokenize_kwargs): device = 'cuda' if torch.cuda.is_available() else 'cpu' inputs = inputs.to(device) self.query = inputs['question'] self.contexts = inputs['contexts'] return super().preprocess(self.query) def _infer(self, inputs, return_tensors=False): model_inputs = self.tokenizer(inputs, return_tensors=self.framework) model_outputs = self.model(**model_inputs) if return_tensors: outputs = model_outputs[0] if self.framework == "pt": outputs = model_outputs[0].tolist() elif self.framework == "tf": outputs = model_outputs[0].numpy().tolist() return [[ii[0][-1]] for ii in outputs] def postprocess(self, model_outputs, return_tensors=False): emb_contexts = np.concatenate([self._infer(context, return_tensors) for context in self.contexts], axis=0) emb_queries = np.concatenate([self._infer(self.query, return_tensors)], axis=0) # Important: take l2 distance! dist = cdist(emb_queries, emb_contexts, 'euclidean') top_k = lambda x: [ [self.contexts[qq] for qq in ii] for ii in dist.argsort(axis=-1)[:,:x] ] # top 5 nearest contexts for each queries top_contexts = top_k(1) return {"context": top_contexts[0][0]}