Update app.py
Browse files
app.py
CHANGED
@@ -11,72 +11,65 @@ import fitz # PyMuPDF
|
|
11 |
# 환경 변수에서 Hugging Face 토큰 가져오기
|
12 |
token = os.environ.get("HF_TOKEN")
|
13 |
|
14 |
-
# 임베딩 모델
|
15 |
-
ST =
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
global data
|
62 |
-
if data is None:
|
63 |
-
dataset = load_dataset("jihye-moon/LawQA-Ko")
|
64 |
-
data = dataset["train"]
|
65 |
-
data = data.map(lambda x: {"question_embedding": load_embedding_model().encode(x["question"])}, batched=True)
|
66 |
-
data.add_faiss_index(column="question_embedding")
|
67 |
-
return data
|
68 |
|
69 |
# 법률 문서 검색 함수
|
70 |
def search_law(query, k=5):
|
71 |
-
|
|
|
72 |
D, I = index.search(query_embedding, k)
|
73 |
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
|
74 |
|
75 |
# 법률 상담 데이터 검색 함수
|
76 |
def search_qa(query, k=3):
|
77 |
-
|
78 |
-
|
79 |
-
"question_embedding", load_embedding_model().encode(query), k=k
|
80 |
)
|
81 |
return [retrieved_examples["answer"][i] for i in range(k)]
|
82 |
|
@@ -99,8 +92,6 @@ def talk(prompt, history):
|
|
99 |
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
|
100 |
formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한
|
101 |
|
102 |
-
model, tokenizer = load_model()
|
103 |
-
|
104 |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
105 |
|
106 |
# 모델에게 생성 지시
|
@@ -155,4 +146,3 @@ demo = gr.ChatInterface(
|
|
155 |
|
156 |
# Gradio 데모 실행
|
157 |
demo.launch(debug=True, server_port=7860)
|
158 |
-
|
|
|
11 |
# 환경 변수에서 Hugging Face 토큰 가져오기
|
12 |
token = os.environ.get("HF_TOKEN")
|
13 |
|
14 |
+
# 임베딩 모델 로드
|
15 |
+
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
16 |
+
|
17 |
+
# Lazy Loading PDF 텍스트와 임베딩
|
18 |
+
law_sentences = None
|
19 |
+
law_embeddings = None
|
20 |
+
index = None
|
21 |
+
|
22 |
+
def load_law_data():
|
23 |
+
global law_sentences, law_embeddings, index
|
24 |
+
if law_sentences is None or law_embeddings is None or index is None:
|
25 |
+
# PDF에서 텍스트 추출
|
26 |
+
pdf_path = "laws.pdf" # 여기에 실제 PDF 경로를 입력하세요.
|
27 |
+
doc = fitz.open(pdf_path)
|
28 |
+
law_text = ""
|
29 |
+
for page in doc:
|
30 |
+
law_text += page.get_text()
|
31 |
+
|
32 |
+
# 텍스트를 문장 단위로 나누고 임베딩
|
33 |
+
law_sentences = law_text.split('\n') # PDF 구조에 따라 분할을 조정
|
34 |
+
law_embeddings = ST.encode(law_sentences)
|
35 |
+
|
36 |
+
# FAISS 인덱스 생성 및 임베딩 추가
|
37 |
+
index = faiss.IndexFlatL2(law_embeddings.shape[1])
|
38 |
+
index.add(law_embeddings)
|
39 |
+
|
40 |
+
# Hugging Face에서 법률 상담 데이터셋 로드
|
41 |
+
dataset = load_dataset("jihye-moon/LawQA-Ko")
|
42 |
+
data = dataset["train"]
|
43 |
+
|
44 |
+
# 질문 컬럼을 임베딩하여 새로운 컬럼에 추가
|
45 |
+
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
|
46 |
+
data.add_faiss_index(column="question_embedding")
|
47 |
+
|
48 |
+
# LLaMA 모델 설정 (양자화 없이)
|
49 |
+
model_id = "google/gemma-2-2b-it"
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
51 |
+
model = AutoModelForCausalLM.from_pretrained(
|
52 |
+
model_id,
|
53 |
+
torch_dtype=torch.bfloat16, # 양자화 없이 bfloat16 사용
|
54 |
+
device_map="auto",
|
55 |
+
token=token
|
56 |
+
)
|
57 |
+
|
58 |
+
SYS_PROMPT = """You are an assistant for answering legal questions.
|
59 |
+
... (이하 생략, 기존 SYS_PROMPT 그대로 유지) ...
|
60 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
# 법률 문서 검색 함수
|
63 |
def search_law(query, k=5):
|
64 |
+
load_law_data() # PDF 텍스트와 임베딩 Lazy Loading
|
65 |
+
query_embedding = ST.encode([query])
|
66 |
D, I = index.search(query_embedding, k)
|
67 |
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
|
68 |
|
69 |
# 법률 상담 데이터 검색 함수
|
70 |
def search_qa(query, k=3):
|
71 |
+
scores, retrieved_examples = data.get_nearest_examples(
|
72 |
+
"question_embedding", ST.encode(query), k=k
|
|
|
73 |
)
|
74 |
return [retrieved_examples["answer"][i] for i in range(k)]
|
75 |
|
|
|
92 |
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
|
93 |
formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한
|
94 |
|
|
|
|
|
95 |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
96 |
|
97 |
# 모델에게 생성 지시
|
|
|
146 |
|
147 |
# Gradio 데모 실행
|
148 |
demo.launch(debug=True, server_port=7860)
|
|