import gradio as gr from datasets import load_dataset import os from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread from sentence_transformers import SentenceTransformer import faiss import fitz # PyMuPDF # 환경 변수에서 Hugging Face 토큰 가져오기 token = os.environ.get("HF_TOKEN") # 임베딩 모델 로드 ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") # Lazy Loading PDF 텍스트와 임베딩 law_sentences = None law_embeddings = None index = None def load_law_data(): global law_sentences, law_embeddings, index if law_sentences is None or law_embeddings is None or index is None: # PDF에서 텍스트 추출 pdf_path = "laws.pdf" # 여기에 실제 PDF 경로를 입력하세요. doc = fitz.open(pdf_path) law_text = "" for page in doc: law_text += page.get_text() # 텍스트를 문장 단위로 나누고 임베딩 law_sentences = law_text.split('\n') # PDF 구조에 따라 분할을 조정 law_embeddings = ST.encode(law_sentences) # FAISS 인덱스 생성 및 임베딩 추가 index = faiss.IndexFlatL2(law_embeddings.shape[1]) index.add(law_embeddings) # Hugging Face에서 법률 상담 데이터셋 로드 dataset = load_dataset("jihye-moon/LawQA-Ko") data = dataset["train"] # 질문 컬럼을 임베딩하여 새로운 컬럼에 추가 data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True) data.add_faiss_index(column="question_embedding") # LLaMA 모델 설정 (양자화 없이) model_id = "google/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, # 양자화 없이 bfloat16 사용 device_map="auto", token=token ) SYS_PROMPT = """You are an assistant for answering legal questions. ... (이하 생략, 기존 SYS_PROMPT 그대로 유지) ... """ # 법률 문서 검색 함수 def search_law(query, k=5): load_law_data() # PDF 텍스트와 임베딩 Lazy Loading query_embedding = ST.encode([query]) D, I = index.search(query_embedding, k) return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])] # 법률 상담 데이터 검색 함수 def search_qa(query, k=3): scores, retrieved_examples = data.get_nearest_examples( "question_embedding", ST.encode(query), k=k ) return [retrieved_examples["answer"][i] for i in range(k)] # 최종 프롬프트 생성 def format_prompt(prompt, law_docs, qa_docs): PROMPT = f"Question: {prompt}\n\nLegal Context:\n" for doc in law_docs: PROMPT += f"{doc[0]}\n" PROMPT += "\nLegal QA:\n" for doc in qa_docs: PROMPT += f"{doc}\n" return PROMPT # 챗봇 응답 함수 def talk(prompt, history): law_results = search_law(prompt, k=3) qa_results = search_qa(prompt, k=3) retrieved_law_docs = [result[0] for result in law_results] formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results) formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한 messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}] # 모델에게 생성 지시 input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=64, do_sample=True, top_p=0.95, temperature=0.2, eos_token_id=tokenizer.eos_token_id, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Gradio 인터페이스 설정 TITLE = "Legal RAG Chatbot" DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation. This chatbot can search legal documents and previous legal QA pairs to provide answers.""" demo = gr.ChatInterface( fn=talk, chatbot=gr.Chatbot( show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False, ), theme="Soft", examples=[["What are the regulations on data privacy?"]], title=TITLE, description=DESCRIPTION, ) # Gradio 데모 실행 demo.launch(debug=True, server_port=7860)