Spaces:
Runtime error
Runtime error
File size: 2,428 Bytes
80dc74c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
from app_lib.utils import SUPPORTED_MODELS
def _validate_class_name(class_name):
if class_name is None:
return (False, "Class name cannot be empty.")
if class_name.strip() == "":
return (False, "Class name cannot be empty.")
return (True, None)
def _validate_concepts(concepts):
if len(concepts) < 3:
return (False, "You must provide at least 3 concepts")
if len(concepts) > 10:
return (False, "Maximum 10 concepts allowed")
return (True, None)
def get_model_name():
return st.selectbox(
"Choose a model to test",
options=SUPPORTED_MODELS,
help="Name of the vision-language model to test the predictions of.",
)
def get_image():
with st.sidebar:
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
image = uploaded_file or image_select(
label="or select one",
images=[
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
],
)
return Image.open(image)
def get_class_name():
class_name = st.text_input(
"Class to test",
help="Name of the class to build the zero-shot CLIP classifier with.",
value="cat",
)
class_ready, class_error = _validate_class_name(class_name)
return class_name, class_ready, class_error
def get_concepts():
concepts = st.text_area(
"Concepts to test (max 10)",
help="List of concepts to test the predictions of the model with. Write one concept per line.",
height=160,
value="piano\ncute\nwhiskers\nmusic\nwild",
)
concepts = concepts.split("\n")
concepts = [concept.strip() for concept in concepts]
concepts = [concept for concept in concepts if concept != ""]
concepts = list(set(concepts))
concepts_ready, concepts_error = _validate_concepts(concepts)
return concepts, concepts_ready, concepts_error
def get_cardinality(concepts, concepts_ready):
return st.slider(
"Size of conditioning set",
help="The number of concepts to condition model predictions on.",
min_value=1,
max_value=max(2, len(concepts) - 1),
value=1,
step=1,
disabled=not concepts_ready,
)
|