laweyedev / app.py
EnverLee's picture
Update app.py
49dac4b verified
raw
history blame
5.17 kB
import spaces
import gradio as gr
from datasets import load_dataset
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz # PyMuPDF
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ.get("HF_TOKEN")
# Lazy Loading ๋ณ€์ˆ˜
ST = None
model = None
tokenizer = None
law_sentences = None
law_embeddings = None
index = None
data = None
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ Lazy Loading
def load_embedding_model():
global ST
if ST is None:
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
return ST
# LLaMA ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € Lazy Loading
@spaces.GPU(duration=120)
def load_model():
global model, tokenizer
if model is None or tokenizer is None:
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
token=token
)
return model, tokenizer
# PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ ๋ฐ ์ž„๋ฒ ๋”ฉ Lazy Loading
def load_law_data():
global law_sentences, law_embeddings, index
if law_sentences is None or law_embeddings is None or index is None:
pdf_path = "laws.pdf" # ์—ฌ๊ธฐ์— ์‹ค์ œ PDF ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.
doc = fitz.open(pdf_path)
law_text = ""
for page in doc:
law_text += page.get_text()
law_sentences = law_text.split('\n') # PDF ๊ตฌ์กฐ์— ๋”ฐ๋ผ ๋ถ„ํ• ์„ ์กฐ์ •
law_embeddings = load_embedding_model().encode(law_sentences)
# FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์ž„๋ฒ ๋”ฉ ์ถ”๊ฐ€
index = faiss.IndexFlatL2(law_embeddings.shape[1])
index.add(law_embeddings)
# Hugging Face์—์„œ ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ (Lazy Loading)
def load_dataset_data():
global data
if data is None:
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]
data = data.map(lambda x: {"question_embedding": load_embedding_model().encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")
return data
# ๋ฒ•๋ฅ  ๋ฌธ์„œ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_law(query, k=5):
load_law_data() # PDF ํ…์ŠคํŠธ์™€ ์ž„๋ฒ ๋”ฉ Lazy Loading
query_embedding = load_embedding_model().encode([query])
D, I = index.search(query_embedding, k)
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
# ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_qa(query, k=3):
dataset_data = load_dataset_data()
scores, retrieved_examples = dataset_data.get_nearest_examples(
"question_embedding", load_embedding_model().encode(query), k=k
)
return [retrieved_examples["answer"][i] for i in range(k)]
# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def format_prompt(prompt, law_docs, qa_docs):
PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
for doc in law_docs:
PROMPT += f"{doc[0]}\n"
PROMPT += "\nLegal QA:\n"
for doc in qa_docs:
PROMPT += f"{doc}\n"
return PROMPT
# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
@spaces.GPU(duration=120)
def talk(prompt, history):
law_results = search_law(prompt, k=3)
qa_results = search_qa(prompt, k=3)
retrieved_law_docs = [result[0] for result in law_results]
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
formatted_prompt = formatted_prompt[:2000] # GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ”„๋กฌํ”„ํŠธ ์ œํ•œ
model, tokenizer = load_model()
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
# ๋ชจ๋ธ์—๊ฒŒ ์ƒ์„ฑ ์ง€์‹œ
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=64,
do_sample=True,
top_p=0.95,
temperature=0.2,
eos_token_id=tokenizer.eos_token_id,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
TITLE = "Legal RAG Chatbot"
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers."""
demo = gr.ChatInterface(
fn=talk,
chatbot=gr.Chatbot(
show_label=True,
show_share_button=True,
show_copy_button=True,
likeable=True,
layout="bubble",
bubble_full_width=False,
),
theme="Soft",
examples=[["What are the regulations on data privacy?"]],
title=TITLE,
description=DESCRIPTION,
)
# Gradio ๋ฐ๋ชจ ์‹คํ–‰
demo.launch(debug=True, server_port=7860)