import os from tqdm import tqdm import argparse import torch from torch.utils.data import DataLoader import torchvision.transforms as transforms import torchvision.datasets as datasets import torch.nn.functional as F from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD import eva_vit_model from eva_vit_model import CLIP from open_clip.tokenizer import tokenize from imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template def main(args): device = 'cuda' if torch.cuda.is_available() else 'cpu' if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.allow_tf32 = True print(f"creating model: {args.model}") model = CLIP(vision_model=args.model) print(f"loading checkpoint from {args.ckpt_path}") state_dict = torch.load(args.ckpt_path, map_location='cpu') model.load_state_dict(state_dict, strict=True) model.to(device) def _convert_image_to_rgb(image): return image.convert("RGB") val_transform = transforms.Compose([ transforms.Resize(args.image_size, transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(args.image_size), _convert_image_to_rgb, transforms.ToTensor(), transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD) ]) val_dataset = datasets.ImageFolder(os.path.join(args.imagenet_path, 'val'), transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers) model.eval() classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, device) top1, top5 = zero_shot_eval(model, classifier, val_loader, device) print(f'ImageNet zeroshot top1: {top1:.4f}, top5: {top5:.4f}') def zero_shot_classifier(model, classnames, templates, device): tokenizer = tokenize with torch.no_grad(): zeroshot_weights = [] for classname in tqdm(classnames): texts = [template(classname) for template in templates] # format with class texts = tokenizer(texts).to(device=device) # tokenize with torch.cuda.amp.autocast(): class_embeddings = model.encode_text(texts) class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) return zeroshot_weights def accuracy(output, target, topk=(1,)): pred = output.topk(max(topk), 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] def zero_shot_eval(model, classifier, dataloader, device): top1, top5, n = 0., 0., 0. with torch.no_grad(): for images, target in tqdm(dataloader, unit_scale=args.batch_size): images = images.to(device=device) target = target.to(device=device) with torch.cuda.amp.autocast(): image_features = model.encode_image(images) image_features = F.normalize(image_features, dim=-1) logits = 100. * image_features @ classifier # measure accuracy acc1, acc5 = accuracy(logits, target, topk=(1, 5)) top1 += acc1 top5 += acc5 n += images.size(0) top1 = (top1 / n) top5 = (top5 / n) return top1, top5 if __name__ == '__main__': parser = argparse.ArgumentParser(description='ImageNet zero shot evaluations', add_help=False) parser.add_argument('--imagenet-path', default='path/to/imagenet', type=str, help='path to imagenet dataset') parser.add_argument('--ckpt-path', default='path/to/ckpt', type=str, help='path to checkpoint') parser.add_argument('--batch-size', default=64, type=int, help='batch size') parser.add_argument('--model', default='eva_base_p16', type=str, help='model') parser.add_argument('--image-size', default=224, type=int, help='image size for evaluation') parser.add_argument('--workers', default=8, type=int) args = parser.parse_args() main(args)