|
# Quickstart |
|
|
|
This quickstart is intended for developers who are ready to dive into the code and see an example of how to integrate `timm` into their model training workflow. |
|
|
|
First, you |
|
|
|
```bash |
|
pip install timm |
|
``` |
|
|
|
## Load a Pretrained Model |
|
|
|
Pretrained models can be loaded using [`create_model`]. |
|
|
|
Here, we load the pretrained `mobilenetv3_large_100` model. |
|
|
|
```py |
|
>>> import timm |
|
|
|
>>> m = timm.create_model( |
|
>>> m.eval() |
|
``` |
|
|
|
<Tip> |
|
Note: The returned PyTorch model is set to train mode by default, so you must call .eval() on it if you plan to use it for inference. |
|
</Tip> |
|
|
|
## List Models with Pretrained Weights |
|
|
|
To list models packaged with `timm`, you can use [`list_models`]. If you specify `pretrained=True`, this function will only return model names that have associated pretrained weights available. |
|
|
|
```py |
|
>>> import timm |
|
>>> from pprint import pprint |
|
>>> model_names = timm.list_models(pretrained=True) |
|
>>> pprint(model_names) |
|
[ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
``` |
|
|
|
You can also list models with a specific pattern in their name. |
|
|
|
```py |
|
>>> import timm |
|
>>> from pprint import pprint |
|
>>> model_names = timm.list_models( |
|
>>> pprint(model_names) |
|
[ |
|
|
|
|
|
|
|
|
|
... |
|
] |
|
``` |
|
|
|
## Fine-Tune a Pretrained Model |
|
|
|
You can finetune any of the pre-trained models just by changing the classifier (the last layer). |
|
|
|
```py |
|
>>> model = timm.create_model( |
|
``` |
|
|
|
To fine-tune on your own dataset, you have to write a PyTorch training loop or adapt `timm` |
|
|
|
## Use a Pretrained Model for Feature Extraction |
|
|
|
Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks. |
|
|
|
For a more in depth guide to using `timm` for feature extraction, see [Feature Extraction](feature_extraction). |
|
|
|
```py |
|
>>> import timm |
|
>>> import torch |
|
>>> x = torch.randn(1, 3, 224, 224) |
|
>>> model = timm.create_model( |
|
>>> features = model.forward_features(x) |
|
>>> print(features.shape) |
|
torch.Size([1, 960, 7, 7]) |
|
``` |
|
|
|
## Image Augmentation |
|
|
|
To transform images into valid inputs for a model, you can use [`timm.data.create_transform`], providing the desired `input_size` that the model expects. |
|
|
|
This will return a generic transform that uses reasonable defaults. |
|
|
|
```py |
|
>>> timm.data.create_transform((3, 224, 224)) |
|
Compose( |
|
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None) |
|
CenterCrop(size=(224, 224)) |
|
ToTensor() |
|
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250])) |
|
) |
|
``` |
|
|
|
Pretrained models have specific transforms that were applied to images fed into them while training. If you use the wrong transform on your image, the model won |
|
|
|
To figure out which transformations were used for a given pretrained model, we can start by taking a look at its `pretrained_cfg` |
|
|
|
```py |
|
>>> model.pretrained_cfg |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
``` |
|
|
|
We can then resolve only the data related configuration by using [`timm.data.resolve_data_config`]. |
|
|
|
```py |
|
>>> timm.data.resolve_data_config(model.pretrained_cfg) |
|
{ |
|
|
|
|
|
|
|
|
|
``` |
|
|
|
We can pass this data config to [`timm.data.create_transform`] to initialize the model |
|
|
|
```py |
|
>>> data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) |
|
>>> transform = timm.data.create_transform(**data_cfg) |
|
>>> transform |
|
Compose( |
|
Resize(size=256, interpolation=bicubic, max_size=None, antialias=None) |
|
CenterCrop(size=(224, 224)) |
|
ToTensor() |
|
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250])) |
|
) |
|
``` |
|
|
|
<Tip> |
|
Note: Here, the pretrained model |
|
</Tip> |
|
|
|
## Using Pretrained Models for Inference |
|
|
|
Here, we will put together the above sections and use a pretrained model for inference. |
|
|
|
First we |
|
|
|
```py |
|
>>> import requests |
|
>>> from PIL import Image |
|
>>> from io import BytesIO |
|
>>> url = |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
>>> image |
|
``` |
|
|
|
Here |
|
|
|
<img src="https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg" alt="An Image from a link" width="300"/> |
|
|
|
Now, we |
|
|
|
```py |
|
>>> model = timm.create_model( |
|
>>> transform = timm.data.create_transform( |
|
**timm.data.resolve_data_config(model.pretrained_cfg) |
|
) |
|
``` |
|
|
|
We can prepare this image for the model by passing it to the transform. |
|
|
|
```py |
|
>>> image_tensor = transform(image) |
|
>>> image_tensor.shape |
|
torch.Size([3, 224, 224]) |
|
``` |
|
|
|
Now we can pass that image to the model to get the predictions. We use `unsqueeze(0)` in this case, as the model is expecting a batch dimension. |
|
|
|
```py |
|
>>> output = model(image_tensor.unsqueeze(0)) |
|
>>> output.shape |
|
torch.Size([1, 1000]) |
|
``` |
|
|
|
To get the predicted probabilities, we apply softmax to the output. This leaves us with a tensor of shape `(num_classes,)`. |
|
|
|
```py |
|
>>> probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
>>> probabilities.shape |
|
torch.Size([1000]) |
|
``` |
|
|
|
Now we |
|
|
|
```py |
|
>>> values, indices = torch.topk(probabilities, 5) |
|
>>> indices |
|
tensor([162, 166, 161, 164, 167]) |
|
``` |
|
|
|
If we check the imagenet labels for the top index, we can see what the model predicted... |
|
|
|
```py |
|
>>> IMAGENET_1k_URL = |
|
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split( |
|
>>> [{ |
|
[{ |
|
{ |
|
{ |
|
{ |
|
{ |
|
``` |