Spaces:
Runtime error
Runtime error
import streamlit as st | |
import io | |
import sys | |
import time | |
import json | |
sys.path.append("./virtex/") | |
from model import * | |
def gen_show_caption(sub_prompt=None, cap_prompt = ""): | |
with st.spinner("Generating Caption"): | |
if sub_prompt is None and cap_prompt is not "": | |
st.write("Without a specified subreddit we default to /r/pics") | |
subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt) | |
st.markdown( | |
f""" | |
<style> | |
red{{ | |
color:#c62828 | |
}} | |
mono{{ | |
font-family: "Inconsolata"; | |
}} | |
</style> | |
### <red> r/{subreddit} </red> {caption} | |
""", | |
unsafe_allow_html=True) | |
st.title("Image Captioning Demo from RedCaps") | |
st.sidebar.markdown( | |
""" | |
### Image Captioning Model from VirTex trained on RedCaps | |
Use this page to caption your own images or try out some of our samples. | |
You can also generate captions as if they are from specific subreddits, | |
as if they start with a particular prompt, or even both. | |
Share your results on twitter with #redcaps or with a friend. | |
""" | |
) | |
with st.spinner("Loading Model"): | |
virtexModel, imageLoader, sample_images, valid_subs = create_objects() | |
# staggered = st.sidebar.checkbox("Iteratively Generate Captions") | |
# if staggered: | |
# pass | |
# else: | |
select_idx = None | |
st.sidebar.title("Select a sample image") | |
if st.sidebar.button("Random Sample Image"): | |
select_idx = get_rand_idx(sample_images) | |
sample_image = sample_images[0 if select_idx is None else select_idx] | |
uploaded_image = None | |
with st.sidebar.form("file-uploader-form", clear_on_submit=True): | |
uploaded_file = st.file_uploader("Choose a file") | |
submitted = st.form_submit_button("Submit") | |
if uploaded_file is not None and submitted: | |
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue())) | |
select_idx = None # set this to help rewrite the cache | |
# class OnChange(): | |
# def __init__(self, idx): | |
# self.idx = idx | |
# def __call__(self): | |
# st.write(f"the idx is: {self.idx}") | |
# st.write(f"the sample_image is {sample_image}") | |
# sample_image = st.sidebar.selectbox( | |
# "", | |
# sample_images, | |
# index = 0 if select_idx is None else select_idx, | |
# on_change=OnChange(0 if select_idx is None else select_idx) | |
# ) | |
st.sidebar.title("Select a Subreddit") | |
sub = st.sidebar.selectbox( | |
"Type below to condition on a subreddit. Select None for a predicted subreddit", | |
valid_subs | |
) | |
st.sidebar.title("Write a Custom Prompt") | |
cap_prompt = st.sidebar.text_input( | |
"Write the start of your caption below", | |
value="" | |
) | |
_ = st.sidebar.button("Regenerate Caption") | |
# advanced = st.sidebar.checkbox("Advanced Options") | |
# if advanced: | |
# nuc_size = st.sidebar.slider("") | |
if uploaded_image is None and submitted: | |
st.write("Please select a file to upload") | |
else: | |
image_file = sample_image | |
# LOAD AND CACHE THE IMAGE | |
if uploaded_image is not None: | |
image = uploaded_image | |
elif select_idx is None and 'image' in st.session_state: | |
image = st.session_state['image'] | |
else: | |
image = Image.open(image_file) | |
image = image.convert("RGB") | |
st.session_state['image'] = image | |
image_dict = imageLoader.transform(image) | |
show_image = imageLoader.show_resize(image) | |
show = st.image(show_image) | |
show.image(show_image, "Your Image") | |
gen_show_caption(sub, imageLoader.text_transform(cap_prompt)) | |
# from model import * | |
# sample_images = get_samples() | |
# v, il = VirTexModel(), ImageLoader() | |
# for s in sample_images: | |
# subreddit, caption = v.predict(il.load(s)) | |
# print("=====================") | |
# print(subreddit) | |
# print(caption) | |