import gradio as gr import torch import torch.nn.functional as F from torchvision import transforms from open_clip import create_model, get_tokenizer from templates import openai_imagenet_template model_str = "ViT-B-16" pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt" preprocess_img = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) @torch.no_grad() def get_txt_features(classnames, templates): all_features = [] for classname in classnames: txts = [template(classname) for template in templates] txts = tokenizer(txts) txt_features = model.encode_text(txts) txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) txt_features /= txt_features.norm() all_features.append(txt_features) all_features = torch.stack(all_features, dim=1) return all_features @torch.no_grad() def predict(img, cls_str: str) -> dict[str, float]: classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] txt_features = get_txt_features(classes, openai_imagenet_template) img = preprocess_img(img) img_features = model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) logits = (img_features @ txt_features).squeeze() probs = F.softmax(logits, dim=0).tolist() return {cls: prob for cls, prob in zip(classes, probs)} if __name__ == "__main__": print("Starting.") model = create_model(model_str, pretrained, output_dict=True) print("Created model.") model = torch.compile(model) print("Compiled model.") tokenizer = get_tokenizer(model_str) demo = gr.Interface( fn=predict, inputs=[ gr.Image(shape=(224, 224)), gr.Textbox( placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True ), ], outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True), ) demo.launch()