File size: 2,206 Bytes
d1c1a86
 
 
 
 
 
8fa75cc
d1c1a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms

from open_clip import create_model, get_tokenizer
from templates import openai_imagenet_template

model_str = "ViT-B-16"
pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt"

preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


@torch.no_grad()
def get_txt_features(classnames, templates):
    all_features = []
    for classname in classnames:
        txts = [template(classname) for template in templates]
        txts = tokenizer(txts)
        txt_features = model.encode_text(txts)
        txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
        txt_features /= txt_features.norm()
        all_features.append(txt_features)
    all_features = torch.stack(all_features, dim=1)
    return all_features


@torch.no_grad()
def predict(img, cls_str: str) -> dict[str, float]:
    classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
    txt_features = get_txt_features(classes, openai_imagenet_template)

    img = preprocess_img(img)

    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)
    logits = (img_features @ txt_features).squeeze()
    probs = F.softmax(logits, dim=0).tolist()
    return {cls: prob for cls, prob in zip(classes, probs)}


if __name__ == "__main__":
    print("Starting.")
    model = create_model(model_str, pretrained, output_dict=True)
    print("Created model.")

    model = torch.compile(model)
    print("Compiled model.")

    tokenizer = get_tokenizer(model_str)

    demo = gr.Interface(
        fn=predict,
        inputs=[
            gr.Image(shape=(224, 224)),
            gr.Textbox(
                placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True
            ),
        ],
        outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
    )

    demo.launch()