import gradio as gr import torch from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration import faiss # Ensure faiss is available # Load the tokenizer, retriever, and model tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) # Define prediction function def predict(input_text): # Tokenize input input_ids = tokenizer([input_text], return_tensors="pt").input_ids # Generate response outputs = model.generate(input_ids) response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return response # Add example texts examples = [ ["Patient admitted with a history of heart failure and requires detailed follow-up on cardiovascular treatment."], ["What are the complications of diabetes mellitus that need to be monitored in this patient?"], ["Describe the appropriate treatment for acute respiratory distress syndrome in a critical care setting."], ["Explain the signs and symptoms that indicate a neurological emergency in a stroke patient."], ["What are the best practices for managing an infectious disease outbreak in a hospital setting?"] ] # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=10, placeholder="Enter your medical question or clinical notes here..."), outputs="text", examples=examples, title="MIMIC-IV RAG Implementation", description="Use RAG (Retrieval-Augmented Generation) to generate responses or provide additional information based on clinical notes and medical questions. This model helps in generating relevant information based on existing medical literature.", ) iface.launch()