Mamba Out
Structured like a State Space Model (SSM) but shedding that skin to reveal something leaner and meaner, MambaOut is the Bride of machine learning models a warrior trained in the arts of data kung fu, ready to wreak havoc on any dataset that crosses its path. As Kobe would fade away to sink an impossible shot, MambaOut fades out unnecessary computations, leaving only pure, venomous performance in its wake.
</LLM off>
Sorry, I asked for a catchy intro with reference to pop culture re 'black mamba' and that was too funny. In all seriousness, this model that's now included in timm
as of ver 1.0.11 (https://github.com/huggingface/pytorch-image-models) is worth a look. I've spent a bit of time experimenting with it and training a few models from scratch, extending the original weights with a few of my own.
So what exactly is the model? As per the cheeky name, the core block is structured like other Mamba (SSM) vision models but without including the SSM itself. No SSM means no custom compiled kernels or extra dependencies, yay! There's just a 7x7 DW convolution with a gating mechanism sandwiched between two FC layers. See the diagram from the paper (https://arxiv.org/abs/2405.07992) below:
From my vantage point, the model has a LOT in common with the ConvNeXt family -- no BatchNorm, purely convolutional w/ 7x7 DW conv for spatial mixing, some 1x1 PW / FC for channel mixing. Indeed in runtime performance, training behaviour it feels like an extension of that family as well. Comparing the core building block of each we can see why
Comparing runtime performance with ConvNeXt, the extra splits, concats, etc add a bit of overhead, but with torch.compile() used that's nullified and they are quite competitive. Bringing the models into timm
I experimented with a few small tweaks in model structure and training recipe. You can see them in the models below with the _rw
suffix. In the base size range the changes yielded some slightly faster variants w/ more params that were able to eek a slightly higher accuracy. Of the tall
, wide
, and short
base variants I'd consider tall
(slightly deeper and slightly narrower) to be the most worthwhile.
The most interesting addition of mine though is the base_plus
. Increasing both depth and width a little AND pretraining on ImageNet-12k, it's right up there with the best ImageNet-22k pretrained models in timm
. One of the first questions I had looking at the original pretrained weights was hmm, the smaller models are pretty good for their size but the base
barely moved the needle from there, what happened? Is scaling broken?
No. At 102M params, base_plus
is matching or passing accuracy levels of ImageNet-22k pretrained ConvNeXt-Large (~200M params), it's not far from the best 22k trained ViT-Large (DeiT-III, ~300M params), it's above Swin/SwinV2-Large and resoundingly so if you consider the runtime performance.
So, pip install --upgrade timm
and give it a try :)
The original research codebase from the paper authors: https://github.com/yuweihao/MambaOut
@article{yu2024mambaout,
title={MambaOut: Do We Really Need Mamba for Vision?},
author={Yu, Weihao and Wang, Xinchao},
journal={arXiv preprint arXiv:2405.07992},
year={2024}
}