laizeqiang commited on
Commit
c43b0d6
1 Parent(s): 35779aa

First model version

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
2
+
3
+ ### Python ###
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+ ### Python Patch ###
166
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
167
+ poetry.toml
168
+
169
+ # ruff
170
+ .ruff_cache/
171
+
172
+ # LSP config files
173
+ pyrightconfig.json
174
+
175
+ # End of https://www.toptal.com/developers/gitignore/api/python
README.md CHANGED
@@ -9,4 +9,75 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # Anything To Image
13
+
14
+ Generate image from anything with [ImageBind](https://github.com/facebookresearch/ImageBind)'s unified latent space and [stable-diffusion-2-1-unclip](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip).
15
+
16
+ - No training is need.
17
+ - Integration with 🤗 [Diffusers](https://github.com/huggingface/diffusers).
18
+ - `imagebind` is directly copy from [official repo](https://github.com/facebookresearch/ImageBind) with modification.
19
+ - Gradio Demo.
20
+
21
+ ## Audio to Image
22
+
23
+ | `assets/wav/bird_audio.wav` | `assets/wav/dog_audio.wav` | `assets/wav/cattle.wav`
24
+ | --- | --- | --- |
25
+ | ![](assets/generated/bird_audio.png) | ![](assets/generated/dog_audio.png) |![](assets/generated/cattle.png) |
26
+
27
+ ```python
28
+ import imagebind
29
+ import torch
30
+ from diffusers import StableUnCLIPImg2ImgPipeline
31
+
32
+ # construct models
33
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
35
+ "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16"
36
+ )
37
+ pipe = pipe.to(device)
38
+
39
+ model = imagebind.imagebind_huge(pretrained=True)
40
+ model.eval()
41
+ model.to(device)
42
+
43
+ # generate image
44
+ with torch.no_grad():
45
+ audio_paths=["assets/wav/bird_audio.wav"]
46
+ embeddings = model.forward({
47
+ imagebind.ModalityType.AUDIO: imagebind.load_and_transform_audio_data(audio_paths, device),
48
+ })
49
+ embeddings = embeddings[imagebind.ModalityType.AUDIO]
50
+ images = pipe(image_embeds=embeddings.half()).images
51
+ images[0].save("bird_audio.png")
52
+ ```
53
+
54
+ ## More
55
+
56
+ Under construction
57
+
58
+
59
+ ## Citation
60
+
61
+ Latent Diffusion
62
+
63
+ ```bibtex
64
+ @InProceedings{Rombach_2022_CVPR,
65
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
66
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
67
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
68
+ month = {June},
69
+ year = {2022},
70
+ pages = {10684-10695}
71
+ }
72
+ ```
73
+
74
+ ImageBind
75
+ ```bibtex
76
+ @inproceedings{girdhar2023imagebind,
77
+ title={ImageBind: One Embedding Space To Bind Them All},
78
+ author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
79
+ and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
80
+ booktitle={CVPR},
81
+ year={2023}
82
+ }
83
+ ```
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import imagebind
3
+ import torch
4
+ from diffusers import StableUnCLIPImg2ImgPipeline
5
+ import soundfile as sf
6
+
7
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
8
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
9
+ "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16"
10
+ )
11
+ pipe = pipe.to(device)
12
+
13
+ model = imagebind.imagebind_huge(pretrained=True)
14
+ model.eval()
15
+ model.to(device)
16
+
17
+ @torch.no_grad()
18
+ def anything2img(prompt, audio):
19
+ sr, waveform = audio
20
+ audio_path = 'tmp.wav'
21
+ sf.write(audio_path, waveform, sr)
22
+ audio_paths=[audio_path]
23
+ embeddings = model.forward({
24
+ imagebind.ModalityType.AUDIO: imagebind.load_and_transform_audio_data(audio_paths, device),
25
+ })
26
+ embeddings = embeddings[imagebind.ModalityType.AUDIO]
27
+ images = pipe(prompt=prompt, image_embeds=embeddings.half()).images
28
+ return images[0]
29
+
30
+
31
+ demo = gr.Interface(fn=anything2img, inputs=["text", "audio"], outputs="image")
32
+ demo.launch(server_name='0.0.0.0', server_port=10051, share=True)
assets/bird_image.jpg ADDED
assets/car_image.jpg ADDED
assets/dog_image.jpg ADDED
assets/generated/bird_audio.png ADDED
assets/generated/cattle.png ADDED
assets/generated/dog_audio.png ADDED
assets/generated/goat.png ADDED
assets/wav/bird_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8b0c17e3b8b3e5b1324d83a8f598ac9129998877b697d32c194b2a1fe11681a
3
+ size 882078
assets/wav/boy_laugh.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b832d14ab8ec4594940ce73e09bf9e63dfe51743ad0830aa820c12930834598f
3
+ size 760044
assets/wav/car_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79e2335cf1fa4b0a6be7d2a93007f630acd89e178405ae2fabf45ee3af801fda
3
+ size 441044
assets/wav/cat_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c5cfac8d4a1d7a8fd9e67e1605cab5ea3362dd332e7a1645d11eddbfc0c0d1e
3
+ size 210044
assets/wav/cattle.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea07de05dcbf737c34ee98859d83ac58544887ea61d8a97757c32d2aabd660c1
3
+ size 1350222
assets/wav/chick_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90121b234635af7b4ab86edb385afe8012e090e2ced563e8080e08c6657580e7
3
+ size 483918
assets/wav/dog_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b4dca689971840c140e4d206ad43dc35ecf7c2c1b661f478268bc6de9ad50b3
3
+ size 460518
assets/wav/goat.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c5051a17d8ed428b7fd0c299c11a87f27c6b0908975e1a56e0c6cf086f92e11
3
+ size 3450044
assets/wav/rain.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc8119be878eafdc69f7408dad614493a238174372eebf75e4c8314d1bb9807d
3
+ size 441044
imagebind/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data import load_and_transform_text, load_and_transform_audio_data, load_and_transform_video_data, load_and_transform_vision_data
2
+ from .models.imagebind_model import imagebind_huge, ModalityType
imagebind/bpe/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
imagebind/data.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchaudio
14
+ from PIL import Image
15
+ from pytorchvideo import transforms as pv_transforms
16
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
17
+ from pytorchvideo.data.encoded_video import EncodedVideo
18
+ from torchvision import transforms
19
+ from torchvision.transforms._transforms_video import NormalizeVideo
20
+
21
+ from .models.multimodal_preprocessors import SimpleTokenizer
22
+
23
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
24
+
25
+ BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
26
+
27
+
28
+ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
29
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
30
+ waveform -= waveform.mean()
31
+ fbank = torchaudio.compliance.kaldi.fbank(
32
+ waveform,
33
+ htk_compat=True,
34
+ sample_frequency=sample_rate,
35
+ use_energy=False,
36
+ window_type="hanning",
37
+ num_mel_bins=num_mel_bins,
38
+ dither=0.0,
39
+ frame_length=25,
40
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
41
+ )
42
+ # Convert to [mel_bins, num_frames] shape
43
+ fbank = fbank.transpose(0, 1)
44
+ # Pad to target_length
45
+ n_frames = fbank.size(1)
46
+ p = target_length - n_frames
47
+ # if p is too large (say >20%), flash a warning
48
+ if abs(p) / n_frames > 0.2:
49
+ logging.warning(
50
+ "Large gap between audio n_frames(%d) and "
51
+ "target_length (%d). Is the audio_target_length "
52
+ "setting correct?",
53
+ n_frames,
54
+ target_length,
55
+ )
56
+ # cut and pad
57
+ if p > 0:
58
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
59
+ elif p < 0:
60
+ fbank = fbank[:, 0:target_length]
61
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
62
+ # channel image
63
+ fbank = fbank.unsqueeze(0)
64
+ return fbank
65
+
66
+
67
+ def get_clip_timepoints(clip_sampler, duration):
68
+ # Read out all clips in this video
69
+ all_clips_timepoints = []
70
+ is_last_clip = False
71
+ end = 0.0
72
+ while not is_last_clip:
73
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
74
+ all_clips_timepoints.append((start, end))
75
+ return all_clips_timepoints
76
+
77
+
78
+ def load_and_transform_vision_data(image_paths, device):
79
+ if image_paths is None:
80
+ return None
81
+
82
+ image_ouputs = []
83
+ for image_path in image_paths:
84
+ data_transform = transforms.Compose(
85
+ [
86
+ transforms.Resize(
87
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
88
+ ),
89
+ transforms.CenterCrop(224),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize(
92
+ mean=(0.48145466, 0.4578275, 0.40821073),
93
+ std=(0.26862954, 0.26130258, 0.27577711),
94
+ ),
95
+ ]
96
+ )
97
+ with open(image_path, "rb") as fopen:
98
+ image = Image.open(fopen).convert("RGB")
99
+
100
+ image = data_transform(image).to(device)
101
+ image_ouputs.append(image)
102
+ return torch.stack(image_ouputs, dim=0)
103
+
104
+
105
+ def load_and_transform_text(text, device):
106
+ if text is None:
107
+ return None
108
+ tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
109
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
110
+ tokens = torch.cat(tokens, dim=0)
111
+ return tokens
112
+
113
+
114
+ def load_and_transform_audio_data(
115
+ audio_paths,
116
+ device,
117
+ num_mel_bins=128,
118
+ target_length=204,
119
+ sample_rate=16000,
120
+ clip_duration=2,
121
+ clips_per_video=3,
122
+ mean=-4.268,
123
+ std=9.138,
124
+ ):
125
+ if audio_paths is None:
126
+ return None
127
+
128
+ audio_outputs = []
129
+ clip_sampler = ConstantClipsPerVideoSampler(
130
+ clip_duration=clip_duration, clips_per_video=clips_per_video
131
+ )
132
+
133
+ for audio_path in audio_paths:
134
+ waveform, sr = torchaudio.load(audio_path)
135
+ if sample_rate != sr:
136
+ waveform = torchaudio.functional.resample(
137
+ waveform, orig_freq=sr, new_freq=sample_rate
138
+ )
139
+ all_clips_timepoints = get_clip_timepoints(
140
+ clip_sampler, waveform.size(1) / sample_rate
141
+ )
142
+ all_clips = []
143
+ for clip_timepoints in all_clips_timepoints:
144
+ waveform_clip = waveform[
145
+ :,
146
+ int(clip_timepoints[0] * sample_rate) : int(
147
+ clip_timepoints[1] * sample_rate
148
+ ),
149
+ ]
150
+ waveform_melspec = waveform2melspec(
151
+ waveform_clip, sample_rate, num_mel_bins, target_length
152
+ )
153
+ all_clips.append(waveform_melspec)
154
+
155
+ normalize = transforms.Normalize(mean=mean, std=std)
156
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
157
+
158
+ all_clips = torch.stack(all_clips, dim=0)
159
+ audio_outputs.append(all_clips)
160
+
161
+ return torch.stack(audio_outputs, dim=0)
162
+
163
+ def get_clip_timepoints(clip_sampler, duration):
164
+ # Read out all clips in this video
165
+ all_clips_timepoints = []
166
+ is_last_clip = False
167
+ end = 0.0
168
+ while not is_last_clip:
169
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
170
+ all_clips_timepoints.append((start, end))
171
+ return all_clips_timepoints
172
+
173
+
174
+ def crop_boxes(boxes, x_offset, y_offset):
175
+ """
176
+ Peform crop on the bounding boxes given the offsets.
177
+ Args:
178
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
179
+ is `num boxes` x 4.
180
+ x_offset (int): cropping offset in the x axis.
181
+ y_offset (int): cropping offset in the y axis.
182
+ Returns:
183
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
184
+ `num boxes` x 4.
185
+ """
186
+ cropped_boxes = boxes.copy()
187
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
188
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
189
+
190
+ return cropped_boxes
191
+
192
+
193
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
194
+ """
195
+ Perform uniform spatial sampling on the images and corresponding boxes.
196
+ Args:
197
+ images (tensor): images to perform uniform crop. The dimension is
198
+ `num frames` x `channel` x `height` x `width`.
199
+ size (int): size of height and weight to crop the images.
200
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
201
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
202
+ crop if height is larger than width.
203
+ boxes (ndarray or None): optional. Corresponding boxes to images.
204
+ Dimension is `num boxes` x 4.
205
+ scale_size (int): optinal. If not None, resize the images to scale_size before
206
+ performing any crop.
207
+ Returns:
208
+ cropped (tensor): images with dimension of
209
+ `num frames` x `channel` x `size` x `size`.
210
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
211
+ `num boxes` x 4.
212
+ """
213
+ assert spatial_idx in [0, 1, 2]
214
+ ndim = len(images.shape)
215
+ if ndim == 3:
216
+ images = images.unsqueeze(0)
217
+ height = images.shape[2]
218
+ width = images.shape[3]
219
+
220
+ if scale_size is not None:
221
+ if width <= height:
222
+ width, height = scale_size, int(height / width * scale_size)
223
+ else:
224
+ width, height = int(width / height * scale_size), scale_size
225
+ images = torch.nn.functional.interpolate(
226
+ images,
227
+ size=(height, width),
228
+ mode="bilinear",
229
+ align_corners=False,
230
+ )
231
+
232
+ y_offset = int(math.ceil((height - size) / 2))
233
+ x_offset = int(math.ceil((width - size) / 2))
234
+
235
+ if height > width:
236
+ if spatial_idx == 0:
237
+ y_offset = 0
238
+ elif spatial_idx == 2:
239
+ y_offset = height - size
240
+ else:
241
+ if spatial_idx == 0:
242
+ x_offset = 0
243
+ elif spatial_idx == 2:
244
+ x_offset = width - size
245
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
246
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
247
+ if ndim == 3:
248
+ cropped = cropped.squeeze(0)
249
+ return cropped, cropped_boxes
250
+
251
+
252
+ class SpatialCrop(nn.Module):
253
+ """
254
+ Convert the video into 3 smaller clips spatially. Must be used after the
255
+ temporal crops to get spatial crops, and should be used with
256
+ -2 in the spatial crop at the slowfast augmentation stage (so full
257
+ frames are passed in here). Will return a larger list with the
258
+ 3x spatial crops as well.
259
+ """
260
+
261
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
262
+ super().__init__()
263
+ self.crop_size = crop_size
264
+ if num_crops == 3:
265
+ self.crops_to_ext = [0, 1, 2]
266
+ self.flipped_crops_to_ext = []
267
+ elif num_crops == 1:
268
+ self.crops_to_ext = [1]
269
+ self.flipped_crops_to_ext = []
270
+ else:
271
+ raise NotImplementedError("Nothing else supported yet")
272
+
273
+ def forward(self, videos):
274
+ """
275
+ Args:
276
+ videos: A list of C, T, H, W videos.
277
+ Returns:
278
+ videos: A list with 3x the number of elements. Each video converted
279
+ to C, T, H', W' by spatial cropping.
280
+ """
281
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
282
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
283
+ res = []
284
+ for video in videos:
285
+ for spatial_idx in self.crops_to_ext:
286
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
287
+ if not self.flipped_crops_to_ext:
288
+ continue
289
+ flipped_video = transforms.functional.hflip(video)
290
+ for spatial_idx in self.flipped_crops_to_ext:
291
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
292
+ return res
293
+
294
+
295
+ def load_and_transform_video_data(
296
+ video_paths,
297
+ device,
298
+ clip_duration=2,
299
+ clips_per_video=5,
300
+ sample_rate=16000,
301
+ ):
302
+ if video_paths is None:
303
+ return None
304
+
305
+ video_outputs = []
306
+ video_transform = transforms.Compose(
307
+ [
308
+ pv_transforms.ShortSideScale(224),
309
+ NormalizeVideo(
310
+ mean=(0.48145466, 0.4578275, 0.40821073),
311
+ std=(0.26862954, 0.26130258, 0.27577711),
312
+ ),
313
+ ]
314
+ )
315
+
316
+ clip_sampler = ConstantClipsPerVideoSampler(
317
+ clip_duration=clip_duration, clips_per_video=clips_per_video
318
+ )
319
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
320
+
321
+ for video_path in video_paths:
322
+ video = EncodedVideo.from_path(
323
+ video_path,
324
+ decoder="decord",
325
+ decode_audio=False,
326
+ **{"sample_rate": sample_rate},
327
+ )
328
+
329
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
330
+
331
+ all_video = []
332
+ for clip_timepoints in all_clips_timepoints:
333
+ # Read the clip, get frames
334
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
335
+ if clip is None:
336
+ raise ValueError("No clip found")
337
+ video_clip = frame_sampler(clip["video"])
338
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
339
+
340
+ all_video.append(video_clip)
341
+
342
+ all_video = [video_transform(clip) for clip in all_video]
343
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
344
+
345
+ all_video = torch.stack(all_video, dim=0)
346
+ video_outputs.append(all_video)
347
+
348
+ return torch.stack(video_outputs, dim=0).to(device)
imagebind/models/__init__.py ADDED
File without changes
imagebind/models/helpers.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import einops
11
+ import numpy as np
12
+ import torch
13
+
14
+ import torch.nn as nn
15
+
16
+
17
+ class Normalize(nn.Module):
18
+ def __init__(self, dim: int) -> None:
19
+ super().__init__()
20
+ self.dim = dim
21
+
22
+ def forward(self, x):
23
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
24
+
25
+
26
+ class LearnableLogitScaling(nn.Module):
27
+ def __init__(
28
+ self,
29
+ logit_scale_init: float = 1 / 0.07,
30
+ learnable: bool = True,
31
+ max_logit_scale: float = 100,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.max_logit_scale = max_logit_scale
35
+ self.logit_scale_init = logit_scale_init
36
+ self.learnable = learnable
37
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
38
+ if learnable:
39
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
40
+ else:
41
+ self.register_buffer("log_logit_scale", log_logit_scale)
42
+
43
+ def forward(self, x):
44
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
45
+
46
+ def extra_repr(self):
47
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
48
+ return st
49
+
50
+
51
+ class EinOpsRearrange(nn.Module):
52
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
53
+ super().__init__()
54
+ self.rearrange_expr = rearrange_expr
55
+ self.kwargs = kwargs
56
+
57
+ def forward(self, x):
58
+ assert isinstance(x, torch.Tensor)
59
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
60
+
61
+
62
+ class VerboseNNModule(nn.Module):
63
+ """
64
+ Wrapper around nn.Module that prints registered buffers and parameter names.
65
+ """
66
+
67
+ @staticmethod
68
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
69
+ st = (
70
+ "("
71
+ + name
72
+ + "): "
73
+ + "tensor("
74
+ + str(tuple(tensor[1].shape))
75
+ + ", requires_grad="
76
+ + str(tensor[1].requires_grad)
77
+ + ")\n"
78
+ )
79
+ return st
80
+
81
+ def extra_repr(self) -> str:
82
+ named_modules = set()
83
+ for p in self.named_modules():
84
+ named_modules.update([p[0]])
85
+ named_modules = list(named_modules)
86
+
87
+ string_repr = ""
88
+ for p in self.named_parameters():
89
+ name = p[0].split(".")[0]
90
+ if name not in named_modules:
91
+ string_repr += self.get_readable_tensor_repr(name, p)
92
+
93
+ for p in self.named_buffers():
94
+ name = p[0].split(".")[0]
95
+ string_repr += self.get_readable_tensor_repr(name, p)
96
+
97
+ return string_repr
98
+
99
+
100
+ def cast_if_src_dtype(
101
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
102
+ ):
103
+ updated = False
104
+ if tensor.dtype == src_dtype:
105
+ tensor = tensor.to(dtype=tgt_dtype)
106
+ updated = True
107
+ return tensor, updated
108
+
109
+
110
+ class QuickGELU(nn.Module):
111
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
112
+ def forward(self, x: torch.Tensor):
113
+ return x * torch.sigmoid(1.702 * x)
114
+
115
+
116
+ class SelectElement(nn.Module):
117
+ def __init__(self, index) -> None:
118
+ super().__init__()
119
+ self.index = index
120
+
121
+ def forward(self, x):
122
+ assert x.ndim >= 3
123
+ return x[:, self.index, ...]
124
+
125
+
126
+ class SelectEOSAndProject(nn.Module):
127
+ """
128
+ Text Pooling used in OpenCLIP
129
+ """
130
+
131
+ def __init__(self, proj: nn.Module) -> None:
132
+ super().__init__()
133
+ self.proj = proj
134
+
135
+ def forward(self, x, seq_len):
136
+ assert x.ndim == 3
137
+ # x is of shape B x L x D
138
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
139
+ x = x[torch.arange(x.shape[0]), seq_len]
140
+ x = self.proj(x)
141
+ return x
imagebind/models/imagebind_model.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import os
10
+ import urllib
11
+ from functools import partial
12
+ from types import SimpleNamespace
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from .helpers import (
18
+ EinOpsRearrange,
19
+ LearnableLogitScaling,
20
+ Normalize,
21
+ SelectElement,
22
+ SelectEOSAndProject,
23
+ )
24
+ from .multimodal_preprocessors import (
25
+ AudioPreprocessor,
26
+ IMUPreprocessor,
27
+ PadIm2Video,
28
+ PatchEmbedGeneric,
29
+ RGBDTPreprocessor,
30
+ SpatioTemporalPosEmbeddingHelper,
31
+ TextPreprocessor,
32
+ ThermalPreprocessor,
33
+ )
34
+
35
+ from .transformer import MultiheadAttention, SimpleTransformer
36
+
37
+
38
+ ModalityType = SimpleNamespace(
39
+ VISION="vision",
40
+ TEXT="text",
41
+ AUDIO="audio",
42
+ THERMAL="thermal",
43
+ DEPTH="depth",
44
+ IMU="imu",
45
+ )
46
+
47
+
48
+ class ImageBindModel(nn.Module):
49
+ def __init__(
50
+ self,
51
+ video_frames=2,
52
+ kernel_size=(2, 14, 14),
53
+ audio_kernel_size=16,
54
+ audio_stride=10,
55
+ out_embed_dim=768,
56
+ vision_embed_dim=1024,
57
+ vision_num_blocks=24,
58
+ vision_num_heads=16,
59
+ audio_embed_dim=768,
60
+ audio_num_blocks=12,
61
+ audio_num_heads=12,
62
+ audio_num_mel_bins=128,
63
+ audio_target_len=204,
64
+ audio_drop_path=0.1,
65
+ text_embed_dim=768,
66
+ text_num_blocks=12,
67
+ text_num_heads=12,
68
+ depth_embed_dim=384,
69
+ depth_kernel_size=16,
70
+ depth_num_blocks=12,
71
+ depth_num_heads=8,
72
+ depth_drop_path=0.0,
73
+ thermal_embed_dim=768,
74
+ thermal_kernel_size=16,
75
+ thermal_num_blocks=12,
76
+ thermal_num_heads=12,
77
+ thermal_drop_path=0.0,
78
+ imu_embed_dim=512,
79
+ imu_kernel_size=8,
80
+ imu_num_blocks=6,
81
+ imu_num_heads=8,
82
+ imu_drop_path=0.7,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.modality_preprocessors = self._create_modality_preprocessors(
87
+ video_frames,
88
+ vision_embed_dim,
89
+ kernel_size,
90
+ text_embed_dim,
91
+ audio_embed_dim,
92
+ audio_kernel_size,
93
+ audio_stride,
94
+ audio_num_mel_bins,
95
+ audio_target_len,
96
+ depth_embed_dim,
97
+ depth_kernel_size,
98
+ thermal_embed_dim,
99
+ thermal_kernel_size,
100
+ imu_embed_dim,
101
+ )
102
+
103
+ self.modality_trunks = self._create_modality_trunks(
104
+ vision_embed_dim,
105
+ vision_num_blocks,
106
+ vision_num_heads,
107
+ text_embed_dim,
108
+ text_num_blocks,
109
+ text_num_heads,
110
+ audio_embed_dim,
111
+ audio_num_blocks,
112
+ audio_num_heads,
113
+ audio_drop_path,
114
+ depth_embed_dim,
115
+ depth_num_blocks,
116
+ depth_num_heads,
117
+ depth_drop_path,
118
+ thermal_embed_dim,
119
+ thermal_num_blocks,
120
+ thermal_num_heads,
121
+ thermal_drop_path,
122
+ imu_embed_dim,
123
+ imu_num_blocks,
124
+ imu_num_heads,
125
+ imu_drop_path,
126
+ )
127
+
128
+ self.modality_heads = self._create_modality_heads(
129
+ out_embed_dim,
130
+ vision_embed_dim,
131
+ text_embed_dim,
132
+ audio_embed_dim,
133
+ depth_embed_dim,
134
+ thermal_embed_dim,
135
+ imu_embed_dim,
136
+ )
137
+
138
+ self.modality_postprocessors = self._create_modality_postprocessors(
139
+ out_embed_dim
140
+ )
141
+
142
+ def _create_modality_preprocessors(
143
+ self,
144
+ video_frames=2,
145
+ vision_embed_dim=1024,
146
+ kernel_size=(2, 14, 14),
147
+ text_embed_dim=768,
148
+ audio_embed_dim=768,
149
+ audio_kernel_size=16,
150
+ audio_stride=10,
151
+ audio_num_mel_bins=128,
152
+ audio_target_len=204,
153
+ depth_embed_dim=768,
154
+ depth_kernel_size=16,
155
+ thermal_embed_dim=768,
156
+ thermal_kernel_size=16,
157
+ imu_embed_dim=512,
158
+ ):
159
+ rgbt_stem = PatchEmbedGeneric(
160
+ proj_stem=[
161
+ PadIm2Video(pad_type="repeat", ntimes=2),
162
+ nn.Conv3d(
163
+ in_channels=3,
164
+ kernel_size=kernel_size,
165
+ out_channels=vision_embed_dim,
166
+ stride=kernel_size,
167
+ bias=False,
168
+ ),
169
+ ]
170
+ )
171
+ rgbt_preprocessor = RGBDTPreprocessor(
172
+ img_size=[3, video_frames, 224, 224],
173
+ num_cls_tokens=1,
174
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
175
+ rgbt_stem=rgbt_stem,
176
+ depth_stem=None,
177
+ )
178
+
179
+ text_preprocessor = TextPreprocessor(
180
+ context_length=77,
181
+ vocab_size=49408,
182
+ embed_dim=text_embed_dim,
183
+ causal_masking=True,
184
+ )
185
+
186
+ audio_stem = PatchEmbedGeneric(
187
+ proj_stem=[
188
+ nn.Conv2d(
189
+ in_channels=1,
190
+ kernel_size=audio_kernel_size,
191
+ stride=audio_stride,
192
+ out_channels=audio_embed_dim,
193
+ bias=False,
194
+ ),
195
+ ],
196
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
197
+ )
198
+ audio_preprocessor = AudioPreprocessor(
199
+ img_size=[1, audio_num_mel_bins, audio_target_len],
200
+ num_cls_tokens=1,
201
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
202
+ audio_stem=audio_stem,
203
+ )
204
+
205
+ depth_stem = PatchEmbedGeneric(
206
+ [
207
+ nn.Conv2d(
208
+ kernel_size=depth_kernel_size,
209
+ in_channels=1,
210
+ out_channels=depth_embed_dim,
211
+ stride=depth_kernel_size,
212
+ bias=False,
213
+ ),
214
+ ],
215
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
216
+ )
217
+
218
+ depth_preprocessor = RGBDTPreprocessor(
219
+ img_size=[1, 224, 224],
220
+ num_cls_tokens=1,
221
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
222
+ rgbt_stem=None,
223
+ depth_stem=depth_stem,
224
+ )
225
+
226
+ thermal_stem = PatchEmbedGeneric(
227
+ [
228
+ nn.Conv2d(
229
+ kernel_size=thermal_kernel_size,
230
+ in_channels=1,
231
+ out_channels=thermal_embed_dim,
232
+ stride=thermal_kernel_size,
233
+ bias=False,
234
+ ),
235
+ ],
236
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
237
+ )
238
+ thermal_preprocessor = ThermalPreprocessor(
239
+ img_size=[1, 224, 224],
240
+ num_cls_tokens=1,
241
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
242
+ thermal_stem=thermal_stem,
243
+ )
244
+
245
+ imu_stem = PatchEmbedGeneric(
246
+ [
247
+ nn.Linear(
248
+ in_features=48,
249
+ out_features=imu_embed_dim,
250
+ bias=False,
251
+ ),
252
+ ],
253
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
254
+ )
255
+
256
+ imu_preprocessor = IMUPreprocessor(
257
+ img_size=[6, 2000],
258
+ num_cls_tokens=1,
259
+ kernel_size=8,
260
+ embed_dim=imu_embed_dim,
261
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
262
+ imu_stem=imu_stem,
263
+ )
264
+
265
+ modality_preprocessors = {
266
+ ModalityType.VISION: rgbt_preprocessor,
267
+ ModalityType.TEXT: text_preprocessor,
268
+ ModalityType.AUDIO: audio_preprocessor,
269
+ ModalityType.DEPTH: depth_preprocessor,
270
+ ModalityType.THERMAL: thermal_preprocessor,
271
+ ModalityType.IMU: imu_preprocessor,
272
+ }
273
+
274
+ return nn.ModuleDict(modality_preprocessors)
275
+
276
+ def _create_modality_trunks(
277
+ self,
278
+ vision_embed_dim=1024,
279
+ vision_num_blocks=24,
280
+ vision_num_heads=16,
281
+ text_embed_dim=768,
282
+ text_num_blocks=12,
283
+ text_num_heads=12,
284
+ audio_embed_dim=768,
285
+ audio_num_blocks=12,
286
+ audio_num_heads=12,
287
+ audio_drop_path=0.0,
288
+ depth_embed_dim=768,
289
+ depth_num_blocks=12,
290
+ depth_num_heads=12,
291
+ depth_drop_path=0.0,
292
+ thermal_embed_dim=768,
293
+ thermal_num_blocks=12,
294
+ thermal_num_heads=12,
295
+ thermal_drop_path=0.0,
296
+ imu_embed_dim=512,
297
+ imu_num_blocks=6,
298
+ imu_num_heads=8,
299
+ imu_drop_path=0.7,
300
+ ):
301
+ def instantiate_trunk(
302
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
303
+ ):
304
+ return SimpleTransformer(
305
+ embed_dim=embed_dim,
306
+ num_blocks=num_blocks,
307
+ ffn_dropout_rate=0.0,
308
+ drop_path_rate=drop_path,
309
+ attn_target=partial(
310
+ MultiheadAttention,
311
+ embed_dim=embed_dim,
312
+ num_heads=num_heads,
313
+ bias=True,
314
+ add_bias_kv=add_bias_kv,
315
+ ),
316
+ pre_transformer_layer=nn.Sequential(
317
+ nn.LayerNorm(embed_dim, eps=1e-6)
318
+ if pre_transformer_ln
319
+ else nn.Identity(),
320
+ EinOpsRearrange("b l d -> l b d"),
321
+ ),
322
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
323
+ )
324
+
325
+ modality_trunks = {}
326
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
327
+ vision_embed_dim,
328
+ vision_num_blocks,
329
+ vision_num_heads,
330
+ pre_transformer_ln=True,
331
+ add_bias_kv=False,
332
+ drop_path=0.0,
333
+ )
334
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
335
+ text_embed_dim,
336
+ text_num_blocks,
337
+ text_num_heads,
338
+ pre_transformer_ln=False,
339
+ add_bias_kv=False,
340
+ drop_path=0.0,
341
+ )
342
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
343
+ audio_embed_dim,
344
+ audio_num_blocks,
345
+ audio_num_heads,
346
+ pre_transformer_ln=False,
347
+ add_bias_kv=True,
348
+ drop_path=audio_drop_path,
349
+ )
350
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
351
+ depth_embed_dim,
352
+ depth_num_blocks,
353
+ depth_num_heads,
354
+ pre_transformer_ln=False,
355
+ add_bias_kv=True,
356
+ drop_path=depth_drop_path,
357
+ )
358
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
359
+ thermal_embed_dim,
360
+ thermal_num_blocks,
361
+ thermal_num_heads,
362
+ pre_transformer_ln=False,
363
+ add_bias_kv=True,
364
+ drop_path=thermal_drop_path,
365
+ )
366
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
367
+ imu_embed_dim,
368
+ imu_num_blocks,
369
+ imu_num_heads,
370
+ pre_transformer_ln=False,
371
+ add_bias_kv=True,
372
+ drop_path=imu_drop_path,
373
+ )
374
+
375
+ return nn.ModuleDict(modality_trunks)
376
+
377
+ def _create_modality_heads(
378
+ self,
379
+ out_embed_dim,
380
+ vision_embed_dim,
381
+ text_embed_dim,
382
+ audio_embed_dim,
383
+ depth_embed_dim,
384
+ thermal_embed_dim,
385
+ imu_embed_dim,
386
+ ):
387
+ modality_heads = {}
388
+
389
+ modality_heads[ModalityType.VISION] = nn.Sequential(
390
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
391
+ SelectElement(index=0),
392
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
393
+ )
394
+
395
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
396
+ proj=nn.Sequential(
397
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
398
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
399
+ )
400
+ )
401
+
402
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
403
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
404
+ SelectElement(index=0),
405
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
406
+ )
407
+
408
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
409
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
410
+ SelectElement(index=0),
411
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
412
+ )
413
+
414
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
415
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
416
+ SelectElement(index=0),
417
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
418
+ )
419
+
420
+ modality_heads[ModalityType.IMU] = nn.Sequential(
421
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
422
+ SelectElement(index=0),
423
+ nn.Dropout(p=0.5),
424
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
425
+ )
426
+
427
+ return nn.ModuleDict(modality_heads)
428
+
429
+ def _create_modality_postprocessors(self, out_embed_dim):
430
+ modality_postprocessors = {}
431
+
432
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
433
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
434
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
435
+ )
436
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
437
+ Normalize(dim=-1),
438
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
439
+ )
440
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
441
+ Normalize(dim=-1),
442
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
443
+ )
444
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
445
+ Normalize(dim=-1),
446
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
447
+ )
448
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
449
+ Normalize(dim=-1),
450
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
451
+ )
452
+
453
+ return nn.ModuleDict(modality_postprocessors)
454
+
455
+ def forward(self, inputs):
456
+ outputs = {}
457
+ for modality_key, modality_value in inputs.items():
458
+ reduce_list = (
459
+ modality_value.ndim >= 5
460
+ ) # Audio and Video inputs consist of multiple clips
461
+ if reduce_list:
462
+ B, S = modality_value.shape[:2]
463
+ modality_value = modality_value.reshape(
464
+ B * S, *modality_value.shape[2:]
465
+ )
466
+
467
+ if modality_value is not None:
468
+ modality_value = self.modality_preprocessors[modality_key](
469
+ **{modality_key: modality_value}
470
+ )
471
+ trunk_inputs = modality_value["trunk"]
472
+ head_inputs = modality_value["head"]
473
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
474
+ modality_value = self.modality_heads[modality_key](
475
+ modality_value, **head_inputs
476
+ )
477
+ modality_value = self.modality_postprocessors[modality_key](
478
+ modality_value
479
+ )
480
+
481
+ if reduce_list:
482
+ modality_value = modality_value.reshape(B, S, -1)
483
+ modality_value = modality_value.mean(dim=1)
484
+
485
+ outputs[modality_key] = modality_value
486
+
487
+ return outputs
488
+
489
+
490
+ def imagebind_huge(pretrained=False):
491
+ model = ImageBindModel(
492
+ vision_embed_dim=1280,
493
+ vision_num_blocks=32,
494
+ vision_num_heads=16,
495
+ text_embed_dim=1024,
496
+ text_num_blocks=24,
497
+ text_num_heads=16,
498
+ out_embed_dim=1024,
499
+ audio_drop_path=0.1,
500
+ imu_drop_path=0.7,
501
+ )
502
+
503
+ if pretrained:
504
+ if not os.path.exists("checkpoints/imagebind_huge.pth"):
505
+ print(
506
+ "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
507
+ )
508
+ os.makedirs("checkpoints", exist_ok=True)
509
+ torch.hub.download_url_to_file(
510
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
511
+ "checkpoints/imagebind_huge.pth",
512
+ progress=True,
513
+ )
514
+
515
+ model.load_state_dict(torch.load("checkpoints/imagebind_huge.pth"))
516
+
517
+ return model
imagebind/models/multimodal_preprocessors.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gzip
9
+ import html
10
+ import io
11
+ import math
12
+ from functools import lru_cache
13
+ from typing import Callable, List, Optional
14
+
15
+ import ftfy
16
+
17
+ import numpy as np
18
+ import regex as re
19
+ import torch
20
+ import torch.nn as nn
21
+ from iopath.common.file_io import g_pathmgr
22
+ from timm.models.layers import trunc_normal_
23
+
24
+ from .helpers import cast_if_src_dtype, VerboseNNModule
25
+
26
+
27
+ def get_sinusoid_encoding_table(n_position, d_hid):
28
+ """Sinusoid position encoding table"""
29
+
30
+ # TODO: make it with torch instead of numpy
31
+ def get_position_angle_vec(position):
32
+ return [
33
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
34
+ for hid_j in range(d_hid)
35
+ ]
36
+
37
+ sinusoid_table = np.array(
38
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
39
+ )
40
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
41
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
42
+
43
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
44
+
45
+
46
+ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
47
+ N = pos_embed.shape[1]
48
+ if N == target_spatial_size:
49
+ return pos_embed
50
+ dim = pos_embed.shape[-1]
51
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
52
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
53
+ pos_embed = nn.functional.interpolate(
54
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
55
+ 0, 3, 1, 2
56
+ ),
57
+ scale_factor=math.sqrt(target_spatial_size / N),
58
+ mode="bicubic",
59
+ )
60
+ if updated:
61
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
62
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
63
+ return pos_embed
64
+
65
+
66
+ def interpolate_pos_encoding(
67
+ npatch_per_img,
68
+ pos_embed,
69
+ patches_layout,
70
+ input_shape=None,
71
+ first_patch_idx=1,
72
+ ):
73
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
74
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
75
+ if npatch_per_img == N:
76
+ return pos_embed
77
+
78
+ assert (
79
+ patches_layout[-1] == patches_layout[-2]
80
+ ), "Interpolation of pos embed not supported for non-square layouts"
81
+
82
+ class_emb = pos_embed[:, :first_patch_idx]
83
+ pos_embed = pos_embed[:, first_patch_idx:]
84
+
85
+ if input_shape is None or patches_layout[0] == 1:
86
+ # simple 2D pos embedding, no temporal component
87
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
88
+ elif patches_layout[0] > 1:
89
+ # pos embed has a temporal component
90
+ assert len(input_shape) == 4, "temporal interpolation not supported"
91
+ # we only support 2D interpolation in this case
92
+ num_frames = patches_layout[0]
93
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
94
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
95
+ # interpolate embedding for zeroth frame
96
+ pos_embed = interpolate_pos_encoding_2d(
97
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
98
+ )
99
+ else:
100
+ raise ValueError("This type of interpolation isn't implemented")
101
+
102
+ return torch.cat((class_emb, pos_embed), dim=1)
103
+
104
+
105
+ def _get_pos_embedding(
106
+ npatch_per_img,
107
+ pos_embed,
108
+ patches_layout,
109
+ input_shape,
110
+ first_patch_idx=1,
111
+ ):
112
+ pos_embed = interpolate_pos_encoding(
113
+ npatch_per_img,
114
+ pos_embed,
115
+ patches_layout,
116
+ input_shape=input_shape,
117
+ first_patch_idx=first_patch_idx,
118
+ )
119
+ return pos_embed
120
+
121
+
122
+ class PatchEmbedGeneric(nn.Module):
123
+ """
124
+ PatchEmbed from Hydra
125
+ """
126
+
127
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
128
+ super().__init__()
129
+
130
+ if len(proj_stem) > 1:
131
+ self.proj = nn.Sequential(*proj_stem)
132
+ else:
133
+ # Special case to be able to load pre-trained models that were
134
+ # trained with a standard stem
135
+ self.proj = proj_stem[0]
136
+ self.norm_layer = norm_layer
137
+
138
+ def get_patch_layout(self, img_size):
139
+ with torch.no_grad():
140
+ dummy_img = torch.zeros(
141
+ [
142
+ 1,
143
+ ]
144
+ + img_size
145
+ )
146
+ dummy_out = self.proj(dummy_img)
147
+ embed_dim = dummy_out.shape[1]
148
+ patches_layout = tuple(dummy_out.shape[2:])
149
+ num_patches = np.prod(patches_layout)
150
+ return patches_layout, num_patches, embed_dim
151
+
152
+ def forward(self, x):
153
+ x = self.proj(x)
154
+ # B C (T) H W -> B (T)HW C
155
+ x = x.flatten(2).transpose(1, 2)
156
+ if self.norm_layer is not None:
157
+ x = self.norm_layer(x)
158
+ return x
159
+
160
+
161
+ class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
162
+ def __init__(
163
+ self,
164
+ patches_layout: List,
165
+ num_patches: int,
166
+ num_cls_tokens: int,
167
+ embed_dim: int,
168
+ learnable: bool,
169
+ ) -> None:
170
+ super().__init__()
171
+ self.num_cls_tokens = num_cls_tokens
172
+ self.patches_layout = patches_layout
173
+ self.num_patches = num_patches
174
+ self.num_tokens = num_cls_tokens + num_patches
175
+ self.learnable = learnable
176
+ if self.learnable:
177
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
178
+ trunc_normal_(self.pos_embed, std=0.02)
179
+ else:
180
+ self.register_buffer(
181
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
182
+ )
183
+
184
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
185
+ input_shape = vision_input.shape
186
+ pos_embed = _get_pos_embedding(
187
+ all_vision_tokens.size(1) - self.num_cls_tokens,
188
+ pos_embed=self.pos_embed,
189
+ patches_layout=self.patches_layout,
190
+ input_shape=input_shape,
191
+ first_patch_idx=self.num_cls_tokens,
192
+ )
193
+ return pos_embed
194
+
195
+
196
+ class RGBDTPreprocessor(VerboseNNModule):
197
+ def __init__(
198
+ self,
199
+ rgbt_stem: PatchEmbedGeneric,
200
+ depth_stem: PatchEmbedGeneric,
201
+ img_size: List = (3, 224, 224),
202
+ num_cls_tokens: int = 1,
203
+ pos_embed_fn: Callable = None,
204
+ use_type_embed: bool = False,
205
+ init_param_style: str = "openclip",
206
+ ) -> None:
207
+ super().__init__()
208
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
209
+ (
210
+ self.patches_layout,
211
+ self.num_patches,
212
+ self.embed_dim,
213
+ ) = stem.get_patch_layout(img_size)
214
+ self.rgbt_stem = rgbt_stem
215
+ self.depth_stem = depth_stem
216
+ self.use_pos_embed = pos_embed_fn is not None
217
+ self.use_type_embed = use_type_embed
218
+ self.num_cls_tokens = num_cls_tokens
219
+
220
+ if self.use_pos_embed:
221
+ self.pos_embedding_helper = pos_embed_fn(
222
+ patches_layout=self.patches_layout,
223
+ num_cls_tokens=num_cls_tokens,
224
+ num_patches=self.num_patches,
225
+ embed_dim=self.embed_dim,
226
+ )
227
+ if self.num_cls_tokens > 0:
228
+ self.cls_token = nn.Parameter(
229
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
230
+ )
231
+ if self.use_type_embed:
232
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
233
+
234
+ self.init_parameters(init_param_style)
235
+
236
+ @torch.no_grad()
237
+ def init_parameters(self, init_param_style):
238
+ if init_param_style == "openclip":
239
+ # OpenCLIP style initialization
240
+ scale = self.embed_dim**-0.5
241
+ if self.use_pos_embed:
242
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
243
+ self.pos_embedding_helper.pos_embed *= scale
244
+
245
+ if self.num_cls_tokens > 0:
246
+ nn.init.normal_(self.cls_token)
247
+ self.cls_token *= scale
248
+ elif init_param_style == "vit":
249
+ self.cls_token.data.fill_(0)
250
+ else:
251
+ raise ValueError(f"Unknown init {init_param_style}")
252
+
253
+ if self.use_type_embed:
254
+ nn.init.normal_(self.type_embed)
255
+
256
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
257
+ # tokens is of shape B x L x D
258
+ tokens = stem(input)
259
+ assert tokens.ndim == 3
260
+ assert tokens.shape[2] == self.embed_dim
261
+ B = tokens.shape[0]
262
+ if self.num_cls_tokens > 0:
263
+ class_tokens = self.cls_token.expand(
264
+ B, -1, -1
265
+ ) # stole class_tokens impl from Phil Wang, thanks
266
+ tokens = torch.cat((class_tokens, tokens), dim=1)
267
+ if self.use_pos_embed:
268
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
269
+ tokens = tokens + pos_embed
270
+ if self.use_type_embed:
271
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
272
+ return tokens
273
+
274
+ def forward(self, vision=None, depth=None, patch_mask=None):
275
+ if patch_mask is not None:
276
+ raise NotImplementedError()
277
+
278
+ if vision is not None:
279
+ vision_tokens = self.tokenize_input_and_cls_pos(
280
+ vision, self.rgbt_stem, patch_mask
281
+ )
282
+
283
+ if depth is not None:
284
+ depth_tokens = self.tokenize_input_and_cls_pos(
285
+ depth, self.depth_stem, patch_mask
286
+ )
287
+
288
+ # aggregate tokens
289
+ if vision is not None and depth is not None:
290
+ final_tokens = vision_tokens + depth_tokens
291
+ else:
292
+ final_tokens = vision_tokens if vision is not None else depth_tokens
293
+ return_dict = {
294
+ "trunk": {
295
+ "tokens": final_tokens,
296
+ },
297
+ "head": {},
298
+ }
299
+ return return_dict
300
+
301
+
302
+ class AudioPreprocessor(RGBDTPreprocessor):
303
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
304
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
305
+
306
+ def forward(self, audio=None):
307
+ return super().forward(vision=audio)
308
+
309
+
310
+ class ThermalPreprocessor(RGBDTPreprocessor):
311
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
312
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
313
+
314
+ def forward(self, thermal=None):
315
+ return super().forward(vision=thermal)
316
+
317
+
318
+ def build_causal_attention_mask(context_length):
319
+ # lazily create causal attention mask, with full attention between the vision tokens
320
+ # pytorch uses additive attention mask; fill with -inf
321
+ mask = torch.empty(context_length, context_length, requires_grad=False)
322
+ mask.fill_(float("-inf"))
323
+ mask.triu_(1) # zero out the lower diagonal
324
+ return mask
325
+
326
+
327
+ class TextPreprocessor(VerboseNNModule):
328
+ def __init__(
329
+ self,
330
+ vocab_size: int,
331
+ context_length: int,
332
+ embed_dim: int,
333
+ causal_masking: bool,
334
+ supply_seq_len_to_head: bool = True,
335
+ num_cls_tokens: int = 0,
336
+ init_param_style: str = "openclip",
337
+ ) -> None:
338
+ super().__init__()
339
+ self.vocab_size = vocab_size
340
+ self.context_length = context_length
341
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
342
+ self.pos_embed = nn.Parameter(
343
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
344
+ )
345
+ self.causal_masking = causal_masking
346
+ if self.causal_masking:
347
+ mask = build_causal_attention_mask(self.context_length)
348
+ # register the mask as a buffer so it can be moved to the right device
349
+ self.register_buffer("mask", mask)
350
+
351
+ self.supply_seq_len_to_head = supply_seq_len_to_head
352
+ self.num_cls_tokens = num_cls_tokens
353
+ self.embed_dim = embed_dim
354
+ if num_cls_tokens > 0:
355
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
356
+ self.cls_token = nn.Parameter(
357
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
358
+ )
359
+
360
+ self.init_parameters(init_param_style)
361
+
362
+ @torch.no_grad()
363
+ def init_parameters(self, init_param_style="openclip"):
364
+ # OpenCLIP style initialization
365
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
366
+ nn.init.normal_(self.pos_embed, std=0.01)
367
+
368
+ if init_param_style == "openclip":
369
+ # OpenCLIP style initialization
370
+ scale = self.embed_dim**-0.5
371
+ if self.num_cls_tokens > 0:
372
+ nn.init.normal_(self.cls_token)
373
+ self.cls_token *= scale
374
+ elif init_param_style == "vit":
375
+ self.cls_token.data.fill_(0)
376
+ else:
377
+ raise ValueError(f"Unknown init {init_param_style}")
378
+
379
+ def forward(self, text):
380
+ # text tokens are of shape B x L x D
381
+ text_tokens = self.token_embedding(text)
382
+ # concat CLS tokens if any
383
+ if self.num_cls_tokens > 0:
384
+ B = text_tokens.shape[0]
385
+ class_tokens = self.cls_token.expand(
386
+ B, -1, -1
387
+ ) # stole class_tokens impl from Phil Wang, thanks
388
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
389
+ text_tokens = text_tokens + self.pos_embed
390
+ return_dict = {
391
+ "trunk": {
392
+ "tokens": text_tokens,
393
+ },
394
+ "head": {},
395
+ }
396
+ # Compute sequence length after adding CLS tokens
397
+ if self.supply_seq_len_to_head:
398
+ text_lengths = text.argmax(dim=-1)
399
+ return_dict["head"] = {
400
+ "seq_len": text_lengths,
401
+ }
402
+ if self.causal_masking:
403
+ return_dict["trunk"].update({"attn_mask": self.mask})
404
+ return return_dict
405
+
406
+
407
+ class Im2Video(nn.Module):
408
+ """Convert an image into a trivial video."""
409
+
410
+ def __init__(self, time_dim=2):
411
+ super().__init__()
412
+ self.time_dim = time_dim
413
+
414
+ def forward(self, x):
415
+ if x.ndim == 4:
416
+ # B, C, H, W -> B, C, T, H, W
417
+ return x.unsqueeze(self.time_dim)
418
+ elif x.ndim == 5:
419
+ return x
420
+ else:
421
+ raise ValueError(f"Dimension incorrect {x.shape}")
422
+
423
+
424
+ class PadIm2Video(Im2Video):
425
+ def __init__(self, ntimes, pad_type, time_dim=2):
426
+ super().__init__(time_dim=time_dim)
427
+ assert ntimes > 0
428
+ assert pad_type in ["zero", "repeat"]
429
+ self.ntimes = ntimes
430
+ self.pad_type = pad_type
431
+
432
+ def forward(self, x):
433
+ x = super().forward(x)
434
+ if x.shape[self.time_dim] == 1:
435
+ if self.pad_type == "repeat":
436
+ new_shape = [1] * len(x.shape)
437
+ new_shape[self.time_dim] = self.ntimes
438
+ x = x.repeat(new_shape)
439
+ elif self.pad_type == "zero":
440
+ padarg = [0, 0] * len(x.shape)
441
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
442
+ x = nn.functional.pad(x, padarg)
443
+ return x
444
+
445
+
446
+ # Modified from github.com/openai/CLIP
447
+ @lru_cache()
448
+ def bytes_to_unicode():
449
+ """
450
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
451
+ The reversible bpe codes work on unicode strings.
452
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
453
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
454
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
455
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
456
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
457
+ """
458
+ bs = (
459
+ list(range(ord("!"), ord("~") + 1))
460
+ + list(range(ord("¡"), ord("¬") + 1))
461
+ + list(range(ord("®"), ord("ÿ") + 1))
462
+ )
463
+ cs = bs[:]
464
+ n = 0
465
+ for b in range(2**8):
466
+ if b not in bs:
467
+ bs.append(b)
468
+ cs.append(2**8 + n)
469
+ n += 1
470
+ cs = [chr(n) for n in cs]
471
+ return dict(zip(bs, cs))
472
+
473
+
474
+ def get_pairs(word):
475
+ """Return set of symbol pairs in a word.
476
+ Word is represented as tuple of symbols (symbols being variable-length strings).
477
+ """
478
+ pairs = set()
479
+ prev_char = word[0]
480
+ for char in word[1:]:
481
+ pairs.add((prev_char, char))
482
+ prev_char = char
483
+ return pairs
484
+
485
+
486
+ def basic_clean(text):
487
+ text = ftfy.fix_text(text)
488
+ text = html.unescape(html.unescape(text))
489
+ return text.strip()
490
+
491
+
492
+ def whitespace_clean(text):
493
+ text = re.sub(r"\s+", " ", text)
494
+ text = text.strip()
495
+ return text
496
+
497
+
498
+ class SimpleTokenizer(object):
499
+ def __init__(self, bpe_path: str, context_length=77):
500
+ self.byte_encoder = bytes_to_unicode()
501
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
502
+
503
+ with g_pathmgr.open(bpe_path, "rb") as fh:
504
+ bpe_bytes = io.BytesIO(fh.read())
505
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
506
+ merges = merges[1 : 49152 - 256 - 2 + 1]
507
+ merges = [tuple(merge.split()) for merge in merges]
508
+ vocab = list(bytes_to_unicode().values())
509
+ vocab = vocab + [v + "</w>" for v in vocab]
510
+ for merge in merges:
511
+ vocab.append("".join(merge))
512
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
513
+ self.encoder = dict(zip(vocab, range(len(vocab))))
514
+ self.decoder = {v: k for k, v in self.encoder.items()}
515
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
516
+ self.cache = {
517
+ "<|startoftext|>": "<|startoftext|>",
518
+ "<|endoftext|>": "<|endoftext|>",
519
+ }
520
+ self.pat = re.compile(
521
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
522
+ re.IGNORECASE,
523
+ )
524
+ self.context_length = context_length
525
+
526
+ def bpe(self, token):
527
+ if token in self.cache:
528
+ return self.cache[token]
529
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
530
+ pairs = get_pairs(word)
531
+
532
+ if not pairs:
533
+ return token + "</w>"
534
+
535
+ while True:
536
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
537
+ if bigram not in self.bpe_ranks:
538
+ break
539
+ first, second = bigram
540
+ new_word = []
541
+ i = 0
542
+ while i < len(word):
543
+ try:
544
+ j = word.index(first, i)
545
+ new_word.extend(word[i:j])
546
+ i = j
547
+ except:
548
+ new_word.extend(word[i:])
549
+ break
550
+
551
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
552
+ new_word.append(first + second)
553
+ i += 2
554
+ else:
555
+ new_word.append(word[i])
556
+ i += 1
557
+ new_word = tuple(new_word)
558
+ word = new_word
559
+ if len(word) == 1:
560
+ break
561
+ else:
562
+ pairs = get_pairs(word)
563
+ word = " ".join(word)
564
+ self.cache[token] = word
565
+ return word
566
+
567
+ def encode(self, text):
568
+ bpe_tokens = []
569
+ text = whitespace_clean(basic_clean(text)).lower()
570
+ for token in re.findall(self.pat, text):
571
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
572
+ bpe_tokens.extend(
573
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
574
+ )
575
+ return bpe_tokens
576
+
577
+ def decode(self, tokens):
578
+ text = "".join([self.decoder[token] for token in tokens])
579
+ text = (
580
+ bytearray([self.byte_decoder[c] for c in text])
581
+ .decode("utf-8", errors="replace")
582
+ .replace("</w>", " ")
583
+ )
584
+ return text
585
+
586
+ def __call__(self, texts, context_length=None):
587
+ if not context_length:
588
+ context_length = self.context_length
589
+
590
+ if isinstance(texts, str):
591
+ texts = [texts]
592
+
593
+ sot_token = self.encoder["<|startoftext|>"]
594
+ eot_token = self.encoder["<|endoftext|>"]
595
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
596
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
597
+
598
+ for i, tokens in enumerate(all_tokens):
599
+ tokens = tokens[:context_length]
600
+ result[i, : len(tokens)] = torch.tensor(tokens)
601
+
602
+ if len(result) == 1:
603
+ return result[0]
604
+ return result
605
+
606
+
607
+ class IMUPreprocessor(VerboseNNModule):
608
+ def __init__(
609
+ self,
610
+ kernel_size: int,
611
+ imu_stem: PatchEmbedGeneric,
612
+ embed_dim: int,
613
+ img_size: List = (6, 2000),
614
+ num_cls_tokens: int = 1,
615
+ pos_embed_fn: Callable = None,
616
+ init_param_style: str = "openclip",
617
+ ) -> None:
618
+ super().__init__()
619
+ stem = imu_stem
620
+ self.imu_stem = imu_stem
621
+ self.embed_dim = embed_dim
622
+ self.use_pos_embed = pos_embed_fn is not None
623
+ self.num_cls_tokens = num_cls_tokens
624
+ self.kernel_size = kernel_size
625
+ self.pos_embed = nn.Parameter(
626
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
627
+ )
628
+
629
+ if self.num_cls_tokens > 0:
630
+ self.cls_token = nn.Parameter(
631
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
632
+ )
633
+
634
+ self.init_parameters(init_param_style)
635
+
636
+ @torch.no_grad()
637
+ def init_parameters(self, init_param_style):
638
+ nn.init.normal_(self.pos_embed, std=0.01)
639
+
640
+ if init_param_style == "openclip":
641
+ # OpenCLIP style initialization
642
+ scale = self.embed_dim**-0.5
643
+
644
+ if self.num_cls_tokens > 0:
645
+ nn.init.normal_(self.cls_token)
646
+ self.cls_token *= scale
647
+ elif init_param_style == "vit":
648
+ self.cls_token.data.fill_(0)
649
+ else:
650
+ raise ValueError(f"Unknown init {init_param_style}")
651
+
652
+ def tokenize_input_and_cls_pos(self, input, stem):
653
+ # tokens is of shape B x L x D
654
+ tokens = stem.norm_layer(stem.proj(input))
655
+ assert tokens.ndim == 3
656
+ assert tokens.shape[2] == self.embed_dim
657
+ B = tokens.shape[0]
658
+ if self.num_cls_tokens > 0:
659
+ class_tokens = self.cls_token.expand(
660
+ B, -1, -1
661
+ ) # stole class_tokens impl from Phil Wang, thanks
662
+ tokens = torch.cat((class_tokens, tokens), dim=1)
663
+ if self.use_pos_embed:
664
+ tokens = tokens + self.pos_embed
665
+ return tokens
666
+
667
+ def forward(self, imu):
668
+ # Patchify
669
+ imu = imu.unfold(
670
+ -1,
671
+ self.kernel_size,
672
+ self.kernel_size,
673
+ ).permute(0, 2, 1, 3)
674
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
675
+
676
+ imu_tokens = self.tokenize_input_and_cls_pos(
677
+ imu,
678
+ self.imu_stem,
679
+ )
680
+
681
+ return_dict = {
682
+ "trunk": {
683
+ "tokens": imu_tokens,
684
+ },
685
+ "head": {},
686
+ }
687
+ return return_dict
imagebind/models/transformer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Code modified from
9
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
10
+ # https://github.com/facebookresearch/deit/blob/main/models.py
11
+ # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
12
+
13
+
14
+ import copy
15
+ import fnmatch
16
+ import logging
17
+ from functools import partial
18
+ from typing import Callable, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.checkpoint as checkpoint
23
+
24
+ from timm.models.layers import DropPath, trunc_normal_
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim,
31
+ num_heads=8,
32
+ qkv_bias=False,
33
+ qk_scale=None,
34
+ attn_drop=0.0,
35
+ proj_drop=0.0,
36
+ ):
37
+ super().__init__()
38
+ self.num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ # NOTE scale factor was wrong in my original version,
41
+ # can set manually to be compat with prev weights
42
+ self.scale = qk_scale or head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x):
50
+ B, N, C = x.shape
51
+ qkv = (
52
+ self.qkv(x)
53
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
54
+ .permute(2, 0, 3, 1, 4)
55
+ )
56
+ q, k, v = (
57
+ qkv[0],
58
+ qkv[1],
59
+ qkv[2],
60
+ ) # make torchscript happy (cannot use tensor as tuple)
61
+
62
+ attn = (q @ k.transpose(-2, -1)) * self.scale
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(
74
+ self,
75
+ in_features,
76
+ hidden_features=None,
77
+ out_features=None,
78
+ act_layer=nn.GELU,
79
+ drop=0.0,
80
+ ):
81
+ super().__init__()
82
+ out_features = out_features or in_features
83
+ hidden_features = hidden_features or in_features
84
+ self.fc1 = nn.Linear(in_features, hidden_features)
85
+ self.act = act_layer()
86
+ self.fc2 = nn.Linear(hidden_features, out_features)
87
+ self.drop = nn.Dropout(drop)
88
+
89
+ def forward(self, x):
90
+ x = self.fc1(x)
91
+ x = self.act(x)
92
+ x = self.drop(x)
93
+ x = self.fc2(x)
94
+ x = self.drop(x)
95
+ return x
96
+
97
+
98
+ class MultiheadAttention(nn.MultiheadAttention):
99
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
100
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
101
+
102
+
103
+ class ViTAttention(Attention):
104
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
105
+ assert attn_mask is None
106
+ return super().forward(x)
107
+
108
+
109
+ class BlockWithMasking(nn.Module):
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ attn_target: Callable,
114
+ mlp_ratio: int = 4,
115
+ act_layer: Callable = nn.GELU,
116
+ norm_layer: Callable = nn.LayerNorm,
117
+ ffn_dropout_rate: float = 0.0,
118
+ drop_path: float = 0.0,
119
+ layer_scale_type: str = None,
120
+ layer_scale_init_value: float = 1e-4,
121
+ ):
122
+ super().__init__()
123
+
124
+ assert not isinstance(
125
+ attn_target, nn.Module
126
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
127
+ self.attn = attn_target()
128
+ if drop_path > 0.0:
129
+ self.drop_path = DropPath(drop_path)
130
+ else:
131
+ self.drop_path = nn.Identity()
132
+ self.norm_1 = norm_layer(dim)
133
+ mlp_hidden_dim = int(mlp_ratio * dim)
134
+ self.mlp = Mlp(
135
+ in_features=dim,
136
+ hidden_features=mlp_hidden_dim,
137
+ act_layer=act_layer,
138
+ drop=ffn_dropout_rate,
139
+ )
140
+ self.norm_2 = norm_layer(dim)
141
+ self.layer_scale_type = layer_scale_type
142
+ if self.layer_scale_type is not None:
143
+ assert self.layer_scale_type in [
144
+ "per_channel",
145
+ "scalar",
146
+ ], f"Found Layer scale type {self.layer_scale_type}"
147
+ if self.layer_scale_type == "per_channel":
148
+ # one gamma value per channel
149
+ gamma_shape = [1, 1, dim]
150
+ elif self.layer_scale_type == "scalar":
151
+ # single gamma value for all channels
152
+ gamma_shape = [1, 1, 1]
153
+ # two gammas: for each part of the fwd in the encoder
154
+ self.layer_scale_gamma1 = nn.Parameter(
155
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
156
+ requires_grad=True,
157
+ )
158
+ self.layer_scale_gamma2 = nn.Parameter(
159
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
160
+ requires_grad=True,
161
+ )
162
+
163
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
164
+ if self.layer_scale_type is None:
165
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
166
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
167
+ else:
168
+ x = (
169
+ x
170
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
171
+ * self.layer_scale_gamma1
172
+ )
173
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
174
+ return x
175
+
176
+
177
+ _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
178
+
179
+
180
+ class SimpleTransformer(nn.Module):
181
+ def __init__(
182
+ self,
183
+ attn_target: Callable,
184
+ embed_dim: int,
185
+ num_blocks: int,
186
+ block: Callable = BlockWithMasking,
187
+ pre_transformer_layer: Callable = None,
188
+ post_transformer_layer: Callable = None,
189
+ drop_path_rate: float = 0.0,
190
+ drop_path_type: str = "progressive",
191
+ norm_layer: Callable = _LAYER_NORM,
192
+ mlp_ratio: int = 4,
193
+ ffn_dropout_rate: float = 0.0,
194
+ layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
195
+ layer_scale_init_value: float = 1e-4, # from cait; float
196
+ weight_init_style: str = "jax", # possible values jax or pytorch
197
+ ):
198
+ """
199
+ Simple Transformer with the following features
200
+ 1. Supports masked attention
201
+ 2. Supports DropPath
202
+ 3. Supports LayerScale
203
+ 4. Supports Dropout in Attention and FFN
204
+ 5. Makes few assumptions about the input except that it is a Tensor
205
+ """
206
+ super().__init__()
207
+ self.pre_transformer_layer = pre_transformer_layer
208
+ if drop_path_type == "progressive":
209
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
210
+ elif drop_path_type == "uniform":
211
+ dpr = [drop_path_rate for i in range(num_blocks)]
212
+ else:
213
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
214
+
215
+ self.blocks = nn.Sequential(
216
+ *[
217
+ block(
218
+ dim=embed_dim,
219
+ attn_target=attn_target,
220
+ mlp_ratio=mlp_ratio,
221
+ ffn_dropout_rate=ffn_dropout_rate,
222
+ drop_path=dpr[i],
223
+ norm_layer=norm_layer,
224
+ layer_scale_type=layer_scale_type,
225
+ layer_scale_init_value=layer_scale_init_value,
226
+ )
227
+ for i in range(num_blocks)
228
+ ]
229
+ )
230
+ self.post_transformer_layer = post_transformer_layer
231
+ self.weight_init_style = weight_init_style
232
+ self.apply(self._init_weights)
233
+
234
+ def _init_weights(self, m):
235
+ if isinstance(m, nn.Linear):
236
+ if self.weight_init_style == "jax":
237
+ # Based on MAE and official Jax ViT implementation
238
+ torch.nn.init.xavier_uniform_(m.weight)
239
+ elif self.weight_init_style == "pytorch":
240
+ # PyTorch ViT uses trunc_normal_
241
+ trunc_normal_(m.weight, std=0.02)
242
+
243
+ if m.bias is not None:
244
+ nn.init.constant_(m.bias, 0)
245
+ elif isinstance(m, (nn.LayerNorm)):
246
+ nn.init.constant_(m.bias, 0)
247
+ nn.init.constant_(m.weight, 1.0)
248
+
249
+ def forward(
250
+ self,
251
+ tokens: torch.Tensor,
252
+ attn_mask: torch.Tensor = None,
253
+ use_checkpoint: bool = False,
254
+ checkpoint_every_n: int = 1,
255
+ checkpoint_blk_ids: List[int] = None,
256
+ ):
257
+ """
258
+ Inputs
259
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
260
+ - attn: mask of shape L x L
261
+
262
+ Output
263
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
264
+ """
265
+ if self.pre_transformer_layer:
266
+ tokens = self.pre_transformer_layer(tokens)
267
+ if use_checkpoint and checkpoint_blk_ids is None:
268
+ checkpoint_blk_ids = [
269
+ blk_id
270
+ for blk_id in range(len(self.blocks))
271
+ if blk_id % checkpoint_every_n == 0
272
+ ]
273
+ if checkpoint_blk_ids:
274
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
275
+ for blk_id, blk in enumerate(self.blocks):
276
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
277
+ tokens = checkpoint.checkpoint(
278
+ blk, tokens, attn_mask, use_reentrant=False
279
+ )
280
+ else:
281
+ tokens = blk(tokens, attn_mask=attn_mask)
282
+ if self.post_transformer_layer:
283
+ tokens = self.post_transformer_layer(tokens)
284
+ return tokens
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ torch==1.13
3
+ torchvision==0.14.0
4
+ torchaudio==0.13.0
5
+ pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d
6
+ timm==0.6.7
7
+ ftfy
8
+ regex
9
+ einops
10
+ fvcore
11
+ decord==0.6.0