File size: 4,778 Bytes
175d074
 
 
0fbf7d5
175d074
 
 
 
 
 
 
 
 
afec61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175d074
 
 
afec61d
 
175d074
 
 
 
 
afec61d
 
175d074
 
 
 
 
 
 
0fbf7d5
175d074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5448adf
175d074
 
5448adf
175d074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cba3268
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
from datasets import load_dataset
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz  # PyMuPDF

# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ.get("HF_TOKEN")

# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# Lazy Loading PDF ํ…์ŠคํŠธ์™€ ์ž„๋ฒ ๋”ฉ
law_sentences = None
law_embeddings = None
index = None

def load_law_data():
    global law_sentences, law_embeddings, index
    if law_sentences is None or law_embeddings is None or index is None:
        # PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
        pdf_path = "laws.pdf"  # ์—ฌ๊ธฐ์— ์‹ค์ œ PDF ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.
        doc = fitz.open(pdf_path)
        law_text = ""
        for page in doc:
            law_text += page.get_text()
        
        # ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๊ณ  ์ž„๋ฒ ๋”ฉ
        law_sentences = law_text.split('\n')  # PDF ๊ตฌ์กฐ์— ๋”ฐ๋ผ ๋ถ„ํ• ์„ ์กฐ์ •
        law_embeddings = ST.encode(law_sentences)

        # FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์ž„๋ฒ ๋”ฉ ์ถ”๊ฐ€
        index = faiss.IndexFlatL2(law_embeddings.shape[1])
        index.add(law_embeddings)

# Hugging Face์—์„œ ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]

# ์งˆ๋ฌธ ์ปฌ๋Ÿผ์„ ์ž„๋ฒ ๋”ฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ปฌ๋Ÿผ์— ์ถ”๊ฐ€
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")

# LLaMA ๋ชจ๋ธ ์„ค์ • (์–‘์žํ™” ์—†์ด)
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,  # ์–‘์žํ™” ์—†์ด bfloat16 ์‚ฌ์šฉ
    device_map="auto",
    token=token
)

SYS_PROMPT = """You are an assistant for answering legal questions.
... (์ดํ•˜ ์ƒ๋žต, ๊ธฐ์กด SYS_PROMPT ๊ทธ๋Œ€๋กœ ์œ ์ง€) ...
"""

# ๋ฒ•๋ฅ  ๋ฌธ์„œ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_law(query, k=5):
    load_law_data()  # PDF ํ…์ŠคํŠธ์™€ ์ž„๋ฒ ๋”ฉ Lazy Loading
    query_embedding = ST.encode([query])
    D, I = index.search(query_embedding, k)
    return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]

# ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_qa(query, k=3):
    scores, retrieved_examples = data.get_nearest_examples(
        "question_embedding", ST.encode(query), k=k
    )
    return [retrieved_examples["answer"][i] for i in range(k)]

# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def format_prompt(prompt, law_docs, qa_docs):
    PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
    for doc in law_docs:
        PROMPT += f"{doc[0]}\n"
    PROMPT += "\nLegal QA:\n"
    for doc in qa_docs:
        PROMPT += f"{doc}\n"
    return PROMPT

# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
def talk(prompt, history):
    law_results = search_law(prompt, k=3)
    qa_results = search_qa(prompt, k=3)

    retrieved_law_docs = [result[0] for result in law_results]
    formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
    formatted_prompt = formatted_prompt[:2000]  # GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ”„๋กฌํ”„ํŠธ ์ œํ•œ

    messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]

    # ๋ชจ๋ธ์—๊ฒŒ ์ƒ์„ฑ ์ง€์‹œ
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=64,
        do_sample=True,
        top_p=0.95,
        temperature=0.2,
        eos_token_id=tokenizer.eos_token_id,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
TITLE = "Legal RAG Chatbot"
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers."""

demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["What are the regulations on data privacy?"]],
    title=TITLE,
    description=DESCRIPTION,
)

# Gradio ๋ฐ๋ชจ ์‹คํ–‰
demo.launch(debug=True, server_port=7860)