AIM: Autoregressive Image Models
Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin
This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.
We introduce AIM a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
- the model capacity can be trivially scaled to billions of parameters, and
- AIM effectively leverages large collections of uncurated image data.
Installation
Please install PyTorch using the official installation instructions. Afterward, install the package as:
pip install git+https://git@github.com/apple/ml-aim.git
We also offer MLX backend support for research and experimentation on Apple silicon. To enable MLX support, simply run:
pip install mlx
Usage
Below we provide an example of usage in PyTorch:
from PIL import Image
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
logits, _ = model(inp)
and in both MLX
from PIL import Image
import mlx.core as mx
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, _ = model(inp)
and JAX
from PIL import Image
import jax.numpy as jnp
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, _), _ = model.apply(params, inp, mutable=['batch_stats'])
Pre-trained checkpoints
The pre-trained models can be accessed either via Hugging Face:
# after running pip install git+https://git@github.com/apple/ml-aim.git
from aim.torch.models import AIMForImageClassification
aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b = AIMForImageClassification.from_pretrained("apple/aim-7B")
or PyTorch Hub as:
import torch
aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b = torch.hub.load("apple/ml-aim", "aim_7B")
Pre-trained backbones
The following table contains pre-trained backbones used in our paper.
model | #params | attn (best layer) | backbone, SHA256 |
---|---|---|---|
AIM-0.6B | 0.6B | 79.4% | link, 0d6f6b8f |
AIM-1B | 1B | 82.3% | link, d254ecd3 |
AIM-3B | 3B | 83.3% | link, 8475ce4e |
AIM-7B | 7B | 84.0% | link, 184ed94c |
Pre-trained attention heads
The table below contains the classification results on ImageNet-1k validation set.
model | top-1 IN-1k | attention head, SHA256 | ||
---|---|---|---|---|
last layer | best layer | last layer | best layer | |
AIM-0.6B | 78.5% | 79.4% | link, 5ce5a341 | link, ebd45c05 |
AIM-1B | 80.6% | 82.3% | link, db3be2ad | link, f1ed7852 |
AIM-3B | 82.2% | 83.3% | link, 5c057b30 | link, ad380e16 |
AIM-7B | 82.4% | 84.0% | link, 1e5c99ba | link, 73ecd732 |
Reproducing the IN-1k classification results
The commands below reproduce the attention probe results on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs:
torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
--model=aim-7B \
--batch-size=64 \
--data-path=/path/to/imagenet \
--probe-layers=last \
--backbone-ckpt-path=/path/to/backbone_ckpt.pth \
--head-ckpt-path=/path/to/head_ckpt.pth
By default, we probe the last 6 layers. To change this, simply pass --probe-layers=best
.