Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pickle | |
import pandas as pd | |
import torch | |
from PIL import Image | |
import numpy as np | |
from main import predict_caption, CLIPModel , get_text_embeddings | |
st.markdown( | |
""" | |
<style> | |
body { | |
background-color: transparent; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
device = torch.device("cpu") | |
testing_df = pd.read_csv("testing_df.csv") | |
model = CLIPModel().to(device) | |
model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu'))) | |
text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device) | |
def show_predicted_caption(image): | |
matches = predict_caption( | |
image, model, text_embeddings, testing_df["caption"] | |
)[0] | |
return matches | |
st.title("Medical Image Captioning") | |
st.write("Upload an image to get a caption:") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
st.write("") | |
if st.button("Generate Caption"): | |
with st.spinner("Generating caption..."): | |
image_np = np.array(image) | |
caption = show_predicted_caption(image_np) | |
st.success(f"Caption: {caption}") | |