File size: 4,943 Bytes
83870cc 51dabd6 be1f224 0157dfd be1f224 51a31d4 be1f224 51a31d4 be1f224 0157dfd 51a31d4 be1f224 51dabd6 b06298d 0157dfd b7158e7 b06298d 8bbe3aa 51a31d4 0157dfd 83870cc 51a31d4 be1f224 0157dfd be1f224 0157dfd be1f224 ab5dfc2 8bbe3aa be1f224 8bbe3aa be1f224 8bbe3aa be1f224 8bbe3aa be1f224 8bbe3aa e9df5ab be1f224 1fb8ae3 8bbe3aa be1f224 492106d be1f224 492106d be1f224 1fb8ae3 ab5dfc2 1fb8ae3 e9df5ab b7158e7 1fb8ae3 83870cc be1f224 83870cc 1fb8ae3 8bbe3aa 83870cc ab5dfc2 1fb8ae3 8bbe3aa 1fb8ae3 8bbe3aa b06298d be1f224 1fb8ae3 83870cc 2827202 8bbe3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import os.path
import torch
from dotenv import load_dotenv
from datasets import DatasetDict
from dataclasses import dataclass
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizerFast,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizerFast,
LongformerModel,
LongformerTokenizer
)
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from src.retrievers.base_retriever import RetrieveType, Retriever
from src.utils.log import logger
from src.utils.preprocessing import remove_formulas
from src.utils.timing import timeit
load_dotenv()
@dataclass
class FaissRetrieverOptions:
ctx_encoder: PreTrainedModel
ctx_tokenizer: PreTrainedTokenizerFast
q_encoder: PreTrainedModel
q_tokenizer: PreTrainedTokenizerFast
embedding_path: str
lm: str
@staticmethod
def dpr(embedding_path: str):
return FaissRetrieverOptions(
ctx_encoder=DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
),
ctx_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
),
q_encoder=DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
),
q_tokenizer=DPRQuestionEncoderTokenizerFast.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
),
embedding_path=embedding_path,
lm="dpr"
)
@staticmethod
def longformer(embedding_path: str):
encoder = LongformerModel.from_pretrained(
"valhalla/longformer-base-4096-finetuned-squadv1"
)
tokenizer = LongformerTokenizer.from_pretrained(
"valhalla/longformer-base-4096-finetuned-squadv1"
)
return FaissRetrieverOptions(
ctx_encoder=encoder,
ctx_tokenizer=tokenizer,
q_encoder=encoder,
q_tokenizer=tokenizer,
embedding_path=embedding_path,
lm="longformer"
)
class FaissRetriever(Retriever):
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, paragraphs: DatasetDict,
options: FaissRetrieverOptions) -> None:
torch.set_grad_enabled(False)
self.lm = options.lm
# Context encoding and tokenization
self.ctx_encoder = options.ctx_encoder
self.ctx_tokenizer = options.ctx_tokenizer
# Question encoding and tokenization
self.q_encoder = options.q_encoder
self.q_tokenizer = options.q_tokenizer
self.paragraphs = paragraphs
self.embedding_path = options.embedding_path
self.index = self._init_index()
def _embed_question(self, q):
match self.lm:
case "dpr":
tok = self.q_tokenizer(
q, return_tensors="pt", truncation=True, padding=True)
return self.q_encoder(**tok)[0][0].numpy()
case "longformer":
tok = self.q_tokenizer(q, return_tensors="pt")
return self.q_encoder(**tok).last_hidden_state[0][0].numpy()
def _embed_context(self, row):
p = row["text"]
match self.lm:
case "dpr":
tok = self.ctx_tokenizer(
p, return_tensors="pt", truncation=True, padding=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {"embeddings": enc}
case "longformer":
tok = self.ctx_tokenizer(p, return_tensors="pt")
enc = self.ctx_encoder(**tok).last_hidden_state[0][0].numpy()
return {"embeddings": enc}
def _init_index(
self,
force_new_embedding: bool = False):
ds = self.paragraphs["train"]
ds = ds.map(remove_formulas)
if not force_new_embedding and os.path.exists(self.embedding_path):
ds.load_faiss_index(
'embeddings', self.embedding_path) # type: ignore
return ds
else:
# Add FAISS embeddings
index = ds.map(self._embed_context) # type: ignore
index.add_faiss_index(column="embeddings")
# save dataset w/ embeddings
os.makedirs("./src/models/", exist_ok=True)
index.save_faiss_index(
"embeddings", self.embedding_path)
return index
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
question_embedding = self._embed_question(query)
scores, results = self.index.get_nearest_examples(
"embeddings", question_embedding, k=k
)
return scores, results
|