|
import spaces |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import fitz |
|
|
|
|
|
token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
ST = None |
|
model = None |
|
tokenizer = None |
|
law_sentences = None |
|
law_embeddings = None |
|
index = None |
|
data = None |
|
|
|
|
|
def load_embedding_model(): |
|
global ST |
|
if ST is None: |
|
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") |
|
return ST |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def load_model(): |
|
global model, tokenizer |
|
if model is None or tokenizer is None: |
|
model_id = "google/gemma-2-2b-it" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
token=token |
|
) |
|
return model, tokenizer |
|
|
|
|
|
def load_law_data(): |
|
global law_sentences, law_embeddings, index |
|
if law_sentences is None or law_embeddings is None or index is None: |
|
pdf_path = "laws.pdf" |
|
doc = fitz.open(pdf_path) |
|
law_text = "" |
|
for page in doc: |
|
law_text += page.get_text() |
|
|
|
law_sentences = law_text.split('\n') |
|
law_embeddings = load_embedding_model().encode(law_sentences) |
|
|
|
|
|
index = faiss.IndexFlatL2(law_embeddings.shape[1]) |
|
index.add(law_embeddings) |
|
|
|
|
|
def load_dataset_data(): |
|
global data |
|
if data is None: |
|
dataset = load_dataset("jihye-moon/LawQA-Ko") |
|
data = dataset["train"] |
|
data = data.map(lambda x: {"question_embedding": load_embedding_model().encode(x["question"])}, batched=True) |
|
data.add_faiss_index(column="question_embedding") |
|
return data |
|
|
|
|
|
|
|
def search_law(query, k=5): |
|
load_law_data() |
|
query_embedding = load_embedding_model().encode([query]) |
|
D, I = index.search(query_embedding, k) |
|
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])] |
|
|
|
|
|
def search_qa(query, k=3): |
|
dataset_data = load_dataset_data() |
|
scores, retrieved_examples = dataset_data.get_nearest_examples( |
|
"question_embedding", load_embedding_model().encode(query), k=k |
|
) |
|
return [retrieved_examples["answer"][i] for i in range(k)] |
|
|
|
|
|
def format_prompt(prompt, law_docs, qa_docs): |
|
PROMPT = f"Question: {prompt}\n\nLegal Context:\n" |
|
for doc in law_docs: |
|
PROMPT += f"{doc[0]}\n" |
|
PROMPT += "\nLegal QA:\n" |
|
for doc in qa_docs: |
|
PROMPT += f"{doc}\n" |
|
return PROMPT |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def talk(prompt, history): |
|
law_results = search_law(prompt, k=3) |
|
qa_results = search_qa(prompt, k=3) |
|
|
|
retrieved_law_docs = [result[0] for result in law_results] |
|
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results) |
|
formatted_prompt = formatted_prompt[:2000] |
|
|
|
model, tokenizer = load_model() |
|
|
|
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}] |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True |
|
) |
|
|
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=64, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.2, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
|
|
TITLE = "Legal RAG Chatbot" |
|
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation. |
|
This chatbot can search legal documents and previous legal QA pairs to provide answers.""" |
|
|
|
demo = gr.ChatInterface( |
|
fn=talk, |
|
chatbot=gr.Chatbot( |
|
show_label=True, |
|
show_share_button=True, |
|
show_copy_button=True, |
|
likeable=True, |
|
layout="bubble", |
|
bubble_full_width=False, |
|
), |
|
theme="Soft", |
|
examples=[["What are the regulations on data privacy?"]], |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
) |
|
|
|
|
|
demo.launch(debug=True, server_port=7860) |
|
|