EnverLee commited on
Commit
afec61d
·
verified ·
1 Parent(s): 4149567

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -61
app.py CHANGED
@@ -11,72 +11,65 @@ import fitz # PyMuPDF
11
  # 환경 변수에서 Hugging Face 토큰 가져오기
12
  token = os.environ.get("HF_TOKEN")
13
 
14
- # 임베딩 모델 Lazy Loading
15
- ST = None
16
- def load_embedding_model():
17
- global ST
18
- if ST is None:
19
- ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
20
- return ST
21
-
22
- # LLaMA 모델 및 토크나이저 Lazy Loading
23
- model = None
24
- tokenizer = None
25
- def load_model():
26
- global model, tokenizer
27
- if model is None or tokenizer is None:
28
- model_id = "google/gemma-2-2b-it"
29
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_id,
32
- torch_dtype=torch.bfloat16,
33
- device_map="auto",
34
- token=token
35
- )
36
- return model, tokenizer
37
-
38
- # PDF에서 텍스트 추출
39
- def extract_text_from_pdf(pdf_path):
40
- doc = fitz.open(pdf_path)
41
- text = ""
42
- for page in doc:
43
- text += page.get_text()
44
- return text
45
-
46
- # 법률 문서 PDF 경로 지정 및 텍스트 추출
47
- pdf_path = "laws.pdf" # 여기에 실제 PDF 경로를 입력하세요.
48
- law_text = extract_text_from_pdf(pdf_path)
49
-
50
- # 법률 문서 텍스트를 문장 단위로 나누고 임베딩
51
- law_sentences = law_text.split('\n') # PDF 구조에 따라 분할을 조정
52
- law_embeddings = load_embedding_model().encode(law_sentences)
53
-
54
- # FAISS 인덱스 생성 및 임베딩 추가
55
- index = faiss.IndexFlatL2(law_embeddings.shape[1])
56
- index.add(law_embeddings)
57
-
58
- # Hugging Face에서 법률 상담 데이터셋 로드 (Lazy Loading)
59
- data = None
60
- def load_dataset_data():
61
- global data
62
- if data is None:
63
- dataset = load_dataset("jihye-moon/LawQA-Ko")
64
- data = dataset["train"]
65
- data = data.map(lambda x: {"question_embedding": load_embedding_model().encode(x["question"])}, batched=True)
66
- data.add_faiss_index(column="question_embedding")
67
- return data
68
 
69
  # 법률 문서 검색 함수
70
  def search_law(query, k=5):
71
- query_embedding = load_embedding_model().encode([query])
 
72
  D, I = index.search(query_embedding, k)
73
  return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
74
 
75
  # 법률 상담 데이터 검색 함수
76
  def search_qa(query, k=3):
77
- dataset_data = load_dataset_data()
78
- scores, retrieved_examples = dataset_data.get_nearest_examples(
79
- "question_embedding", load_embedding_model().encode(query), k=k
80
  )
81
  return [retrieved_examples["answer"][i] for i in range(k)]
82
 
@@ -99,8 +92,6 @@ def talk(prompt, history):
99
  formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
100
  formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한
101
 
102
- model, tokenizer = load_model()
103
-
104
  messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
105
 
106
  # 모델에게 생성 지시
@@ -155,4 +146,3 @@ demo = gr.ChatInterface(
155
 
156
  # Gradio 데모 실행
157
  demo.launch(debug=True, server_port=7860)
158
-
 
11
  # 환경 변수에서 Hugging Face 토큰 가져오기
12
  token = os.environ.get("HF_TOKEN")
13
 
14
+ # 임베딩 모델 로드
15
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
16
+
17
+ # Lazy Loading PDF 텍스트와 임베딩
18
+ law_sentences = None
19
+ law_embeddings = None
20
+ index = None
21
+
22
+ def load_law_data():
23
+ global law_sentences, law_embeddings, index
24
+ if law_sentences is None or law_embeddings is None or index is None:
25
+ # PDF에서 텍스트 추출
26
+ pdf_path = "laws.pdf" # 여기에 실제 PDF 경로를 입력하세요.
27
+ doc = fitz.open(pdf_path)
28
+ law_text = ""
29
+ for page in doc:
30
+ law_text += page.get_text()
31
+
32
+ # 텍스트를 문장 단위로 나누고 임베딩
33
+ law_sentences = law_text.split('\n') # PDF 구조에 따라 분할을 조정
34
+ law_embeddings = ST.encode(law_sentences)
35
+
36
+ # FAISS 인덱스 생성 및 임베딩 추가
37
+ index = faiss.IndexFlatL2(law_embeddings.shape[1])
38
+ index.add(law_embeddings)
39
+
40
+ # Hugging Face에서 법률 상담 데이터셋 로드
41
+ dataset = load_dataset("jihye-moon/LawQA-Ko")
42
+ data = dataset["train"]
43
+
44
+ # 질문 컬럼을 임베딩하여 새로운 컬럼에 추가
45
+ data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
46
+ data.add_faiss_index(column="question_embedding")
47
+
48
+ # LLaMA 모델 설정 (양자화 없이)
49
+ model_id = "google/gemma-2-2b-it"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_id,
53
+ torch_dtype=torch.bfloat16, # 양자화 없이 bfloat16 사용
54
+ device_map="auto",
55
+ token=token
56
+ )
57
+
58
+ SYS_PROMPT = """You are an assistant for answering legal questions.
59
+ ... (이하 생략, 기존 SYS_PROMPT 그대로 유지) ...
60
+ """
 
 
 
 
 
 
 
61
 
62
  # 법률 문서 검색 함수
63
  def search_law(query, k=5):
64
+ load_law_data() # PDF 텍스트와 임베딩 Lazy Loading
65
+ query_embedding = ST.encode([query])
66
  D, I = index.search(query_embedding, k)
67
  return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
68
 
69
  # 법률 상담 데이터 검색 함수
70
  def search_qa(query, k=3):
71
+ scores, retrieved_examples = data.get_nearest_examples(
72
+ "question_embedding", ST.encode(query), k=k
 
73
  )
74
  return [retrieved_examples["answer"][i] for i in range(k)]
75
 
 
92
  formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
93
  formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한
94
 
 
 
95
  messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
96
 
97
  # 모델에게 생성 지시
 
146
 
147
  # Gradio 데모 실행
148
  demo.launch(debug=True, server_port=7860)