Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import faiss # Ensure faiss is available | |
# Load the tokenizer, retriever, and model | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) | |
# Define prediction function | |
def predict(input_text): | |
# Tokenize input | |
input_ids = tokenizer([input_text], return_tensors="pt").input_ids | |
# Generate response | |
outputs = model.generate(input_ids) | |
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return response | |
# Add example texts | |
examples = [ | |
["Patient admitted with a history of heart failure and requires detailed follow-up on cardiovascular treatment."], | |
["What are the complications of diabetes mellitus that need to be monitored in this patient?"], | |
["Describe the appropriate treatment for acute respiratory distress syndrome in a critical care setting."], | |
["Explain the signs and symptoms that indicate a neurological emergency in a stroke patient."], | |
["What are the best practices for managing an infectious disease outbreak in a hospital setting?"] | |
] | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=10, placeholder="Enter your medical question or clinical notes here..."), | |
outputs="text", | |
examples=examples, | |
title="MIMIC-IV RAG Implementation", | |
description="Use RAG (Retrieval-Augmented Generation) to generate responses or provide additional information based on clinical notes and medical questions. This model helps in generating relevant information based on existing medical literature.", | |
) | |
iface.launch() | |