import gradio as gr import os import time import pdfplumber from dotenv import load_dotenv import torch from transformers import ( BertJapaneseTokenizer, BertModel, AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig ) from langchain_community.vectorstores import FAISS # 修正 from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain_community.llms import HuggingFacePipeline # 修正 from langchain_community.embeddings import HuggingFaceEmbeddings # 修正 from langchain_huggingface import HuggingFaceEndpoint # Pydanticの警告を無視 import warnings warnings.filterwarnings( "ignore", message=r"Field \"model_name\" in HuggingFaceInferenceAPIEmbeddings has conflict with protected namespace" ) load_dotenv() list_llm = [ "meta-llama/Meta-Llama-3-8B-Instruct", "rinna/llama-3-youko-8b", ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # 日本語PDFのテキスト抽出 def extract_text_from_pdf(file_path): with as pdf: pages = [page.extract_text() for page in pdf.pages] return " ".join(pages) # モデルとトークナイザの初期化 tokenizer_bert = BertJapaneseTokenizer.from_pretrained( 'cl-tohoku/bert-base-japanese', clean_up_tokenization_spaces=True ) model_bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese') def split_text_simple(text, chunk_size=1024): return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] def create_db(splits): embeddings = HuggingFaceEmbeddings( model_name='sonoisa/sentence-bert-base-ja-mean-tokens' ) vectordb = FAISS.from_texts(splits, embeddings) return vectordb def initialize_llmchain( llm_model, temperature, max_tokens, top_k, vector_db, retries=5, delay=5 ): attempt = 0 while attempt < retries: try: # ローカルモデルの場合 if "rinna" in llm_model.lower(): # デバイスの自動検出 if torch.cuda.is_available(): device_map = "auto" torch_dtype = torch.float16 # GPUがある場合は量子化を使用 quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model = AutoModelForCausalLM.from_pretrained( llm_model, device_map=device_map, quantization_config=quantization_config ) else: device_map = {"": "cpu"} torch_dtype = torch.float32 # CPUの場合は量子化を使用せずにモデルをロード model = AutoModelForCausalLM.from_pretrained( llm_model, device_map=device_map, torch_dtype=torch_dtype ) tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature ) llm = HuggingFacePipeline(pipeline=pipe) # エンドポイントモデルの場合 elif "meta-llama" in llm_model.lower() or "mistralai" in llm_model.lower(): # パラメータを直接指定 llm = HuggingFaceEndpoint( endpoint_url=f"{llm_model}", huggingfacehub_api_token=os.getenv("HF_TOKEN"), temperature=temperature, max_new_tokens=max_tokens, top_k=top_k ) else: # その他のモデルの場合(必要に応じて追加) raise Exception(f"Unsupported model: {llm_model}") # 共通の処理 memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, memory=memory, return_source_documents=True, verbose=False ) return qa_chain except Exception as e: if "Could not authenticate with huggingface_hub" in str(e): time.sleep(delay) attempt += 1 else: raise Exception(f"Error initializing QA chain: {str(e)}") raise Exception(f"Failed to initialize after {retries} attempts") def process_pdf(file): try: if file is None: return None, "Please upload a PDF file." text = extract_text_from_pdf( splits = split_text_simple(text) vdb = create_db(splits) return vdb, "PDF processed and vector database created." except Exception as e: return None, f"Error processing PDF: {str(e)}" def initialize_qa_chain( llm_index, temperature, max_tokens, top_k, vector_db ): try: if vector_db is None: return None, "Please process a PDF first." llm_name = list_llm[llm_index] chain = initialize_llmchain( llm_name, temperature, max_tokens, top_k, vector_db ) return chain, "QA Chatbot initialized with selected LLM." except Exception as e: return None, f"Error initializing QA chain: {str(e)}" def update_chat(msg, history, chain): try: if chain is None: return history + [("User", msg), ("Assistant", "Please initialize the QA Chatbot first.")] response = chain({"question": msg, "chat_history": history}) return history + [("User", msg), ("Assistant", response['answer'])] except Exception as e: return history + [("User", msg), ("Assistant", f"Error: {str(e)}")] def demo(): with gr.Blocks() as demo: vector_db = gr.State(value=None) qa_chain = gr.State(value=None) with gr.Tab("Step 1 - Upload and Process"): with gr.Row(): document = gr.File(label="Upload your Japanese PDF document", file_types=["pdf"]) with gr.Row(): process_btn = gr.Button("Process PDF") process_output = gr.Textbox(label="Processing Output") with gr.Tab("Step 2 - Initialize QA Chatbot"): with gr.Row(): llm_btn = gr.Radio(list_llm_simple, label="Select LLM Model", type="index") llm_temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7) max_tokens = gr.Slider(minimum=128, maximum=2048, step=128, label="Max Tokens", value=1024) top_k = gr.Slider(minimum=1, maximum=10, step=1, label="Top K", value=3) with gr.Row(): init_qa_btn = gr.Button("Initialize QA Chatbot") init_output = gr.Textbox(label="Initialization Output") with gr.Tab("Step 3 - Chat with your Document"): chatbot = gr.Chatbot() message = gr.Textbox(label="Ask a question") with gr.Row(): send_btn = gr.Button("Send") clear_chat_btn = gr.Button("Clear Chat") reset_all_btn = gr.Button("Reset All") process_pdf, inputs=[document], outputs=[vector_db, process_output] ) initialize_qa_chain, inputs=[llm_btn, llm_temperature, max_tokens, top_k, vector_db], outputs=[qa_chain, init_output] ) update_chat, inputs=[message, chatbot, qa_chain], outputs=[chatbot] ) # Clear Chatボタン:チャット履歴のみをクリア lambda: None, outputs=[chatbot] ) # Reset Allボタン:チャット履歴、PDFデータ、チャットボットの状態をすべてクリア lambda: (None, None, None), outputs=[chatbot, vector_db, qa_chain] ) return demo if __name__ == "__main__": demo().launch()