File size: 4,905 Bytes
472fc23
175d074
 
 
3f6ce03
175d074
 
 
 
 
 
414b472
175d074
 
 
 
3f6ce03
7a4c79a
3f6ce03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c26bdd
3f6ce03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d4afc
414b472
 
175d074
3f6ce03
175d074
 
 
 
414b472
175d074
3f6ce03
 
175d074
 
 
 
 
 
 
3f6ce03
175d074
 
 
 
 
 
414b472
175d074
 
 
 
 
 
 
 
d6a5a31
 
175d074
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6ce03
175d074
 
d6a5a31
175d074
 
 
 
 
 
 
 
 
 
d6a5a31
175d074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6ce03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import spaces
import gradio as gr
from datasets import load_dataset
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz  # PyMuPDF


# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ.get("HF_TOKEN")


# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
ST = SentenceTransformer("jhgan/ko-sroberta-multitask")

# PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    return text

# ๋ฒ•๋ฅ  ๋ฌธ์„œ PDF ๊ฒฝ๋กœ ์ง€์ • ๋ฐ ํ…์ŠคํŠธ ์ถ”์ถœ
pdf_path = "laws.pdf"  # ์—ฌ๊ธฐ์— ์‹ค์ œ PDF ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.
law_text = extract_text_from_pdf(pdf_path)

# ๋ฒ•๋ฅ  ๋ฌธ์„œ ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๊ณ  ์ž„๋ฒ ๋”ฉ
law_sentences = law_text.split('\n')  # Adjust splitting based on your PDF structure
law_embeddings = ST.encode(law_sentences)

# FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์ž„๋ฒ ๋”ฉ ์ถ”๊ฐ€
index = faiss.IndexFlatL2(law_embeddings.shape[1])
index.add(law_embeddings)

# Hugging Face์—์„œ ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]

# ์งˆ๋ฌธ ์ปฌ๋Ÿผ์„ ์ž„๋ฒ ๋”ฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ปฌ๋Ÿผ์— ์ถ”๊ฐ€
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")

# LLaMA ๋ชจ๋ธ ์„ค์ •
model_id = "google/gemma-2-27b-it"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config,
    token=token
)

SYS_PROMPT = """You are an assistant for answering legal questions.
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.
you must answer korean."""

# ๋ฒ•๋ฅ  ๋ฌธ์„œ ๊ฒ€์ƒ‰ 
@spaces.Gpu
def search_law(query, k=5):
    query_embedding = ST.encode([query])
    D, I = index.search(query_embedding, k)
    return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]

# ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
@spaces.Gpu
def search_qa(query, k=3):
    scores, retrieved_examples = data.get_nearest_examples(
        "question_embedding", ST.encode(query), k=k
    )
    return [retrieved_examples["answer"][i] for i in range(k)]

# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def format_prompt(prompt, law_docs, qa_docs):
    PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
    for doc in law_docs:
        PROMPT += f"{doc[0]}\n"  # Assuming doc[0] contains the relevant text
    PROMPT += "\nLegal QA:\n"
    for doc in qa_docs:
        PROMPT += f"{doc}\n"
    return PROMPT

# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
@spaces.Gpu
def talk(prompt, history):
    law_results = search_law(prompt, k=3)
    qa_results = search_qa(prompt, k=3)

    retrieved_law_docs = [result[0] for result in law_results]
    formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
    formatted_prompt = formatted_prompt[:2000]  # GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ”„๋กฌํ”„ํŠธ ์ œํ•œ

    # Adjust the message roles
    messages = [{"role": "user", "content": SYS_PROMPT + "\n" + formatted_prompt}]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        temperature=0.2,
        eos_token_id=tokenizer.eos_token_id,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)
        
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
TITLE = "Legal RAG Chatbot"
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers."""

demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["What are the regulations on data privacy?"]],
    title=TITLE,
    description=DESCRIPTION,
)

# Gradio ๋ฐ๋ชจ ์‹คํ–‰
demo.launch(debug=True)