File size: 3,063 Bytes
6ee47c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c8b727
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import List
import chromadb
from transformers import AutoTokenizer, AutoModel
from chromadb.config import Settings
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
from hazm import *


class RAG:
    def __init__(self,
                 model_name: str = "HooshvareLab/bert-base-parsbert-uncased",
                 collection_name: str = "legal_cases",
                 persist_directory: str = "chromadb_collections/",
                 top_k: int = 2
                 ) -> None:

        self.cases_df = pd.read_csv('processed_cases.csv')

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.normalizer = Normalizer()
        self.top_k = top_k

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        self.client = chromadb.PersistentClient(path=persist_directory)

        self.collection = self.client.get_collection(name=collection_name)

    def query_pre_process(self, query: str) -> str:
        return self.normalizer.normalize(query)

    def embed_single_text(self, text: str) -> np.ndarray:
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        inputs = {key: value.to(self.device) for key, value in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()


    def extract_case_title_from_df(self, case_id: str) -> str:

        case_id_int = int(case_id.split("_")[1])

        try:
            case_title = self.cases_df.loc[case_id_int, 'title']
            return case_title
        except KeyError:
            return "Case ID not found in DataFrame."

    def extract_case_text_from_df(self, case_id: str) -> str:
        case_id_int = int(case_id.split("_")[1])

        try:
            case_text = self.cases_df.loc[case_id_int, 'text']
            return case_text
        except KeyError:
            return "Case ID not found in DataFrame."

    def retrieve_relevant_cases(self, query_text: str) -> List[str]:
        normalized_query_text = self.query_pre_process(query_text)

        query_embedding = self.embed_single_text(normalized_query_text)
        query_embedding_list = query_embedding.tolist()

        results = self.collection.query(
            query_embeddings=[query_embedding_list],
            n_results=self.top_k
        )

        retrieved_cases = []
        for i in range(len(results['metadatas'][0])):
            case_id = results['ids'][0][i]
            case_text = self.extract_case_text_from_df(case_id)
            case_title = self.extract_case_title_from_df(case_id)
            retrieved_cases.append({
                "text": case_text,
                "title": case_title
            })

        return retrieved_cases

    def get_information(self, query: str) -> List[str]:
        return self.retrieve_relevant_cases(query)