Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import time | |
import pdfplumber | |
from dotenv import load_dotenv | |
import torch | |
from transformers import ( | |
BertJapaneseTokenizer, | |
BertModel, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
BitsAndBytesConfig | |
) | |
from langchain_community.vectorstores import FAISS # 修正 | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.llms import HuggingFacePipeline # 修正 | |
from langchain_community.embeddings import HuggingFaceEmbeddings # 修正 | |
from langchain_huggingface import HuggingFaceEndpoint | |
# Pydanticの警告を無視 | |
import warnings | |
warnings.filterwarnings( | |
"ignore", | |
message=r"Field \"model_name\" in HuggingFaceInferenceAPIEmbeddings has conflict with protected namespace" | |
) | |
load_dotenv() | |
list_llm = [ | |
"meta-llama/Meta-Llama-3-8B-Instruct", | |
"rinna/llama-3-youko-8b", | |
] | |
list_llm_simple = [os.path.basename(llm) for llm in list_llm] | |
# 日本語PDFのテキスト抽出 | |
def extract_text_from_pdf(file_path): | |
with pdfplumber.open(file_path) as pdf: | |
pages = [page.extract_text() for page in pdf.pages] | |
return " ".join(pages) | |
# モデルとトークナイザの初期化 | |
tokenizer_bert = BertJapaneseTokenizer.from_pretrained( | |
'cl-tohoku/bert-base-japanese', | |
clean_up_tokenization_spaces=True | |
) | |
model_bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese') | |
def split_text_simple(text, chunk_size=1024): | |
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
def create_db(splits): | |
embeddings = HuggingFaceEmbeddings( | |
model_name='sonoisa/sentence-bert-base-ja-mean-tokens' | |
) | |
vectordb = FAISS.from_texts(splits, embeddings) | |
return vectordb | |
def initialize_llmchain( | |
llm_model, | |
temperature, | |
max_tokens, | |
top_k, | |
vector_db, | |
retries=5, | |
delay=5 | |
): | |
attempt = 0 | |
while attempt < retries: | |
try: | |
# ローカルモデルの場合 | |
if "rinna" in llm_model.lower(): | |
# デバイスの自動検出 | |
if torch.cuda.is_available(): | |
device_map = "auto" | |
torch_dtype = torch.float16 | |
# GPUがある場合は量子化を使用 | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
llm_model, | |
device_map=device_map, | |
quantization_config=quantization_config | |
) | |
else: | |
device_map = {"": "cpu"} | |
torch_dtype = torch.float32 | |
# CPUの場合は量子化を使用せずにモデルをロード | |
model = AutoModelForCausalLM.from_pretrained( | |
llm_model, | |
device_map=device_map, | |
torch_dtype=torch_dtype | |
) | |
tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=max_tokens, | |
temperature=temperature | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# エンドポイントモデルの場合 | |
elif "meta-llama" in llm_model.lower() or "mistralai" in llm_model.lower(): | |
# パラメータを直接指定 | |
llm = HuggingFaceEndpoint( | |
endpoint_url=f"https://api-inference.huggingface.co/models/{llm_model}", | |
huggingfacehub_api_token=os.getenv("HF_TOKEN"), | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k | |
) | |
else: | |
# その他のモデルの場合(必要に応じて追加) | |
raise Exception(f"Unsupported model: {llm_model}") | |
# 共通の処理 | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key='answer', | |
return_messages=True | |
) | |
retriever = vector_db.as_retriever() | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True, | |
verbose=False | |
) | |
return qa_chain | |
except Exception as e: | |
if "Could not authenticate with huggingface_hub" in str(e): | |
time.sleep(delay) | |
attempt += 1 | |
else: | |
raise Exception(f"Error initializing QA chain: {str(e)}") | |
raise Exception(f"Failed to initialize after {retries} attempts") | |
def process_pdf(file): | |
try: | |
if file is None: | |
return None, "Please upload a PDF file." | |
text = extract_text_from_pdf(file.name) | |
splits = split_text_simple(text) | |
vdb = create_db(splits) | |
return vdb, "PDF processed and vector database created." | |
except Exception as e: | |
return None, f"Error processing PDF: {str(e)}" | |
def initialize_qa_chain( | |
llm_index, | |
temperature, | |
max_tokens, | |
top_k, | |
vector_db | |
): | |
try: | |
if vector_db is None: | |
return None, "Please process a PDF first." | |
llm_name = list_llm[llm_index] | |
chain = initialize_llmchain( | |
llm_name, | |
temperature, | |
max_tokens, | |
top_k, | |
vector_db | |
) | |
return chain, "QA Chatbot initialized with selected LLM." | |
except Exception as e: | |
return None, f"Error initializing QA chain: {str(e)}" | |
def update_chat(msg, history, chain): | |
try: | |
if chain is None: | |
return history + [("User", msg), ("Assistant", "Please initialize the QA Chatbot first.")] | |
response = chain({"question": msg, "chat_history": history}) | |
return history + [("User", msg), ("Assistant", response['answer'])] | |
except Exception as e: | |
return history + [("User", msg), ("Assistant", f"Error: {str(e)}")] | |
def demo(): | |
with gr.Blocks() as demo: | |
vector_db = gr.State(value=None) | |
qa_chain = gr.State(value=None) | |
with gr.Tab("Step 1 - Upload and Process"): | |
with gr.Row(): | |
document = gr.File(label="Upload your Japanese PDF document", file_types=["pdf"]) | |
with gr.Row(): | |
process_btn = gr.Button("Process PDF") | |
process_output = gr.Textbox(label="Processing Output") | |
with gr.Tab("Step 2 - Initialize QA Chatbot"): | |
with gr.Row(): | |
llm_btn = gr.Radio(list_llm_simple, label="Select LLM Model", type="index") | |
llm_temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7) | |
max_tokens = gr.Slider(minimum=128, maximum=2048, step=128, label="Max Tokens", value=1024) | |
top_k = gr.Slider(minimum=1, maximum=10, step=1, label="Top K", value=3) | |
with gr.Row(): | |
init_qa_btn = gr.Button("Initialize QA Chatbot") | |
init_output = gr.Textbox(label="Initialization Output") | |
with gr.Tab("Step 3 - Chat with your Document"): | |
chatbot = gr.Chatbot() | |
message = gr.Textbox(label="Ask a question") | |
with gr.Row(): | |
send_btn = gr.Button("Send") | |
clear_chat_btn = gr.Button("Clear Chat") | |
reset_all_btn = gr.Button("Reset All") | |
process_btn.click( | |
process_pdf, | |
inputs=[document], | |
outputs=[vector_db, process_output] | |
) | |
init_qa_btn.click( | |
initialize_qa_chain, | |
inputs=[llm_btn, llm_temperature, max_tokens, top_k, vector_db], | |
outputs=[qa_chain, init_output] | |
) | |
send_btn.click( | |
update_chat, | |
inputs=[message, chatbot, qa_chain], | |
outputs=[chatbot] | |
) | |
# Clear Chatボタン:チャット履歴のみをクリア | |
clear_chat_btn.click( | |
lambda: None, | |
outputs=[chatbot] | |
) | |
# Reset Allボタン:チャット履歴、PDFデータ、チャットボットの状態をすべてクリア | |
reset_all_btn.click( | |
lambda: (None, None, None), | |
outputs=[chatbot, vector_db, qa_chain] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo().launch() | |