RadiXGPT_ / app.py
Singularity666's picture
Upload 7 files
f9f1b17
raw
history blame
1.3 kB
import streamlit as st
import pickle
import pandas as pd
import torch
from PIL import Image
import numpy as np
from main import predict_caption, CLIPModel , get_text_embeddings
st.markdown(
"""
<style>
body {
background-color: transparent;
}
</style>
""",
unsafe_allow_html=True,
)
device = torch.device("cpu")
testing_df = pd.read_csv("testing_df.csv")
model = CLIPModel().to(device)
model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
def show_predicted_caption(image):
matches = predict_caption(
image, model, text_embeddings, testing_df["caption"]
)[0]
return matches
st.title("Medical Image Captioning")
st.write("Upload an image to get a caption:")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
st.write("")
if st.button("Generate Caption"):
with st.spinner("Generating caption..."):
image_np = np.array(image)
caption = show_predicted_caption(image_np)
st.success(f"Caption: {caption}")