Spaces:
Runtime error
Runtime error
add application files
Browse files- app.py +0 -0
- config.py +22 -0
- model/__init__.py +0 -0
- model/chat.py +23 -0
- model/controller.py +18 -0
- model/llm/llm.py +108 -0
- model/processor/case_crawler.py +113 -0
- model/processor/database_Chunker.ipynb +0 -0
- model/processor/law_provider.py +61 -0
- model/processor/pre_process.ipynb +0 -0
- model/processor/retrieval_rag_nlp_project.ipynb:Zone.Identifier +0 -0
- model/propmt/__init__.py +0 -0
- model/propmt/prompt_handler.py +16 -0
- model/rag/__init__.py +0 -0
- model/rag/rag_handler.py +102 -0
- requirements.txt +24 -0
app.py
ADDED
File without changes
|
config.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
gpt_3_5 = "gpt-3.5-turbo-instruct"
|
3 |
+
gpt_mini = "gpt-4o-mini"
|
4 |
+
|
5 |
+
aval_ai = {
|
6 |
+
"model": gpt_3_5,
|
7 |
+
"base_url": "https://api.avalai.ir/v1",
|
8 |
+
|
9 |
+
}
|
10 |
+
|
11 |
+
GILAS_CONFIG = {
|
12 |
+
"api_key": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwMzg5OTQ0NjgsImp0aSI6IjExNDg4MzAyMTE3NDA0MzY2ODc0NiIsImlhdCI6MTcyMzYzNDQ2OCwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyMzYzNDQ2OCwic3ViIjoiMTE0ODgzMDIxMTc0MDQzNjY4NzQ2In0.8hbh59BmwBcAfoH9nEB98_5BIuxzwUUb8fpHSKF1S_Q",
|
13 |
+
"model": "gpt-4o-mini" ,
|
14 |
+
"base_url": 'https://api.gilas.io/v1',
|
15 |
+
}
|
16 |
+
|
17 |
+
OPENAI_CONFIG = {
|
18 |
+
"model": gpt_mini,
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
LLM_CONFIG = aval_ai
|
model/__init__.py
ADDED
File without changes
|
model/chat.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.propmt.prompt_handler import *
|
2 |
+
from model.llm.llm import *
|
3 |
+
from model.rag.rag_handler import *
|
4 |
+
from config import *
|
5 |
+
|
6 |
+
class Chat:
|
7 |
+
def __init__(self, chat_id, rag_handler) -> None:
|
8 |
+
self.chat_id = chat_id
|
9 |
+
self.message_history = []
|
10 |
+
self.response_history = []
|
11 |
+
self.prompt_handler = Prompt()
|
12 |
+
self.llm = LLM_API_Call("gilas")
|
13 |
+
self.rag_handler = rag_handler
|
14 |
+
|
15 |
+
def response(self, message: str) -> str:
|
16 |
+
self.message_history.append(message)
|
17 |
+
|
18 |
+
info_list = self.rag_handler.get_information(message)
|
19 |
+
prompt = self.prompt_handler.get_prompt(message, info_list)
|
20 |
+
response = self.llm.get_LLM_response(prompt=prompt)
|
21 |
+
|
22 |
+
self.response_history.append(response)
|
23 |
+
return response
|
model/controller.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
%%writefile model/controller.py
|
2 |
+
from model.chat import *
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
5 |
+
|
6 |
+
class Controller:
|
7 |
+
def __init__(self) -> None:
|
8 |
+
self.chat_dic = {}
|
9 |
+
self.rag_handler = RAG()
|
10 |
+
|
11 |
+
def handle_message(self,
|
12 |
+
chat_id: int,
|
13 |
+
message: str) -> str:
|
14 |
+
if chat_id not in self.chat_dic:
|
15 |
+
self.chat_dic[chat_id] = Chat(chat_id=chat_id, rag_handler=self.rag_handler)
|
16 |
+
chat = self.chat_dic[chat_id]
|
17 |
+
return chat.response(message)
|
18 |
+
x
|
model/llm/llm.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import OpenAI
|
2 |
+
import openai
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
|
7 |
+
|
8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
9 |
+
|
10 |
+
from config import *
|
11 |
+
|
12 |
+
|
13 |
+
class LLM_API_Call:
|
14 |
+
|
15 |
+
def __init__(self, type) -> None:
|
16 |
+
if type == "openai":
|
17 |
+
self.llm = OpenAI_API_Call(api_key = LLM_CONFIG[""],
|
18 |
+
model = LLM_CONFIG["model"])
|
19 |
+
elif type == "gilas":
|
20 |
+
self.llm = Gilas_API_Call(api_key = GILAS_CONFIG["api_key"],
|
21 |
+
model = GILAS_CONFIG["model"],
|
22 |
+
base_url=GILAS_CONFIG["base_url"])
|
23 |
+
else:
|
24 |
+
self.llm = OpenAI(
|
25 |
+
**LLM_CONFIG
|
26 |
+
)
|
27 |
+
|
28 |
+
def get_LLM_response(self, prompt: str) -> str:
|
29 |
+
return self.llm.invoke(prompt)
|
30 |
+
|
31 |
+
|
32 |
+
class OpenAI_API_Call:
|
33 |
+
|
34 |
+
def __init__(self, api_key, model="gpt-4"):
|
35 |
+
self.api_key = api_key
|
36 |
+
openai.api_key = api_key
|
37 |
+
self.model = model
|
38 |
+
self.conversation = []
|
39 |
+
|
40 |
+
def add_message(self, role, content):
|
41 |
+
self.conversation.append({"role": role, "content": content})
|
42 |
+
|
43 |
+
def get_response(self):
|
44 |
+
response = openai.ChatCompletion.create(
|
45 |
+
model=self.model,
|
46 |
+
messages=self.conversation
|
47 |
+
)
|
48 |
+
return response['choices'][0]['message']['content']
|
49 |
+
|
50 |
+
def invoke(self, user_input):
|
51 |
+
self.add_message("user", user_input)
|
52 |
+
|
53 |
+
response = self.get_response()
|
54 |
+
|
55 |
+
self.add_message("assistant", response)
|
56 |
+
|
57 |
+
return response
|
58 |
+
|
59 |
+
|
60 |
+
class Gilas_API_Call:
|
61 |
+
def __init__(self, api_key, base_url, model="gpt-4o-mini"):
|
62 |
+
self.api_key = api_key
|
63 |
+
self.base_url = base_url
|
64 |
+
self.model = model
|
65 |
+
self.headers = {
|
66 |
+
"Authorization": f"Bearer {self.api_key}",
|
67 |
+
"Content-Type": "application/json"
|
68 |
+
}
|
69 |
+
self.conversation = []
|
70 |
+
|
71 |
+
def add_message(self, role, content):
|
72 |
+
self.conversation.append({"role": role, "content": content})
|
73 |
+
|
74 |
+
def get_response(self):
|
75 |
+
data = {
|
76 |
+
"model": self.model,
|
77 |
+
"messages": self.conversation
|
78 |
+
}
|
79 |
+
|
80 |
+
response = requests.post(
|
81 |
+
url=f"{self.base_url}/chat/completions",
|
82 |
+
headers=self.headers,
|
83 |
+
json=data
|
84 |
+
)
|
85 |
+
|
86 |
+
# print(f"Response status code: {response.status_code}")
|
87 |
+
# print(f"Response content: {response.text}")
|
88 |
+
|
89 |
+
if response.status_code == 200:
|
90 |
+
try:
|
91 |
+
return response.json()['choices'][0]['message']['content']
|
92 |
+
except (KeyError, IndexError, ValueError) as e:
|
93 |
+
raise Exception(f"Unexpected API response format: {e}")
|
94 |
+
else:
|
95 |
+
raise Exception(f"Gilas API call failed: {response.status_code} - {response.text}")
|
96 |
+
|
97 |
+
def invoke(self, user_input):
|
98 |
+
self.add_message("user", user_input)
|
99 |
+
|
100 |
+
response = self.get_response()
|
101 |
+
|
102 |
+
self.add_message("assistant", response)
|
103 |
+
|
104 |
+
return response
|
105 |
+
|
106 |
+
|
107 |
+
# test = LLM_API_Call(type = "gilas")
|
108 |
+
# print(test.get_LLM_response("سلام"))
|
model/processor/case_crawler.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
class Crawler:
|
8 |
+
# This is used for vote separating when list of vote concatenation in string
|
9 |
+
vote_splitter = " |split| "
|
10 |
+
|
11 |
+
def __init__(self, base_url: str, list_url:str ,
|
12 |
+
base_vote_url:str , models_path: str , result_path:str):
|
13 |
+
if base_url == "":
|
14 |
+
self.base_url ="https://ara.jri.ac.ir/"
|
15 |
+
else:
|
16 |
+
self.base_url = base_url
|
17 |
+
|
18 |
+
if list_url == "":
|
19 |
+
self.list_url ="https://ara.jri.ac.ir/Judge/Index"
|
20 |
+
else:
|
21 |
+
self.list_url = list_url
|
22 |
+
|
23 |
+
if base_vote_url == "":
|
24 |
+
self.base_vote_url ="https://ara.jri.ac.ir/Judge/Text/"
|
25 |
+
else:
|
26 |
+
self.base_vote_url = base_vote_url
|
27 |
+
|
28 |
+
if models_path == "":
|
29 |
+
self.models_path ="Models/"
|
30 |
+
else:
|
31 |
+
self.models_path = models_path
|
32 |
+
self.pos_model_path = os.path.join(models_path, "postagger.model")
|
33 |
+
self.chunker_path = os.path.join(models_path, "chunker.model")
|
34 |
+
|
35 |
+
if result_path == "":
|
36 |
+
self.result_path = "Resource/"
|
37 |
+
else:
|
38 |
+
self.result_path = result_path
|
39 |
+
|
40 |
+
self.merges_vote_path = os.path.join(result_path, 'merged_vote.txt')
|
41 |
+
self.clean_vote_path = os.path.join(result_path, 'clean_vote.txt')
|
42 |
+
self.clean_vote_path_csv = os.path.join(result_path, 'clean_vote.csv')
|
43 |
+
self.selected_vote_path = os.path.join(result_path, 'selected_vote.txt')
|
44 |
+
self.law_list_path = os.path.join(result_path, 'law_list.txt')
|
45 |
+
self.law_clean_list_path = os.path.join(result_path, 'law_clean_list.txt')
|
46 |
+
self.vote_stop_path = os.path.join(result_path, "vote_stopwords.txt")
|
47 |
+
self.law_stop_path = os.path.join(result_path, "law_stopwords.txt")
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def check_valid_vote(html_soup: BeautifulSoup) -> bool:
|
51 |
+
# Extract title for detection of non-valid vote
|
52 |
+
h1_element = html_soup.find('h1', class_='Title3D')
|
53 |
+
if h1_element is None:
|
54 |
+
return False
|
55 |
+
span_text = h1_element.find('span').text # Text within the <span> tag
|
56 |
+
full_text = h1_element.text # Full text within the <h1> element
|
57 |
+
text_after_span = full_text.split(span_text)[-1].strip() # Extract text after the </span> tag
|
58 |
+
return len(text_after_span) > 0
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def html_data_extractor(html_soup: BeautifulSoup, vote_splitter: str) -> str:
|
62 |
+
vote_text = html_soup.find('div', id='treeText', class_='BackText')
|
63 |
+
title = html_soup.find('h1', class_='Title3D')
|
64 |
+
info = html_soup.find('td', valign="top", class_="font-size-small")
|
65 |
+
# for separating each vote in file use vote_splitter
|
66 |
+
vote_df = str(title) + str(info) + str(vote_text) + vote_splitter
|
67 |
+
return vote_df
|
68 |
+
|
69 |
+
def vote_crawler(self, start: int, end: int, separator: int):
|
70 |
+
counter = 0 # For counting right votes crawled
|
71 |
+
result_list = []
|
72 |
+
warnings.filterwarnings("ignore")
|
73 |
+
# Loop for sending request to get each vote page
|
74 |
+
for i in tqdm(range(start, end)):
|
75 |
+
# Save every separator records gotten in .txt format
|
76 |
+
if (counter % separator == 0 and counter > 0) or i == end - 1:
|
77 |
+
text_file = open(os.path.join(self.result_path, f'vote{i}.txt'), "w", encoding='utf-8')
|
78 |
+
text_file.write(''.join(result_list))
|
79 |
+
text_file.close()
|
80 |
+
result_list = []
|
81 |
+
url = self.base_vote_url + f"{i}"
|
82 |
+
response = requests.get(url, verify=False)
|
83 |
+
# Change format for Persian text
|
84 |
+
response.encoding = 'utf-8'
|
85 |
+
resp_text = response.text
|
86 |
+
html_soup = BeautifulSoup(resp_text, 'html.parser')
|
87 |
+
if response.ok and self.check_valid_vote(html_soup):
|
88 |
+
counter += 1
|
89 |
+
vote_df = self.html_data_extractor(html_soup, self.vote_splitter)
|
90 |
+
result_list.append(vote_df)
|
91 |
+
|
92 |
+
def merge_out_txt(self) -> None:
|
93 |
+
|
94 |
+
with open(self.result_path, 'w', encoding='utf-8') as outfile:
|
95 |
+
for filename in os.listdir(self.merges_vote_path):
|
96 |
+
if filename.startswith("vote") and filename.endswith('.txt'): # Only merge vote .txt
|
97 |
+
with open(os.path.join(self.merges_vote_path, filename), 'r', encoding='utf-8') as infile:
|
98 |
+
outfile.write(infile.read())
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
models_path = input("Enter the models path (initial value = https://ara.jri.ac.ir/): ")
|
102 |
+
result_path = input("Enter the result path (initial value = https://ara.jri.ac.ir/Judge/Index): ")
|
103 |
+
base_url = input("Enter the base URL (initial value = https://ara.jri.ac.ir/Judge/Text/): ")
|
104 |
+
list_url = input("Enter the list URL (initial value = Models/ ): ")
|
105 |
+
base_vote_url = input("Enter the base vote URL (initial value = Resource/ ): ")
|
106 |
+
|
107 |
+
crawler_instance = Crawler(models_path=models_path, result_path=result_path, base_url=base_url, list_url=list_url, base_vote_url=base_vote_url)
|
108 |
+
start = int(input("Enter the start value for vote crawling: "))
|
109 |
+
end = int(input("Enter the end value for vote crawling: "))
|
110 |
+
separator = int(input("Enter the separator value for vote crawling: "))
|
111 |
+
|
112 |
+
crawler_instance.vote_crawler(start=start, end=end, separator=separator)
|
113 |
+
crawler_instance.merge_out_txt()
|
model/processor/database_Chunker.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/processor/law_provider.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
|
4 |
+
class LawTxetPreProcessor():
|
5 |
+
|
6 |
+
def __init__(self, law_texts: list) -> None:
|
7 |
+
self._law_texets = law_texts
|
8 |
+
self._law_name_df = pd.DataFrame(columns=["law_index", "law_name"])
|
9 |
+
self._madeh_df = pd.DataFrame(columns=["law_index", "madeh_index", "madeh_text"])
|
10 |
+
self._is_df = False
|
11 |
+
|
12 |
+
def build_df(self):
|
13 |
+
title_list = []
|
14 |
+
madeh_list = []
|
15 |
+
madeh_index = []
|
16 |
+
law_index = []
|
17 |
+
counter = 0
|
18 |
+
for text in self._law_texets:
|
19 |
+
title = self.title_extractor(text)
|
20 |
+
title_list.append(title)
|
21 |
+
temp_madeh_list = self.madeh_extractor(text, title == "قانون اساسی جمهوری اسلامی ایران")
|
22 |
+
law_index.extend([counter for i in temp_madeh_list])
|
23 |
+
madeh_index.extend([i+1 for i in range(len(temp_madeh_list))])
|
24 |
+
madeh_list.extend(temp_madeh_list)
|
25 |
+
counter += 1
|
26 |
+
law_index_list = [i for i in range(counter)]
|
27 |
+
self._madeh_df = pd.DataFrame({"law_index": law_index,
|
28 |
+
"madeh_index": madeh_index,
|
29 |
+
"madeh_text": madeh_list})
|
30 |
+
self._law_name_df = pd.DataFrame({"law_index": law_index_list,
|
31 |
+
"law_name": title_list})
|
32 |
+
|
33 |
+
def title_extractor(self, law_text: str) -> str:
|
34 |
+
first_newline_index = law_text.find('\n')
|
35 |
+
return law_text[:first_newline_index]
|
36 |
+
|
37 |
+
def madeh_extractor(self, law_text: str, is_asl:False)-> list:
|
38 |
+
result = []
|
39 |
+
pattern = r"(^.{0,1}اصل )" if is_asl else r"(^.{0,1}ماده)"
|
40 |
+
removed_regex = r"❯.*\n"
|
41 |
+
notvalid_pattern = r"(^.{0,1}ماده.*مکرر\n)"
|
42 |
+
cleaned_text = re.sub(removed_regex, "", law_text)
|
43 |
+
matches = re.finditer(pattern, cleaned_text, flags=re.MULTILINE)
|
44 |
+
not_valid_matches = re.finditer(notvalid_pattern, cleaned_text, flags=re.MULTILINE)
|
45 |
+
indices = [match.start() for match in matches]
|
46 |
+
not_valid_indices = [match.start() for match in not_valid_matches]
|
47 |
+
valid_indices = [item for item in indices if item not in not_valid_indices]
|
48 |
+
for i in range(len(valid_indices)):
|
49 |
+
start = valid_indices[i]
|
50 |
+
if i != len(valid_indices)-1:
|
51 |
+
end = valid_indices[i+1]
|
52 |
+
result.append(cleaned_text[start:end])
|
53 |
+
else:
|
54 |
+
result.append(cleaned_text[start:])
|
55 |
+
return result
|
56 |
+
|
57 |
+
|
58 |
+
def get_df(self) -> pd.DataFrame:
|
59 |
+
if not self._is_df:
|
60 |
+
self.build_df()
|
61 |
+
return self._law_name_df, self._madeh_df
|
model/processor/pre_process.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/processor/retrieval_rag_nlp_project.ipynb:Zone.Identifier
ADDED
Binary file (27 Bytes). View file
|
|
model/propmt/__init__.py
ADDED
File without changes
|
model/propmt/prompt_handler.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
class Prompt:
|
4 |
+
|
5 |
+
def get_prompt(self, message:str, info_list: List) -> str:
|
6 |
+
prompt = f"As a user, I want to ask you the following legal question:\n{message}\n\n"
|
7 |
+
|
8 |
+
if info_list:
|
9 |
+
prompt += "Here are some relevant legal cases and information you should consider:\n"
|
10 |
+
for i, info in enumerate(info_list):
|
11 |
+
prompt += f"case {i+1}:\n{info['title']}\n{info['text']}\n"
|
12 |
+
|
13 |
+
prompt += "\nBased on the provided information, please respond in Persian(Farsi) with a concise legal analysis.\
|
14 |
+
Ensure that your response is as summarized and clear as possible. (one paragraph)"
|
15 |
+
|
16 |
+
return prompt
|
model/rag/__init__.py
ADDED
File without changes
|
model/rag/rag_handler.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import chromadb
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from chromadb.config import Settings
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
import os
|
10 |
+
from hazm import *
|
11 |
+
|
12 |
+
|
13 |
+
class RAG:
|
14 |
+
def __init__(self,
|
15 |
+
model_name: str = "HooshvareLab/bert-base-parsbert-uncased",
|
16 |
+
collection_name: str = "legal_cases",
|
17 |
+
persist_directory: str = "chromadb_collections/",
|
18 |
+
top_k: int = 2
|
19 |
+
) -> None:
|
20 |
+
|
21 |
+
self.cases_df = pd.read_csv('processed_cases.csv')
|
22 |
+
|
23 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
+
self.model = AutoModel.from_pretrained(model_name)
|
25 |
+
self.normalizer = Normalizer()
|
26 |
+
self.top_k = top_k
|
27 |
+
|
28 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
+
self.model.to(self.device)
|
30 |
+
|
31 |
+
self.client = chromadb.PersistentClient(path=persist_directory)
|
32 |
+
|
33 |
+
self.collection = self.client.get_collection(name=collection_name)
|
34 |
+
|
35 |
+
def query_pre_process(self, query: str) -> str:
|
36 |
+
return self.normalizer.normalize(query)
|
37 |
+
|
38 |
+
def embed_single_text(self, text: str) -> np.ndarray:
|
39 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
40 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = self.model(**inputs)
|
43 |
+
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
44 |
+
|
45 |
+
|
46 |
+
def extract_case_title_from_df(self, case_id: str) -> str:
|
47 |
+
|
48 |
+
case_id_int = int(case_id.split("_")[1])
|
49 |
+
|
50 |
+
try:
|
51 |
+
case_title = self.cases_df.loc[case_id_int, 'title']
|
52 |
+
return case_title
|
53 |
+
except KeyError:
|
54 |
+
return "Case ID not found in DataFrame."
|
55 |
+
|
56 |
+
def extract_case_text_from_df(self, case_id: str) -> str:
|
57 |
+
case_id_int = int(case_id.split("_")[1])
|
58 |
+
|
59 |
+
try:
|
60 |
+
case_text = self.cases_df.loc[case_id_int, 'text']
|
61 |
+
return case_text
|
62 |
+
except KeyError:
|
63 |
+
return "Case ID not found in DataFrame."
|
64 |
+
|
65 |
+
def retrieve_relevant_cases(self, query_text: str) -> List[str]:
|
66 |
+
normalized_query_text = self.query_pre_process(query_text)
|
67 |
+
|
68 |
+
query_embedding = self.embed_single_text(normalized_query_text)
|
69 |
+
query_embedding_list = query_embedding.tolist()
|
70 |
+
|
71 |
+
results = self.collection.query(
|
72 |
+
query_embeddings=[query_embedding_list],
|
73 |
+
n_results=self.top_k
|
74 |
+
)
|
75 |
+
|
76 |
+
retrieved_cases = []
|
77 |
+
for i in range(len(results['metadatas'][0])):
|
78 |
+
case_id = results['ids'][0][i]
|
79 |
+
case_text = self.extract_case_text_from_df(case_id)
|
80 |
+
case_title = self.extract_case_title_from_df(case_id)
|
81 |
+
retrieved_cases.append({
|
82 |
+
"text": case_text,
|
83 |
+
"title": case_title
|
84 |
+
})
|
85 |
+
|
86 |
+
return retrieved_cases
|
87 |
+
|
88 |
+
def get_information(self, query: str) -> List[str]:
|
89 |
+
return self.retrieve_relevant_cases(query)
|
90 |
+
from typing import List
|
91 |
+
|
92 |
+
|
93 |
+
class RAG:
|
94 |
+
|
95 |
+
def __init__(self) -> None:
|
96 |
+
pass
|
97 |
+
|
98 |
+
def get_information(self, query: str) -> List[str]:
|
99 |
+
return []
|
100 |
+
|
101 |
+
def query_pre_process(self, query: str):
|
102 |
+
return query
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#dataset
|
2 |
+
datasets
|
3 |
+
pandas
|
4 |
+
numpy
|
5 |
+
indexed_gzip
|
6 |
+
# json
|
7 |
+
matrix-nio[e2e]
|
8 |
+
opsdroid
|
9 |
+
python-dotenv
|
10 |
+
|
11 |
+
BeautifulSoup4
|
12 |
+
requests
|
13 |
+
tqdm
|
14 |
+
|
15 |
+
hazm
|
16 |
+
spacy
|
17 |
+
|
18 |
+
rank_bm25
|
19 |
+
openai
|
20 |
+
gradio
|
21 |
+
|
22 |
+
langchain_openai
|
23 |
+
sentence-transformers
|
24 |
+
chromadb
|