Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
f3d261b
1 Parent(s): 70299d1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -59
README.md CHANGED
@@ -1,81 +1,79 @@
1
- ---
2
- {}
3
  ---
4
 
5
- ---
6
- license: apache-2.0
7
- tags:
8
- - mlx
9
- - mlx-image
10
- - vision
11
- - image-classification
12
- datasets:
13
- - imagenet-1k
14
- library_name: mlx-image
15
- ---
16
- # vit_base_patch16_224.dino
17
 
18
- A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
19
 
20
- The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
21
 
22
- Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
23
 
24
- <div align="center">
25
- <img width="100%" alt="DINO illustration" src="dino.gif">
26
- </div>
27
 
28
 
29
- ## How to use
30
- ```bash
31
- pip install mlx-image
32
- ```
33
 
34
- Here is how to use this model for image classification:
35
 
36
- ```python
37
- from mlxim.model import create_model
38
- from mlxim.io import read_rgb
39
- from mlxim.transform import ImageNetTransform
40
 
41
- transform = ImageNetTransform(train=False, img_size=224)
42
- x = transform(read_rgb("cat.png"))
43
- x = mx.expand_dims(x, 0)
44
 
45
- model = create_model("vit_base_patch16_224.dino")
46
- model.eval()
47
 
48
- logits, attn_masks = model(x, attn_masks=True)
49
- ```
50
 
51
- You can also use the embeds from layer before head:
52
- ```python
53
- from mlxim.model import create_model
54
- from mlxim.io import read_rgb
55
- from mlxim.transform import ImageNetTransform
56
 
57
- transform = ImageNetTransform(train=False, img_size=512)
58
- x = transform(read_rgb("cat.png"))
59
- x = mx.expand_dims(x, 0)
60
 
61
- # first option
62
- model = create_model("vit_base_patch16_224.dino", num_classes=0)
63
- model.eval()
64
 
65
- embeds = model(x)
66
 
67
- # second option
68
- model = create_model("vit_base_patch16_224.dino")
69
- model.eval()
70
 
71
- embeds, attn_masks = model.get_features(x)
72
- ```
73
 
74
- ## Attention maps
75
- You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
76
 
77
- <div align="center">
78
- <img width="100%" alt="Attention Map" src="attention_maps.png">
79
- </div>
80
 
81
-
 
1
+
 
2
  ---
3
 
4
+ ---
5
+ license: apache-2.0
6
+ tags:
7
+ - mlx
8
+ - mlx-image
9
+ - vision
10
+ - image-classification
11
+ datasets:
12
+ - imagenet-1k
13
+ library_name: mlx-image
14
+ ---
15
+ # vit_base_patch16_224.dino
16
 
17
+ A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
18
 
19
+ The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
20
 
21
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
22
 
23
+ <div align="center">
24
+ <img width="100%" alt="DINO illustration" src="dino.gif">
25
+ </div>
26
 
27
 
28
+ ## How to use
29
+ ```bash
30
+ pip install mlx-image
31
+ ```
32
 
33
+ Here is how to use this model for image classification:
34
 
35
+ ```python
36
+ from mlxim.model import create_model
37
+ from mlxim.io import read_rgb
38
+ from mlxim.transform import ImageNetTransform
39
 
40
+ transform = ImageNetTransform(train=False, img_size=224)
41
+ x = transform(read_rgb("cat.png"))
42
+ x = mx.expand_dims(x, 0)
43
 
44
+ model = create_model("vit_base_patch16_224.dino")
45
+ model.eval()
46
 
47
+ logits, attn_masks = model(x, attn_masks=True)
48
+ ```
49
 
50
+ You can also use the embeds from layer before head:
51
+ ```python
52
+ from mlxim.model import create_model
53
+ from mlxim.io import read_rgb
54
+ from mlxim.transform import ImageNetTransform
55
 
56
+ transform = ImageNetTransform(train=False, img_size=512)
57
+ x = transform(read_rgb("cat.png"))
58
+ x = mx.expand_dims(x, 0)
59
 
60
+ # first option
61
+ model = create_model("vit_base_patch16_224.dino", num_classes=0)
62
+ model.eval()
63
 
64
+ embeds = model(x)
65
 
66
+ # second option
67
+ model = create_model("vit_base_patch16_224.dino")
68
+ model.eval()
69
 
70
+ embeds, attn_masks = model.get_features(x)
71
+ ```
72
 
73
+ ## Attention maps
74
+ You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
75
 
76
+ <div align="center">
77
+ <img width="100%" alt="Attention Map" src="attention_maps.png">
78
+ </div>
79