import os import re from hashlib import blake2b from tempfile import NamedTemporaryFile import dotenv from grobid_quantities.quantities import QuantitiesAPI from langchain.memory import ConversationBufferWindowMemory # from langchain_community.callbacks import PromptLayerCallbackHandler from langchain_community.chat_models import ChatOpenAI from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_huggingface import HuggingFaceEmbeddings from langchain_openai import OpenAIEmbeddings from streamlit_pdf_viewer import pdf_viewer from document_qa.ner_client_generic import NERClientGeneric dotenv.load_dotenv(override=True) import streamlit as st from document_qa.document_qa_engine import DocumentQAEngine, DataStorage from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations OPENAI_MODELS = ['gpt-3.5-turbo', "gpt-4", "gpt-4-1106-preview"] OPENAI_EMBEDDINGS = [ 'text-embedding-ada-002', 'text-embedding-3-large', 'openai-text-embedding-3-small' ] OPEN_MODELS = { 'Mistral-Nemo-Instruct-2407': 'mistralai/Mistral-Nemo-Instruct-2407', 'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.3', 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct" } DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)' OPEN_EMBEDDINGS = { DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2', 'SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral', 'SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R', 'NV-Embed': 'nvidia/NV-Embed-v1', 'e5-mistral-7b-instruct': 'intfloat/e5-mistral-7b-instruct', 'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5' } if 'rqa' not in st.session_state: st.session_state['rqa'] = {} if 'model' not in st.session_state: st.session_state['model'] = None if 'api_keys' not in st.session_state: st.session_state['api_keys'] = {} if 'doc_id' not in st.session_state: st.session_state['doc_id'] = None if 'loaded_embeddings' not in st.session_state: st.session_state['loaded_embeddings'] = None if 'hash' not in st.session_state: st.session_state['hash'] = None if 'git_rev' not in st.session_state: st.session_state['git_rev'] = "unknown" if os.path.exists("revision.txt"): with open("revision.txt", 'r') as fr: from_file = fr.read() st.session_state['git_rev'] = from_file if len(from_file) > 0 else "unknown" if "messages" not in st.session_state: st.session_state.messages = [] if 'ner_processing' not in st.session_state: st.session_state['ner_processing'] = False if 'uploaded' not in st.session_state: st.session_state['uploaded'] = False if 'memory' not in st.session_state: st.session_state['memory'] = None if 'binary' not in st.session_state: st.session_state['binary'] = None if 'annotations' not in st.session_state: st.session_state['annotations'] = None if 'should_show_annotations' not in st.session_state: st.session_state['should_show_annotations'] = True if 'pdf' not in st.session_state: st.session_state['pdf'] = None if 'embeddings' not in st.session_state: st.session_state['embeddings'] = None if 'scroll_to_first_annotation' not in st.session_state: st.session_state['scroll_to_first_annotation'] = False st.set_page_config( page_title="Scientific Document Insights Q/A", page_icon="📝", initial_sidebar_state="expanded", layout="wide", menu_items={ 'Get Help': 'https://github.com/lfoppiano/document-qa', 'Report a bug': "https://github.com/lfoppiano/document-qa/issues", 'About': "Upload a scientific article in PDF, ask questions, get insights." } ) st.markdown( """ """, unsafe_allow_html=True ) def new_file(): st.session_state['loaded_embeddings'] = None st.session_state['doc_id'] = None st.session_state['uploaded'] = True if st.session_state['memory']: st.session_state['memory'].clear() def clear_memory(): st.session_state['memory'].clear() # @st.cache_resource def init_qa(model, embeddings_name=None, api_key=None): ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])]) if model in OPENAI_MODELS: if embeddings_name is None: embeddings_name = 'text-embedding-ada-002' st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if api_key: chat = ChatOpenAI(model_name=model, temperature=0, openai_api_key=api_key, frequency_penalty=0.1) if embeddings_name not in OPENAI_EMBEDDINGS: st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.") st.stop() return embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key) else: chat = ChatOpenAI(model_name=model, temperature=0, frequency_penalty=0.1) embeddings = OpenAIEmbeddings(model=embeddings_name) elif model in OPEN_MODELS: if embeddings_name is None: embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME chat = HuggingFaceEndpoint( repo_id=OPEN_MODELS[model], temperature=0.01, max_new_tokens=4092, model_kwargs={"max_length": 8192}, # callbacks=[PromptLayerCallbackHandler(pl_tags=[model, "document-qa"])] ) embeddings = HuggingFaceEmbeddings( model_name=OPEN_EMBEDDINGS[embeddings_name]) # st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None else: st.error("The model was not loaded properly. Try reloading. ") st.stop() return storage = DataStorage(embeddings) return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) @st.cache_resource def init_ner(): quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True) materials_client = NERClientGeneric(ping=True) config_materials = { 'grobid': { "server": os.environ['GROBID_MATERIALS_URL'], 'sleep_time': 5, 'timeout': 60, 'url_mapping': { 'processText_disable_linking': "/service/process/text?disableLinking=True", # 'processText_disable_linking': "/service/process/text" } } } materials_client.set_config(config_materials) gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client, grobid_superconductors_client=materials_client) return gqa gqa = init_ner() def get_file_hash(fname): hash_md5 = blake2b() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def play_old_messages(container): if st.session_state['messages']: for message in st.session_state['messages']: if message['role'] == 'user': container.chat_message("user").markdown(message['content']) elif message['role'] == 'assistant': if mode == "LLM": container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True) else: container.chat_message("assistant").write(message['content']) # is_api_key_provided = st.session_state['api_key'] with st.sidebar: st.title("📝 Document Q/A") st.markdown("Upload a scientific article in PDF, ask questions, get insights.") st.markdown( ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") st.divider() st.session_state['model'] = model = st.selectbox( "Model:", options=OPENAI_MODELS + list(OPEN_MODELS.keys()), index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index( os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0, placeholder="Select model", help="Select the LLM model:", disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] ) embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS st.session_state['embeddings'] = embedding_name = st.selectbox( "Embeddings:", options=embedding_choices, index=0, placeholder="Select embedding", help="Select the Embedding function:", disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] ) if (model in OPEN_MODELS) and model not in st.session_state['api_keys']: if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ: api_key = st.text_input('Huggingface API Key', type="password") st.markdown("Get it [here](https://huggingface.co/docs/hub/security-tokens)") else: api_key = os.environ['HUGGINGFACEHUB_API_TOKEN'] if api_key: # st.session_state['api_key'] = is_api_key_provided = True if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']: with st.spinner("Preparing environment"): st.session_state['api_keys'][model] = api_key # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ: # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key st.session_state['rqa'][model] = init_qa(model, embedding_name) elif model in OPENAI_MODELS and model not in st.session_state['api_keys']: if 'OPENAI_API_KEY' not in os.environ: api_key = st.text_input('OpenAI API Key', type="password") st.markdown("Get it [here](https://platform.openai.com/account/api-keys)") else: api_key = os.environ['OPENAI_API_KEY'] if api_key: if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']: with st.spinner("Preparing environment"): st.session_state['api_keys'][model] = api_key if 'OPENAI_API_KEY' not in os.environ: st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key) else: st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings']) # else: # is_api_key_provided = st.session_state['api_key'] # st.button( # 'Reset chat memory.', # key="reset-memory-button", # on_click=clear_memory, # help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.", # disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None) left_column, right_column = st.columns([5, 4]) right_column = right_column.container(border=True) left_column = left_column.container(border=True) with right_column: uploaded_file = st.file_uploader( "Upload a scientific article", type=("pdf"), on_change=new_file, disabled=st.session_state['model'] is not None and st.session_state['model'] not in st.session_state['api_keys'], help="The full-text is extracted using Grobid." ) placeholder = st.empty() messages = st.container(height=300) question = st.chat_input( "Ask something about the article", # placeholder="Can you give me a short summary?", disabled=not uploaded_file ) query_modes = { "llm": "LLM Q/A", "embeddings": "Embeddings", "question_coefficient": "Question coefficient" } with st.sidebar: st.header("Settings") mode = st.radio( "Query mode", ("llm", "embeddings", "question_coefficient"), disabled=not uploaded_file, index=0, horizontal=True, format_func=lambda x: query_modes[x], help="LLM will respond the question, Embedding will show the " "relevant paragraphs to the question in the paper. " "Question coefficient attempt to estimate how effective the question will be answered." ) st.session_state['scroll_to_first_annotation'] = st.checkbox( "Scroll to context", help='The PDF viewer will automatically scroll to the first relevant passage in the document.' ) st.session_state['ner_processing'] = st.checkbox( "Identify materials and properties.", help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.' ) # Add a checkbox for showing annotations # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True) # st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True) chunk_size = st.slider("Text chunks size", -1, 2000, value=-1, help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.", disabled=uploaded_file is not None) if chunk_size == -1: context_size = st.slider("Context size (paragraphs)", 3, 20, value=10, help="Number of paragraphs to consider when answering a question", disabled=not uploaded_file) else: context_size = st.slider("Context size (chunks)", 3, 10, value=4, help="Number of chunks to consider when answering a question", disabled=not uploaded_file) st.divider() st.header("Documentation") st.markdown("https://github.com/lfoppiano/document-qa") st.markdown( """Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions.""") if st.session_state['git_rev'] != "unknown": st.markdown("**Revision number**: [" + st.session_state[ 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")") if uploaded_file and not st.session_state.loaded_embeddings: if model not in st.session_state['api_keys']: st.error("Before uploading a document, you must enter the API key. ") st.stop() with left_column: with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'): binary = uploaded_file.getvalue() tmp_file = NamedTemporaryFile() tmp_file.write(bytearray(binary)) st.session_state['binary'] = binary st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name, chunk_size=chunk_size, perc_overlap=0.1) st.session_state['loaded_embeddings'] = True st.session_state.messages = [] def rgb_to_hex(rgb): return "#{:02x}{:02x}{:02x}".format(*rgb) def generate_color_gradient(num_elements): # Define warm and cold colors in RGB format warm_color = (255, 165, 0) # Orange cold_color = (0, 0, 255) # Blue # Generate a linear gradient of colors color_gradient = [ rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in zip(warm_color, cold_color))) for i in range(num_elements) ] return color_gradient with right_column: if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id: st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) for message in st.session_state.messages: # with messages.chat_message(message["role"]): if message['mode'] == "llm": messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) elif message['mode'] == "embeddings": messages.chat_message(message["role"]).write(message["content"]) elif message['mode'] == "question_coefficient": messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) if model not in st.session_state['rqa']: st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `") st.stop() text_response = None if mode == "embeddings": with placeholder: with st.spinner("Fetching the relevant context..."): text_response, coordinates = st.session_state['rqa'][model].query_storage( question, st.session_state.doc_id, context_size=context_size ) elif mode == "llm": with placeholder: with st.spinner("Generating LLM response..."): _, text_response, coordinates = st.session_state['rqa'][model].query_document( question, st.session_state.doc_id, context_size=context_size ) elif mode == "question_coefficient": with st.spinner("Estimate question/context relevancy..."): text_response, coordinates = st.session_state['rqa'][model].analyse_query( question, st.session_state.doc_id, context_size=context_size ) annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] for coord_doc in coordinates] gradients = generate_color_gradient(len(annotations)) for i, color in enumerate(gradients): for annotation in annotations[i]: annotation['color'] = color st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in annotation_doc] if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.") if mode == "llm": if st.session_state['ner_processing']: with st.spinner("Processing NER on LLM response..."): entities = gqa.process_single_text(text_response) decorated_text = decorate_text_with_annotations(text_response.strip(), entities) decorated_text = decorated_text.replace('class="label material"', 'style="color:green"') decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text) text_response = decorated_text messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True) else: messages.chat_message("assistant").write(text_response) st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response}) elif st.session_state.loaded_embeddings and st.session_state.doc_id: play_old_messages(messages) with left_column: if st.session_state['binary']: with st.container(height=600): pdf_viewer( input=st.session_state['binary'], annotation_outline_size=2, annotations=st.session_state['annotations'], render_text=True, scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state['scroll_to_first_annotation']) else None )