Update app.py
Browse files
app.py
CHANGED
@@ -11,65 +11,75 @@ import fitz # PyMuPDF
|
|
11 |
# ํ๊ฒฝ ๋ณ์์์ Hugging Face ํ ํฐ ๊ฐ์ ธ์ค๊ธฐ
|
12 |
token = os.environ.get("HF_TOKEN")
|
13 |
|
14 |
-
#
|
15 |
-
ST =
|
16 |
-
|
17 |
-
|
18 |
law_sentences = None
|
19 |
law_embeddings = None
|
20 |
index = None
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def load_law_data():
|
23 |
global law_sentences, law_embeddings, index
|
24 |
if law_sentences is None or law_embeddings is None or index is None:
|
25 |
-
# PDF์์ ํ
์คํธ ์ถ์ถ
|
26 |
pdf_path = "laws.pdf" # ์ฌ๊ธฐ์ ์ค์ PDF ๊ฒฝ๋ก๋ฅผ ์
๋ ฅํ์ธ์.
|
27 |
doc = fitz.open(pdf_path)
|
28 |
law_text = ""
|
29 |
for page in doc:
|
30 |
law_text += page.get_text()
|
31 |
|
32 |
-
# ํ
์คํธ๋ฅผ ๋ฌธ์ฅ ๋จ์๋ก ๋๋๊ณ ์๋ฒ ๋ฉ
|
33 |
law_sentences = law_text.split('\n') # PDF ๊ตฌ์กฐ์ ๋ฐ๋ผ ๋ถํ ์ ์กฐ์
|
34 |
-
law_embeddings =
|
35 |
|
36 |
# FAISS ์ธ๋ฑ์ค ์์ฑ ๋ฐ ์๋ฒ ๋ฉ ์ถ๊ฐ
|
37 |
index = faiss.IndexFlatL2(law_embeddings.shape[1])
|
38 |
index.add(law_embeddings)
|
39 |
|
40 |
-
# Hugging Face์์ ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ์
๋ก๋
|
41 |
-
|
42 |
-
data
|
43 |
-
|
44 |
-
|
45 |
-
data =
|
46 |
-
data.
|
47 |
-
|
48 |
-
|
49 |
-
model_id = "google/gemma-2-2b-it"
|
50 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
51 |
-
model = AutoModelForCausalLM.from_pretrained(
|
52 |
-
model_id,
|
53 |
-
torch_dtype=torch.bfloat16, # ์์ํ ์์ด bfloat16 ์ฌ์ฉ
|
54 |
-
device_map="auto",
|
55 |
-
token=token
|
56 |
-
)
|
57 |
-
|
58 |
-
SYS_PROMPT = """You are an assistant for answering legal questions.
|
59 |
-
... (์ดํ ์๋ต, ๊ธฐ์กด SYS_PROMPT ๊ทธ๋๋ก ์ ์ง) ...
|
60 |
-
"""
|
61 |
|
62 |
# ๋ฒ๋ฅ ๋ฌธ์ ๊ฒ์ ํจ์
|
63 |
def search_law(query, k=5):
|
64 |
load_law_data() # PDF ํ
์คํธ์ ์๋ฒ ๋ฉ Lazy Loading
|
65 |
-
query_embedding =
|
66 |
D, I = index.search(query_embedding, k)
|
67 |
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
|
68 |
|
69 |
# ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ ๊ฒ์ ํจ์
|
70 |
def search_qa(query, k=3):
|
71 |
-
|
72 |
-
|
|
|
73 |
)
|
74 |
return [retrieved_examples["answer"][i] for i in range(k)]
|
75 |
|
@@ -92,6 +102,8 @@ def talk(prompt, history):
|
|
92 |
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
|
93 |
formatted_prompt = formatted_prompt[:2000] # GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ ํผํ๊ธฐ ์ํด ํ๋กฌํํธ ์ ํ
|
94 |
|
|
|
|
|
95 |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
96 |
|
97 |
# ๋ชจ๋ธ์๊ฒ ์์ฑ ์ง์
|
|
|
11 |
# ํ๊ฒฝ ๋ณ์์์ Hugging Face ํ ํฐ ๊ฐ์ ธ์ค๊ธฐ
|
12 |
token = os.environ.get("HF_TOKEN")
|
13 |
|
14 |
+
# Lazy Loading ๋ณ์
|
15 |
+
ST = None
|
16 |
+
model = None
|
17 |
+
tokenizer = None
|
18 |
law_sentences = None
|
19 |
law_embeddings = None
|
20 |
index = None
|
21 |
+
data = None
|
22 |
+
|
23 |
+
# ์๋ฒ ๋ฉ ๋ชจ๋ธ Lazy Loading
|
24 |
+
def load_embedding_model():
|
25 |
+
global ST
|
26 |
+
if ST is None:
|
27 |
+
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
28 |
+
return ST
|
29 |
+
|
30 |
+
# LLaMA ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ Lazy Loading
|
31 |
+
def load_model():
|
32 |
+
global model, tokenizer
|
33 |
+
if model is None or tokenizer is None:
|
34 |
+
model_id = "google/gemma-2-2b-it"
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
36 |
+
model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
model_id,
|
38 |
+
torch_dtype=torch.bfloat16,
|
39 |
+
device_map="auto",
|
40 |
+
token=token
|
41 |
+
)
|
42 |
+
return model, tokenizer
|
43 |
+
|
44 |
+
# PDF์์ ํ
์คํธ ์ถ์ถ ๋ฐ ์๋ฒ ๋ฉ Lazy Loading
|
45 |
def load_law_data():
|
46 |
global law_sentences, law_embeddings, index
|
47 |
if law_sentences is None or law_embeddings is None or index is None:
|
|
|
48 |
pdf_path = "laws.pdf" # ์ฌ๊ธฐ์ ์ค์ PDF ๊ฒฝ๋ก๋ฅผ ์
๋ ฅํ์ธ์.
|
49 |
doc = fitz.open(pdf_path)
|
50 |
law_text = ""
|
51 |
for page in doc:
|
52 |
law_text += page.get_text()
|
53 |
|
|
|
54 |
law_sentences = law_text.split('\n') # PDF ๊ตฌ์กฐ์ ๋ฐ๋ผ ๋ถํ ์ ์กฐ์
|
55 |
+
law_embeddings = load_embedding_model().encode(law_sentences)
|
56 |
|
57 |
# FAISS ์ธ๋ฑ์ค ์์ฑ ๋ฐ ์๋ฒ ๋ฉ ์ถ๊ฐ
|
58 |
index = faiss.IndexFlatL2(law_embeddings.shape[1])
|
59 |
index.add(law_embeddings)
|
60 |
|
61 |
+
# Hugging Face์์ ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ์
๋ก๋ (Lazy Loading)
|
62 |
+
def load_dataset_data():
|
63 |
+
global data
|
64 |
+
if data is None:
|
65 |
+
dataset = load_dataset("jihye-moon/LawQA-Ko")
|
66 |
+
data = dataset["train"]
|
67 |
+
data = data.map(lambda x: {"question_embedding": load_embedding_model().encode(x["question"])}, batched=True)
|
68 |
+
data.add_faiss_index(column="question_embedding")
|
69 |
+
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# ๋ฒ๋ฅ ๋ฌธ์ ๊ฒ์ ํจ์
|
72 |
def search_law(query, k=5):
|
73 |
load_law_data() # PDF ํ
์คํธ์ ์๋ฒ ๋ฉ Lazy Loading
|
74 |
+
query_embedding = load_embedding_model().encode([query])
|
75 |
D, I = index.search(query_embedding, k)
|
76 |
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
|
77 |
|
78 |
# ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ ๊ฒ์ ํจ์
|
79 |
def search_qa(query, k=3):
|
80 |
+
dataset_data = load_dataset_data()
|
81 |
+
scores, retrieved_examples = dataset_data.get_nearest_examples(
|
82 |
+
"question_embedding", load_embedding_model().encode(query), k=k
|
83 |
)
|
84 |
return [retrieved_examples["answer"][i] for i in range(k)]
|
85 |
|
|
|
102 |
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
|
103 |
formatted_prompt = formatted_prompt[:2000] # GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ ํผํ๊ธฐ ์ํด ํ๋กฌํํธ ์ ํ
|
104 |
|
105 |
+
model, tokenizer = load_model()
|
106 |
+
|
107 |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
108 |
|
109 |
# ๋ชจ๋ธ์๊ฒ ์์ฑ ์ง์
|