Legal_RAG_Chatbot / model /rag /rag_handler.py
mojtabaa4's picture
add application files
1c8b727
raw
history blame contribute delete
No virus
3.06 kB
from typing import List
import chromadb
from transformers import AutoTokenizer, AutoModel
from chromadb.config import Settings
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
from hazm import *
class RAG:
def __init__(self,
model_name: str = "HooshvareLab/bert-base-parsbert-uncased",
collection_name: str = "legal_cases",
persist_directory: str = "chromadb_collections/",
top_k: int = 2
) -> None:
self.cases_df = pd.read_csv('processed_cases.csv')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.normalizer = Normalizer()
self.top_k = top_k
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.client = chromadb.PersistentClient(path=persist_directory)
self.collection = self.client.get_collection(name=collection_name)
def query_pre_process(self, query: str) -> str:
return self.normalizer.normalize(query)
def embed_single_text(self, text: str) -> np.ndarray:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
def extract_case_title_from_df(self, case_id: str) -> str:
case_id_int = int(case_id.split("_")[1])
try:
case_title = self.cases_df.loc[case_id_int, 'title']
return case_title
except KeyError:
return "Case ID not found in DataFrame."
def extract_case_text_from_df(self, case_id: str) -> str:
case_id_int = int(case_id.split("_")[1])
try:
case_text = self.cases_df.loc[case_id_int, 'text']
return case_text
except KeyError:
return "Case ID not found in DataFrame."
def retrieve_relevant_cases(self, query_text: str) -> List[str]:
normalized_query_text = self.query_pre_process(query_text)
query_embedding = self.embed_single_text(normalized_query_text)
query_embedding_list = query_embedding.tolist()
results = self.collection.query(
query_embeddings=[query_embedding_list],
n_results=self.top_k
)
retrieved_cases = []
for i in range(len(results['metadatas'][0])):
case_id = results['ids'][0][i]
case_text = self.extract_case_text_from_df(case_id)
case_title = self.extract_case_title_from_df(case_id)
retrieved_cases.append({
"text": case_text,
"title": case_title
})
return retrieved_cases
def get_information(self, query: str) -> List[str]:
return self.retrieve_relevant_cases(query)