File size: 2,428 Bytes
80dc74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select

from app_lib.utils import SUPPORTED_MODELS


def _validate_class_name(class_name):
    if class_name is None:
        return (False, "Class name cannot be empty.")
    if class_name.strip() == "":
        return (False, "Class name cannot be empty.")
    return (True, None)


def _validate_concepts(concepts):
    if len(concepts) < 3:
        return (False, "You must provide at least 3 concepts")
    if len(concepts) > 10:
        return (False, "Maximum 10 concepts allowed")
    return (True, None)


def get_model_name():
    return st.selectbox(
        "Choose a model to test",
        options=SUPPORTED_MODELS,
        help="Name of the vision-language model to test the predictions of.",
    )


def get_image():
    with st.sidebar:
        uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
        image = uploaded_file or image_select(
            label="or select one",
            images=[
                "assets/ace.jpg",
                "assets/ace.jpg",
                "assets/ace.jpg",
                "assets/ace.jpg",
            ],
        )
    return Image.open(image)


def get_class_name():
    class_name = st.text_input(
        "Class to test",
        help="Name of the class to build the zero-shot CLIP classifier with.",
        value="cat",
    )

    class_ready, class_error = _validate_class_name(class_name)
    return class_name, class_ready, class_error


def get_concepts():
    concepts = st.text_area(
        "Concepts to test (max 10)",
        help="List of concepts to test the predictions of the model with. Write one concept per line.",
        height=160,
        value="piano\ncute\nwhiskers\nmusic\nwild",
    )
    concepts = concepts.split("\n")
    concepts = [concept.strip() for concept in concepts]
    concepts = [concept for concept in concepts if concept != ""]
    concepts = list(set(concepts))

    concepts_ready, concepts_error = _validate_concepts(concepts)
    return concepts, concepts_ready, concepts_error


def get_cardinality(concepts, concepts_ready):
    return st.slider(
        "Size of conditioning set",
        help="The number of concepts to condition model predictions on.",
        min_value=1,
        max_value=max(2, len(concepts) - 1),
        value=1,
        step=1,
        disabled=not concepts_ready,
    )