import streamlit as st import pickle import pandas as pd import torch from PIL import Image import numpy as np from main import predict_caption, CLIPModel , get_text_embeddings st.markdown( """ """, unsafe_allow_html=True, ) device = torch.device("cpu") testing_df = pd.read_csv("testing_df.csv") model = CLIPModel().to(device) model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu'))) text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device) def show_predicted_caption(image): matches = predict_caption( image, model, text_embeddings, testing_df["caption"] )[0] return matches st.title("Medical Image Captioning") st.write("Upload an image to get a caption:") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) st.write("") if st.button("Generate Caption"): with st.spinner("Generating caption..."): image_np = np.array(image) caption = show_predicted_caption(image_np) st.success(f"Caption: {caption}")