zamborg commited on
Commit
1cabc83
1 Parent(s): 250dc27

testing caption_prompting

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -5,7 +5,7 @@ import time
5
  import json
6
  sys.path.append("./virtex/")
7
 
8
- def gen_show_caption(sub_prompt=None):
9
  with st.spinner("Generating Caption"):
10
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt)
11
  st.header("Predicted Caption:\n\n")
@@ -35,6 +35,21 @@ sample_image = st.sidebar.selectbox(
35
  sample_images
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if st.sidebar.button("Random Sample Image"):
39
  random_image = get_rand_img(sample_images)
40
  sample_image = None
@@ -46,11 +61,6 @@ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
46
  if uploaded_file is not None and submitted:
47
  uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
48
 
49
- st.sidebar.title("Select a Subreddit")
50
- sub = st.sidebar.selectbox(
51
- "Select None for a Predicted Subreddit",
52
- valid_subs
53
- )
54
 
55
 
56
  if uploaded_image is None and submitted:
@@ -68,7 +78,7 @@ else:
68
  show = st.image(image)
69
  show.image(image, "Your Image")
70
 
71
- gen_show_caption(sub)
72
 
73
  image.close()
74
 
 
5
  import json
6
  sys.path.append("./virtex/")
7
 
8
+ def gen_show_caption(sub_prompt=None, cap_prompt = None):
9
  with st.spinner("Generating Caption"):
10
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt)
11
  st.header("Predicted Caption:\n\n")
 
35
  sample_images
36
  )
37
 
38
+ st.sidebar.title("Select a Subreddit")
39
+ sub = st.sidebar.selectbox(
40
+ "Select None for a Predicted Subreddit",
41
+ valid_subs
42
+ )
43
+
44
+ st.sidebar.title("Write a Custom Prompt")
45
+ cap_prompt = st.text_input(
46
+ "Leave this blank for an unbiased caption",
47
+ value="This is a photo of"
48
+ )
49
+ st.write("=====================")
50
+ st.write(cap_prompt)
51
+ st.write("=====================")
52
+
53
  if st.sidebar.button("Random Sample Image"):
54
  random_image = get_rand_img(sample_images)
55
  sample_image = None
 
61
  if uploaded_file is not None and submitted:
62
  uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
63
 
 
 
 
 
 
64
 
65
 
66
  if uploaded_image is None and submitted:
 
78
  show = st.image(image)
79
  show.image(image, "Your Image")
80
 
81
+ gen_show_caption(sub, cap_prompt)
82
 
83
  image.close()
84