Mamba Out

Community Article Published October 18, 2024

Structured like a State Space Model (SSM) but shedding that skin to reveal something leaner and meaner, MambaOut is the Bride of machine learning models a warrior trained in the arts of data kung fu, ready to wreak havoc on any dataset that crosses its path. As Kobe would fade away to sink an impossible shot, MambaOut fades out unnecessary computations, leaving only pure, venomous performance in its wake.

</LLM off>

Sorry, I asked for a catchy intro with reference to pop culture re 'black mamba' and that was too funny. In all seriousness, this model that's now included in timm as of ver 1.0.11 (https://github.com/huggingface/pytorch-image-models) is worth a look. I've spent a bit of time experimenting with it and training a few models from scratch, extending the original weights with a few of my own.

So what exactly is the model? As per the cheeky name, the core block is structured like other Mamba (SSM) vision models but without including the SSM itself. No SSM means no custom compiled kernels or extra dependencies, yay! There's just a 7x7 DW convolution with a gating mechanism sandwiched between two FC layers. See the diagram from the paper (https://arxiv.org/abs/2405.07992) below:

image/png

From my vantage point, the model has a LOT in common with the ConvNeXt family -- no BatchNorm, purely convolutional w/ 7x7 DW conv for spatial mixing, some 1x1 PW / FC for channel mixing. Indeed in runtime performance, training behaviour it feels like an extension of that family as well. Comparing the core building block of each we can see why

image/png

Comparing runtime performance with ConvNeXt, the extra splits, concats, etc add a bit of overhead, but with torch.compile() used that's nullified and they are quite competitive. Bringing the models into timm I experimented with a few small tweaks in model structure and training recipe. You can see them in the models below with the _rw suffix. In the base size range the changes yielded some slightly faster variants w/ more params that were able to eek a slightly higher accuracy. Of the tall, wide, and short base variants I'd consider tall (slightly deeper and slightly narrower) to be the most worthwhile.

The most interesting addition of mine though is the base_plus. Increasing both depth and width a little AND pretraining on ImageNet-12k, it's right up there with the best ImageNet-22k pretrained models in timm. One of the first questions I had looking at the original pretrained weights was hmm, the smaller models are pretty good for their size but the base barely moved the needle from there, what happened? Is scaling broken?

No. At 102M params, base_plus is matching or passing accuracy levels of ImageNet-22k pretrained ConvNeXt-Large (~200M params), it's not far from the best 22k trained ViT-Large (DeiT-III, ~300M params), it's above Swin/SwinV2-Large and resoundingly so if you consider the runtime performance.

model img_size top1 top5 param_count
mambaout_base_plus_rw.sw_e150_r384_in12k_ft_in1k 384 87.506 98.428 101.66
mambaout_base_plus_rw.sw_e150_in12k_ft_in1k 288 86.912 98.236 101.66
mambaout_base_plus_rw.sw_e150_in12k_ft_in1k 224 86.632 98.156 101.66
mambaout_base_tall_rw.sw_e500_in1k 288 84.974 97.332 86.48
mambaout_base_wide_rw.sw_e500_in1k 288 84.962 97.208 94.45
mambaout_base_short_rw.sw_e500_in1k 288 84.832 97.27 88.83
mambaout_base.in1k 288 84.72 96.93 84.81
mambaout_small_rw.sw_e450_in1k 288 84.598 97.098 48.5
mambaout_small.in1k 288 84.5 96.974 48.49
mambaout_base_wide_rw.sw_e500_in1k 224 84.454 96.864 94.45
mambaout_base_tall_rw.sw_e500_in1k 224 84.434 96.958 86.48
mambaout_base_short_rw.sw_e500_in1k 224 84.362 96.952 88.83
mambaout_base.in1k 224 84.168 96.68 84.81
mambaout_small.in1k 224 84.086 96.63 48.49
mambaout_small_rw.sw_e450_in1k 224 84.024 96.752 48.5
mambaout_tiny.in1k 288 83.448 96.538 26.55
mambaout_tiny.in1k 224 82.736 96.1 26.55
mambaout_kobe.in1k 288 81.054 95.718 9.14
mambaout_kobe.in1k 224 79.986 94.986 9.14
mambaout_femto.in1k 288 79.848 95.14 7.3
mambaout_femto.in1k 224 78.87 94.408 7.3

So, pip install --upgrade timm and give it a try :)

The original research codebase from the paper authors: https://github.com/yuweihao/MambaOut

@article{yu2024mambaout,
  title={MambaOut: Do We Really Need Mamba for Vision?},
  author={Yu, Weihao and Wang, Xinchao},
  journal={arXiv preprint arXiv:2405.07992},
  year={2024}
}