Nick088 commited on
Commit
fa90792
1 Parent(s): 91f7fec

added audio sr files, adapted them to zerogpu and optimization for memory

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -1
  2. app.py +46 -6
  3. audiosr/__init__.py +2 -0
  4. audiosr/__main__.py +123 -0
  5. audiosr/__pycache__/__init__.cpython-310.pyc +0 -0
  6. audiosr/__pycache__/lowpass.cpython-310.pyc +0 -0
  7. audiosr/__pycache__/pipeline.cpython-310.pyc +0 -0
  8. audiosr/__pycache__/utils.cpython-310.pyc +0 -0
  9. audiosr/clap/__init__.py +0 -0
  10. audiosr/clap/__pycache__/__init__.cpython-310.pyc +0 -0
  11. audiosr/clap/open_clip/__init__.py +25 -0
  12. audiosr/clap/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
  13. audiosr/clap/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
  14. audiosr/clap/open_clip/__pycache__/feature_fusion.cpython-310.pyc +0 -0
  15. audiosr/clap/open_clip/__pycache__/htsat.cpython-310.pyc +0 -0
  16. audiosr/clap/open_clip/__pycache__/loss.cpython-310.pyc +0 -0
  17. audiosr/clap/open_clip/__pycache__/model.cpython-310.pyc +0 -0
  18. audiosr/clap/open_clip/__pycache__/openai.cpython-310.pyc +0 -0
  19. audiosr/clap/open_clip/__pycache__/pann_model.cpython-310.pyc +0 -0
  20. audiosr/clap/open_clip/__pycache__/pretrained.cpython-310.pyc +0 -0
  21. audiosr/clap/open_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
  22. audiosr/clap/open_clip/__pycache__/transform.cpython-310.pyc +0 -0
  23. audiosr/clap/open_clip/__pycache__/utils.cpython-310.pyc +0 -0
  24. audiosr/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  25. audiosr/clap/open_clip/factory.py +276 -0
  26. audiosr/clap/open_clip/feature_fusion.py +192 -0
  27. audiosr/clap/open_clip/htsat.py +1304 -0
  28. audiosr/clap/open_clip/loss.py +397 -0
  29. audiosr/clap/open_clip/model.py +931 -0
  30. audiosr/clap/open_clip/model_configs/HTSAT-base.json +23 -0
  31. audiosr/clap/open_clip/model_configs/HTSAT-large.json +23 -0
  32. audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +23 -0
  33. audiosr/clap/open_clip/model_configs/HTSAT-tiny.json +23 -0
  34. audiosr/clap/open_clip/model_configs/PANN-10.json +23 -0
  35. audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json +23 -0
  36. audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +23 -0
  37. audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +23 -0
  38. audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json +23 -0
  39. audiosr/clap/open_clip/model_configs/PANN-14.json +23 -0
  40. audiosr/clap/open_clip/model_configs/PANN-6.json +23 -0
  41. audiosr/clap/open_clip/model_configs/RN101-quickgelu.json +22 -0
  42. audiosr/clap/open_clip/model_configs/RN101.json +21 -0
  43. audiosr/clap/open_clip/model_configs/RN50-quickgelu.json +22 -0
  44. audiosr/clap/open_clip/model_configs/RN50.json +21 -0
  45. audiosr/clap/open_clip/model_configs/RN50x16.json +21 -0
  46. audiosr/clap/open_clip/model_configs/RN50x4.json +21 -0
  47. audiosr/clap/open_clip/model_configs/ViT-B-16.json +16 -0
  48. audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  49. audiosr/clap/open_clip/model_configs/ViT-B-32.json +16 -0
  50. audiosr/clap/open_clip/model_configs/ViT-L-14.json +16 -0
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.31.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  short_description: Fixed fork of the original audio sr!
12
  ---
 
6
  sdk: gradio
7
  sdk_version: 4.31.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  short_description: Fixed fork of the original audio sr!
12
  ---
app.py CHANGED
@@ -1,10 +1,50 @@
1
- import os
2
- import subprocess
 
 
 
3
 
4
- os.system("git clone https://github.com/Nick088Official/versatile_audio_super_resolution")
 
 
 
 
 
5
 
6
- os.chdir("versatile_audio_super_resolution")
7
 
8
- subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- os.system("python app.py")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from audiosr import super_resolution, build_model
3
+ import torch
4
+ import gc # free up memory
5
+ import spaces
6
 
7
+ @spaces.GPU(duration=300)
8
+ def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
9
+ audiosr = build_model(model_name=model_name)
10
+
11
+ if torch.cuda.is_avaible():
12
+ torch.cuda.empty_cache() # empty cuda cache
13
 
14
+ gc.collect()
15
 
16
+ # set random seed when seed input value is 0
17
+ if seed == 0:
18
+ import random
19
+ seed = random.randint(1, 2**32-1)
20
+
21
+ waveform = super_resolution(
22
+ audiosr,
23
+ audio_file,
24
+ seed,
25
+ guidance_scale=guidance_scale,
26
+ ddim_steps=ddim_steps
27
+ )
28
 
29
+ if torch.cuda.is_avaible():
30
+ torch.cuda.empty_cache()
31
+
32
+ gc.collect()
33
+
34
+ return (48000, waveform)
35
+
36
+ iface = gr.Interface(
37
+ fn=inference,
38
+ inputs=[
39
+ gr.Audio(type="filepath", label="Input Audio"),
40
+ gr.Dropdown(["basic", "speech"], value="basic", label="Model"),
41
+ gr.Slider(1, 10, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)"),
42
+ gr.Slider(1, 100, value=50, step=1, label="DDIM Steps", info="The sampling step for DDIM"),
43
+ gr.Number(value=42, precision=0, label="Seed", info="Changing this value (any integer number) will lead to a different generation result, put 0 for a random one.")
44
+ ],
45
+ outputs=gr.Audio(type="numpy", label="Output Audio"),
46
+ title="AudioSR",
47
+ description="Audio Super Resolution with AudioSR"
48
+ )
49
+
50
+ iface.launch(share=False)
audiosr/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2
+ from .pipeline import *
audiosr/__main__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import torch
4
+ import logging
5
+ from audiosr import super_resolution, build_model, save_wave, get_time, read_list
6
+ import argparse
7
+
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
9
+ matplotlib_logger = logging.getLogger('matplotlib')
10
+ matplotlib_logger.setLevel(logging.WARNING)
11
+
12
+ parser = argparse.ArgumentParser()
13
+
14
+ parser.add_argument(
15
+ "-i",
16
+ "--input_audio_file",
17
+ type=str,
18
+ required=False,
19
+ help="Input audio file for audio super resolution",
20
+ )
21
+
22
+ parser.add_argument(
23
+ "-il",
24
+ "--input_file_list",
25
+ type=str,
26
+ required=False,
27
+ default="",
28
+ help="A file that contains all audio files that need to perform audio super resolution",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "-s",
33
+ "--save_path",
34
+ type=str,
35
+ required=False,
36
+ help="The path to save model output",
37
+ default="./output",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--model_name",
42
+ type=str,
43
+ required=False,
44
+ help="The checkpoint you gonna use",
45
+ default="basic",
46
+ choices=["basic","speech"]
47
+ )
48
+
49
+ parser.add_argument(
50
+ "-d",
51
+ "--device",
52
+ type=str,
53
+ required=False,
54
+ help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
55
+ default="auto",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--ddim_steps",
60
+ type=int,
61
+ required=False,
62
+ default=50,
63
+ help="The sampling step for DDIM",
64
+ )
65
+
66
+ parser.add_argument(
67
+ "-gs",
68
+ "--guidance_scale",
69
+ type=float,
70
+ required=False,
71
+ default=3.5,
72
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--seed",
77
+ type=int,
78
+ required=False,
79
+ default=42,
80
+ help="Changing this value (any integer number) will lead to a different generation result.",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--suffix",
85
+ type=str,
86
+ required=False,
87
+ help="Suffix for the output file",
88
+ default="_AudioSR_Processed_48K",
89
+ )
90
+
91
+ args = parser.parse_args()
92
+ torch.set_float32_matmul_precision("high")
93
+ save_path = os.path.join(args.save_path, get_time())
94
+
95
+ assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
96
+
97
+ input_file = args.input_audio_file
98
+ random_seed = args.seed
99
+ sample_rate=48000
100
+ latent_t_per_second=12.8
101
+ guidance_scale = args.guidance_scale
102
+
103
+ os.makedirs(save_path, exist_ok=True)
104
+ audiosr = build_model(model_name=args.model_name, device=args.device)
105
+
106
+ if(args.input_file_list):
107
+ print("Generate audio based on the text prompts in %s" % args.input_file_list)
108
+ files_todo = read_list(args.input_file_list)
109
+ else:
110
+ files_todo = [input_file]
111
+
112
+ for input_file in files_todo:
113
+ name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
114
+
115
+ waveform = super_resolution(
116
+ audiosr,
117
+ input_file,
118
+ seed=random_seed,
119
+ guidance_scale=guidance_scale,
120
+ ddim_steps=args.ddim_steps,
121
+ latent_t_per_second=latent_t_per_second
122
+ )
123
+ save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
audiosr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (312 Bytes). View file
 
audiosr/__pycache__/lowpass.cpython-310.pyc ADDED
Binary file (5.22 kB). View file
 
audiosr/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (4.18 kB). View file
 
audiosr/__pycache__/utils.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
audiosr/clap/__init__.py ADDED
File without changes
audiosr/clap/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (165 Bytes). View file
 
audiosr/clap/open_clip/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
audiosr/clap/open_clip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (971 Bytes). View file
 
audiosr/clap/open_clip/__pycache__/factory.cpython-310.pyc ADDED
Binary file (6.65 kB). View file
 
audiosr/clap/open_clip/__pycache__/feature_fusion.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
audiosr/clap/open_clip/__pycache__/htsat.cpython-310.pyc ADDED
Binary file (30.8 kB). View file
 
audiosr/clap/open_clip/__pycache__/loss.cpython-310.pyc ADDED
Binary file (7.92 kB). View file
 
audiosr/clap/open_clip/__pycache__/model.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
audiosr/clap/open_clip/__pycache__/openai.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
audiosr/clap/open_clip/__pycache__/pann_model.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
audiosr/clap/open_clip/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
audiosr/clap/open_clip/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (7.37 kB). View file
 
audiosr/clap/open_clip/__pycache__/transform.cpython-310.pyc ADDED
Binary file (989 Bytes). View file
 
audiosr/clap/open_clip/__pycache__/utils.cpython-310.pyc ADDED
Binary file (9.73 kB). View file
 
audiosr/clap/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
audiosr/clap/open_clip/factory.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from .model import CLAP, convert_weights_to_fp16
11
+ from .openai import load_openai_model
12
+ from .pretrained import get_pretrained_url, download_pretrained
13
+ from .transform import image_transform
14
+
15
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
16
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
17
+
18
+
19
+ def _natural_key(string_):
20
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
21
+
22
+
23
+ def _rescan_model_configs():
24
+ global _MODEL_CONFIGS
25
+
26
+ config_ext = (".json",)
27
+ config_files = []
28
+ for config_path in _MODEL_CONFIG_PATHS:
29
+ if config_path.is_file() and config_path.suffix in config_ext:
30
+ config_files.append(config_path)
31
+ elif config_path.is_dir():
32
+ for ext in config_ext:
33
+ config_files.extend(config_path.glob(f"*{ext}"))
34
+
35
+ for cf in config_files:
36
+ if os.path.basename(cf)[0] == ".":
37
+ continue # Ignore hidden files
38
+
39
+ with open(cf, "r") as f:
40
+ model_cfg = json.load(f)
41
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
42
+ _MODEL_CONFIGS[cf.stem] = model_cfg
43
+
44
+ _MODEL_CONFIGS = {
45
+ k: v
46
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
47
+ }
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
54
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
55
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
56
+ state_dict = checkpoint["state_dict"]
57
+ else:
58
+ state_dict = checkpoint
59
+ if skip_params:
60
+ if next(iter(state_dict.items()))[0].startswith("module"):
61
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
62
+ # for k in state_dict:
63
+ # if k.startswith('transformer'):
64
+ # v = state_dict.pop(k)
65
+ # state_dict['text_branch.' + k[12:]] = v
66
+ return state_dict
67
+
68
+
69
+ def create_model(
70
+ amodel_name: str,
71
+ tmodel_name: str,
72
+ pretrained: str = "",
73
+ precision: str = "fp32",
74
+ device: torch.device = torch.device("cpu"),
75
+ jit: bool = False,
76
+ force_quick_gelu: bool = False,
77
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
78
+ skip_params=True,
79
+ pretrained_audio: str = "",
80
+ pretrained_text: str = "",
81
+ enable_fusion: bool = False,
82
+ fusion_type: str = "None"
83
+ # pretrained_image: bool = False,
84
+ ):
85
+ amodel_name = amodel_name.replace(
86
+ "/", "-"
87
+ ) # for callers using old naming with / in ViT names
88
+ pretrained_orig = pretrained
89
+ pretrained = pretrained.lower()
90
+ if pretrained == "openai":
91
+ if amodel_name in _MODEL_CONFIGS:
92
+ logging.info(f"Loading {amodel_name} model config.")
93
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
94
+ else:
95
+ logging.error(
96
+ f"Model config for {amodel_name} not found; available models {list_models()}."
97
+ )
98
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
99
+
100
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
101
+ # Hard Code in model name
102
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
103
+ model = load_openai_model(
104
+ "ViT-B-16",
105
+ model_cfg,
106
+ device=device,
107
+ jit=jit,
108
+ cache_dir=openai_model_cache_dir,
109
+ enable_fusion=enable_fusion,
110
+ fusion_type=fusion_type,
111
+ )
112
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
113
+ if precision == "amp" or precision == "fp32":
114
+ model = model.float()
115
+ else:
116
+ if amodel_name in _MODEL_CONFIGS:
117
+ logging.info(f"Loading {amodel_name} model config.")
118
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
119
+ else:
120
+ logging.error(
121
+ f"Model config for {amodel_name} not found; available models {list_models()}."
122
+ )
123
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
124
+
125
+ if force_quick_gelu:
126
+ # override for use of QuickGELU on non-OpenAI transformer models
127
+ model_cfg["quick_gelu"] = True
128
+
129
+ # if pretrained_image:
130
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
131
+ # # pretrained weight loading for timm models set via vision_cfg
132
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
133
+ # else:
134
+ # assert False, 'pretrained image towers currently only supported for timm models'
135
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
136
+ model_cfg["enable_fusion"] = enable_fusion
137
+ model_cfg["fusion_type"] = fusion_type
138
+ model = CLAP(**model_cfg)
139
+
140
+ if pretrained:
141
+ checkpoint_path = ""
142
+ url = get_pretrained_url(amodel_name, pretrained)
143
+ if url:
144
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
145
+ elif os.path.exists(pretrained_orig):
146
+ checkpoint_path = pretrained_orig
147
+ if checkpoint_path:
148
+ logging.info(
149
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
150
+ )
151
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
152
+ model.load_state_dict(ckpt)
153
+ param_names = [n for n, p in model.named_parameters()]
154
+ # for n in param_names:
155
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
156
+ else:
157
+ logging.warning(
158
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
159
+ )
160
+ raise RuntimeError(
161
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
162
+ )
163
+
164
+ if pretrained_audio:
165
+ if amodel_name.startswith("PANN"):
166
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
167
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
168
+ audio_ckpt = audio_ckpt["model"]
169
+ keys = list(audio_ckpt.keys())
170
+ for key in keys:
171
+ if (
172
+ "spectrogram_extractor" not in key
173
+ and "logmel_extractor" not in key
174
+ ):
175
+ v = audio_ckpt.pop(key)
176
+ audio_ckpt["audio_branch." + key] = v
177
+ elif os.path.basename(pretrained_audio).startswith(
178
+ "PANN"
179
+ ): # checkpoint trained via HTSAT codebase
180
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
181
+ audio_ckpt = audio_ckpt["state_dict"]
182
+ keys = list(audio_ckpt.keys())
183
+ for key in keys:
184
+ if key.startswith("sed_model"):
185
+ v = audio_ckpt.pop(key)
186
+ audio_ckpt["audio_branch." + key[10:]] = v
187
+ elif os.path.basename(pretrained_audio).startswith(
188
+ "finetuned"
189
+ ): # checkpoint trained via linear probe codebase
190
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
191
+ else:
192
+ raise ValueError("Unknown audio checkpoint")
193
+ elif amodel_name.startswith("HTSAT"):
194
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
195
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
196
+ audio_ckpt = audio_ckpt["state_dict"]
197
+ keys = list(audio_ckpt.keys())
198
+ for key in keys:
199
+ if key.startswith("sed_model") and (
200
+ "spectrogram_extractor" not in key
201
+ and "logmel_extractor" not in key
202
+ ):
203
+ v = audio_ckpt.pop(key)
204
+ audio_ckpt["audio_branch." + key[10:]] = v
205
+ elif os.path.basename(pretrained_audio).startswith(
206
+ "HTSAT"
207
+ ): # checkpoint trained via HTSAT codebase
208
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
209
+ audio_ckpt = audio_ckpt["state_dict"]
210
+ keys = list(audio_ckpt.keys())
211
+ for key in keys:
212
+ if key.startswith("sed_model"):
213
+ v = audio_ckpt.pop(key)
214
+ audio_ckpt["audio_branch." + key[10:]] = v
215
+ elif os.path.basename(pretrained_audio).startswith(
216
+ "finetuned"
217
+ ): # checkpoint trained via linear probe codebase
218
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
219
+ else:
220
+ raise ValueError("Unknown audio checkpoint")
221
+ else:
222
+ raise f"this audio encoder pretrained checkpoint is not support"
223
+
224
+ model.load_state_dict(audio_ckpt, strict=False)
225
+ logging.info(
226
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
227
+ )
228
+ param_names = [n for n, p in model.named_parameters()]
229
+ for n in param_names:
230
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
231
+
232
+ model.to(device=device)
233
+ if precision == "fp16":
234
+ assert device.type != "cpu"
235
+ convert_weights_to_fp16(model)
236
+
237
+ if jit:
238
+ model = torch.jit.script(model)
239
+
240
+ return model, model_cfg
241
+
242
+
243
+ def create_model_and_transforms(
244
+ model_name: str,
245
+ pretrained: str = "",
246
+ precision: str = "fp32",
247
+ device: torch.device = torch.device("cpu"),
248
+ jit: bool = False,
249
+ force_quick_gelu: bool = False,
250
+ # pretrained_image: bool = False,
251
+ ):
252
+ model = create_model(
253
+ model_name,
254
+ pretrained,
255
+ precision,
256
+ device,
257
+ jit,
258
+ force_quick_gelu=force_quick_gelu,
259
+ # pretrained_image=pretrained_image
260
+ )
261
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
262
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
263
+ return model, preprocess_train, preprocess_val
264
+
265
+
266
+ def list_models():
267
+ """enumerate available model architectures based on config files"""
268
+ return list(_MODEL_CONFIGS.keys())
269
+
270
+
271
+ def add_model_config(path):
272
+ """add model config path or file and update registry"""
273
+ if not isinstance(path, Path):
274
+ path = Path(path)
275
+ _MODEL_CONFIG_PATHS.append(path)
276
+ _rescan_model_configs()
audiosr/clap/open_clip/feature_fusion.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature Fusion for Varible-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ """
13
+ 直接相加 DirectAddFuse
14
+ """
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ """
25
+ 多特征融合 iAFF
26
+ """
27
+
28
+ def __init__(self, channels=64, r=4, type="2D"):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == "1D":
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == "2D":
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f"the type is not supported"
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa, xa], dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ """
135
+ 多特征融合 AFF
136
+ """
137
+
138
+ def __init__(self, channels=64, r=4, type="2D"):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == "1D":
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == "2D":
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f"the type is not supported."
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa, xa], dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo
audiosr/clap/open_clip/htsat.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from itertools import repeat
11
+ import collections.abc
12
+ import math
13
+ import warnings
14
+
15
+ from torch.nn.init import _calculate_fan_in_and_fan_out
16
+ import torch.utils.checkpoint as checkpoint
17
+
18
+ import random
19
+
20
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
21
+ from torchlibrosa.augmentation import SpecAugmentation
22
+
23
+ from itertools import repeat
24
+ from .utils import do_mixup, interpolate
25
+
26
+ from .feature_fusion import iAFF, AFF, DAF
27
+
28
+
29
+ # from PyTorch internals
30
+ def _ntuple(n):
31
+ def parse(x):
32
+ if isinstance(x, collections.abc.Iterable):
33
+ return x
34
+ return tuple(repeat(x, n))
35
+
36
+ return parse
37
+
38
+
39
+ to_1tuple = _ntuple(1)
40
+ to_2tuple = _ntuple(2)
41
+ to_3tuple = _ntuple(3)
42
+ to_4tuple = _ntuple(4)
43
+ to_ntuple = _ntuple
44
+
45
+
46
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
+ 'survival rate' as the argument.
53
+ """
54
+ if drop_prob == 0.0 or not training:
55
+ return x
56
+ keep_prob = 1 - drop_prob
57
+ shape = (x.shape[0],) + (1,) * (
58
+ x.ndim - 1
59
+ ) # work with diff dim tensors, not just 2D ConvNets
60
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
+ random_tensor.floor_() # binarize
62
+ output = x.div(keep_prob) * random_tensor
63
+ return output
64
+
65
+
66
+ class DropPath(nn.Module):
67
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
+
69
+ def __init__(self, drop_prob=None):
70
+ super(DropPath, self).__init__()
71
+ self.drop_prob = drop_prob
72
+
73
+ def forward(self, x):
74
+ return drop_path(x, self.drop_prob, self.training)
75
+
76
+
77
+ class PatchEmbed(nn.Module):
78
+ """2D Image to Patch Embedding"""
79
+
80
+ def __init__(
81
+ self,
82
+ img_size=224,
83
+ patch_size=16,
84
+ in_chans=3,
85
+ embed_dim=768,
86
+ norm_layer=None,
87
+ flatten=True,
88
+ patch_stride=16,
89
+ enable_fusion=False,
90
+ fusion_type="None",
91
+ ):
92
+ super().__init__()
93
+ img_size = to_2tuple(img_size)
94
+ patch_size = to_2tuple(patch_size)
95
+ patch_stride = to_2tuple(patch_stride)
96
+ self.img_size = img_size
97
+ self.patch_size = patch_size
98
+ self.patch_stride = patch_stride
99
+ self.grid_size = (
100
+ img_size[0] // patch_stride[0],
101
+ img_size[1] // patch_stride[1],
102
+ )
103
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
104
+ self.flatten = flatten
105
+ self.in_chans = in_chans
106
+ self.embed_dim = embed_dim
107
+
108
+ self.enable_fusion = enable_fusion
109
+ self.fusion_type = fusion_type
110
+
111
+ padding = (
112
+ (patch_size[0] - patch_stride[0]) // 2,
113
+ (patch_size[1] - patch_stride[1]) // 2,
114
+ )
115
+
116
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
+ self.proj = nn.Conv2d(
118
+ in_chans * 4,
119
+ embed_dim,
120
+ kernel_size=patch_size,
121
+ stride=patch_stride,
122
+ padding=padding,
123
+ )
124
+ else:
125
+ self.proj = nn.Conv2d(
126
+ in_chans,
127
+ embed_dim,
128
+ kernel_size=patch_size,
129
+ stride=patch_stride,
130
+ padding=padding,
131
+ )
132
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
+
134
+ if (self.enable_fusion) and (
135
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
+ ):
137
+ self.mel_conv2d = nn.Conv2d(
138
+ in_chans,
139
+ embed_dim,
140
+ kernel_size=(patch_size[0], patch_size[1] * 3),
141
+ stride=(patch_stride[0], patch_stride[1] * 3),
142
+ padding=padding,
143
+ )
144
+ if self.fusion_type == "daf_2d":
145
+ self.fusion_model = DAF()
146
+ elif self.fusion_type == "aff_2d":
147
+ self.fusion_model = AFF(channels=embed_dim, type="2D")
148
+ elif self.fusion_type == "iaff_2d":
149
+ self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
+
151
+ def forward(self, x, longer_idx=None):
152
+ if (self.enable_fusion) and (
153
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
+ ):
155
+ global_x = x[:, 0:1, :, :]
156
+
157
+ # global processing
158
+ B, C, H, W = global_x.shape
159
+ assert (
160
+ H == self.img_size[0] and W == self.img_size[1]
161
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
+ global_x = self.proj(global_x)
163
+ TW = global_x.size(-1)
164
+ if len(longer_idx) > 0:
165
+ # local processing
166
+ local_x = x[longer_idx, 1:, :, :].contiguous()
167
+ B, C, H, W = local_x.shape
168
+ local_x = local_x.view(B * C, 1, H, W)
169
+ local_x = self.mel_conv2d(local_x)
170
+ local_x = local_x.view(
171
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
+ )
173
+ local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
+ TB, TC, TH, _ = local_x.size()
175
+ if local_x.size(-1) < TW:
176
+ local_x = torch.cat(
177
+ [
178
+ local_x,
179
+ torch.zeros(
180
+ (TB, TC, TH, TW - local_x.size(-1)),
181
+ device=global_x.device,
182
+ ),
183
+ ],
184
+ dim=-1,
185
+ )
186
+ else:
187
+ local_x = local_x[:, :, :, :TW]
188
+
189
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
+ x = global_x
191
+ else:
192
+ B, C, H, W = x.shape
193
+ assert (
194
+ H == self.img_size[0] and W == self.img_size[1]
195
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
+ x = self.proj(x)
197
+
198
+ if self.flatten:
199
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
+ x = self.norm(x)
201
+ return x
202
+
203
+
204
+ class Mlp(nn.Module):
205
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
+
207
+ def __init__(
208
+ self,
209
+ in_features,
210
+ hidden_features=None,
211
+ out_features=None,
212
+ act_layer=nn.GELU,
213
+ drop=0.0,
214
+ ):
215
+ super().__init__()
216
+ out_features = out_features or in_features
217
+ hidden_features = hidden_features or in_features
218
+ self.fc1 = nn.Linear(in_features, hidden_features)
219
+ self.act = act_layer()
220
+ self.fc2 = nn.Linear(hidden_features, out_features)
221
+ self.drop = nn.Dropout(drop)
222
+
223
+ def forward(self, x):
224
+ x = self.fc1(x)
225
+ x = self.act(x)
226
+ x = self.drop(x)
227
+ x = self.fc2(x)
228
+ x = self.drop(x)
229
+ return x
230
+
231
+
232
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
+ def norm_cdf(x):
236
+ # Computes standard normal cumulative distribution function
237
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
+
239
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
240
+ warnings.warn(
241
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
+ "The distribution of values may be incorrect.",
243
+ stacklevel=2,
244
+ )
245
+
246
+ with torch.no_grad():
247
+ # Values are generated by using a truncated uniform distribution and
248
+ # then using the inverse CDF for the normal distribution.
249
+ # Get upper and lower cdf values
250
+ l = norm_cdf((a - mean) / std)
251
+ u = norm_cdf((b - mean) / std)
252
+
253
+ # Uniformly fill tensor with values from [l, u], then translate to
254
+ # [2l-1, 2u-1].
255
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
256
+
257
+ # Use inverse cdf transform for normal distribution to get truncated
258
+ # standard normal
259
+ tensor.erfinv_()
260
+
261
+ # Transform to proper mean, std
262
+ tensor.mul_(std * math.sqrt(2.0))
263
+ tensor.add_(mean)
264
+
265
+ # Clamp to ensure it's in the proper range
266
+ tensor.clamp_(min=a, max=b)
267
+ return tensor
268
+
269
+
270
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
+ # type: (Tensor, float, float, float, float) -> Tensor
272
+ r"""Fills the input Tensor with values drawn from a truncated
273
+ normal distribution. The values are effectively drawn from the
274
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
+ with values outside :math:`[a, b]` redrawn until they are within
276
+ the bounds. The method used for generating the random values works
277
+ best when :math:`a \leq \text{mean} \leq b`.
278
+ Args:
279
+ tensor: an n-dimensional `torch.Tensor`
280
+ mean: the mean of the normal distribution
281
+ std: the standard deviation of the normal distribution
282
+ a: the minimum cutoff value
283
+ b: the maximum cutoff value
284
+ Examples:
285
+ >>> w = torch.empty(3, 5)
286
+ >>> nn.init.trunc_normal_(w)
287
+ """
288
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
+
290
+
291
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
+ if mode == "fan_in":
294
+ denom = fan_in
295
+ elif mode == "fan_out":
296
+ denom = fan_out
297
+ elif mode == "fan_avg":
298
+ denom = (fan_in + fan_out) / 2
299
+
300
+ variance = scale / denom
301
+
302
+ if distribution == "truncated_normal":
303
+ # constant is stddev of standard normal truncated to (-2, 2)
304
+ trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
+ elif distribution == "normal":
306
+ tensor.normal_(std=math.sqrt(variance))
307
+ elif distribution == "uniform":
308
+ bound = math.sqrt(3 * variance)
309
+ tensor.uniform_(-bound, bound)
310
+ else:
311
+ raise ValueError(f"invalid distribution {distribution}")
312
+
313
+
314
+ def lecun_normal_(tensor):
315
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
+
317
+
318
+ def window_partition(x, window_size):
319
+ """
320
+ Args:
321
+ x: (B, H, W, C)
322
+ window_size (int): window size
323
+ Returns:
324
+ windows: (num_windows*B, window_size, window_size, C)
325
+ """
326
+ B, H, W, C = x.shape
327
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
+ windows = (
329
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
+ )
331
+ return windows
332
+
333
+
334
+ def window_reverse(windows, window_size, H, W):
335
+ """
336
+ Args:
337
+ windows: (num_windows*B, window_size, window_size, C)
338
+ window_size (int): Window size
339
+ H (int): Height of image
340
+ W (int): Width of image
341
+ Returns:
342
+ x: (B, H, W, C)
343
+ """
344
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
345
+ x = windows.view(
346
+ B, H // window_size, W // window_size, window_size, window_size, -1
347
+ )
348
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
+ return x
350
+
351
+
352
+ class WindowAttention(nn.Module):
353
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
+ It supports both of shifted and non-shifted window.
355
+ Args:
356
+ dim (int): Number of input channels.
357
+ window_size (tuple[int]): The height and width of the window.
358
+ num_heads (int): Number of attention heads.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ dim,
368
+ window_size,
369
+ num_heads,
370
+ qkv_bias=True,
371
+ qk_scale=None,
372
+ attn_drop=0.0,
373
+ proj_drop=0.0,
374
+ ):
375
+ super().__init__()
376
+ self.dim = dim
377
+ self.window_size = window_size # Wh, Ww
378
+ self.num_heads = num_heads
379
+ head_dim = dim // num_heads
380
+ self.scale = qk_scale or head_dim**-0.5
381
+
382
+ # define a parameter table of relative position bias
383
+ self.relative_position_bias_table = nn.Parameter(
384
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
385
+ ) # 2*Wh-1 * 2*Ww-1, nH
386
+
387
+ # get pair-wise relative position index for each token inside the window
388
+ coords_h = torch.arange(self.window_size[0])
389
+ coords_w = torch.arange(self.window_size[1])
390
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
391
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
392
+ relative_coords = (
393
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
394
+ ) # 2, Wh*Ww, Wh*Ww
395
+ relative_coords = relative_coords.permute(
396
+ 1, 2, 0
397
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
398
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
399
+ relative_coords[:, :, 1] += self.window_size[1] - 1
400
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
401
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
402
+ self.register_buffer("relative_position_index", relative_position_index)
403
+
404
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
405
+ self.attn_drop = nn.Dropout(attn_drop)
406
+ self.proj = nn.Linear(dim, dim)
407
+ self.proj_drop = nn.Dropout(proj_drop)
408
+
409
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
410
+ self.softmax = nn.Softmax(dim=-1)
411
+
412
+ def forward(self, x, mask=None):
413
+ """
414
+ Args:
415
+ x: input features with shape of (num_windows*B, N, C)
416
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
417
+ """
418
+ B_, N, C = x.shape
419
+ qkv = (
420
+ self.qkv(x)
421
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
422
+ .permute(2, 0, 3, 1, 4)
423
+ )
424
+ q, k, v = (
425
+ qkv[0],
426
+ qkv[1],
427
+ qkv[2],
428
+ ) # make torchscript happy (cannot use tensor as tuple)
429
+
430
+ q = q * self.scale
431
+ attn = q @ k.transpose(-2, -1)
432
+
433
+ relative_position_bias = self.relative_position_bias_table[
434
+ self.relative_position_index.view(-1)
435
+ ].view(
436
+ self.window_size[0] * self.window_size[1],
437
+ self.window_size[0] * self.window_size[1],
438
+ -1,
439
+ ) # Wh*Ww,Wh*Ww,nH
440
+ relative_position_bias = relative_position_bias.permute(
441
+ 2, 0, 1
442
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
443
+ attn = attn + relative_position_bias.unsqueeze(0)
444
+
445
+ if mask is not None:
446
+ nW = mask.shape[0]
447
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
448
+ 1
449
+ ).unsqueeze(0)
450
+ attn = attn.view(-1, self.num_heads, N, N)
451
+ attn = self.softmax(attn)
452
+ else:
453
+ attn = self.softmax(attn)
454
+
455
+ attn = self.attn_drop(attn)
456
+
457
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
458
+ x = self.proj(x)
459
+ x = self.proj_drop(x)
460
+ return x, attn
461
+
462
+ def extra_repr(self):
463
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
464
+
465
+
466
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
467
+ class SwinTransformerBlock(nn.Module):
468
+ r"""Swin Transformer Block.
469
+ Args:
470
+ dim (int): Number of input channels.
471
+ input_resolution (tuple[int]): Input resulotion.
472
+ num_heads (int): Number of attention heads.
473
+ window_size (int): Window size.
474
+ shift_size (int): Shift size for SW-MSA.
475
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
476
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
477
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
478
+ drop (float, optional): Dropout rate. Default: 0.0
479
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
480
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
481
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
482
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
483
+ """
484
+
485
+ def __init__(
486
+ self,
487
+ dim,
488
+ input_resolution,
489
+ num_heads,
490
+ window_size=7,
491
+ shift_size=0,
492
+ mlp_ratio=4.0,
493
+ qkv_bias=True,
494
+ qk_scale=None,
495
+ drop=0.0,
496
+ attn_drop=0.0,
497
+ drop_path=0.0,
498
+ act_layer=nn.GELU,
499
+ norm_layer=nn.LayerNorm,
500
+ norm_before_mlp="ln",
501
+ ):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.input_resolution = input_resolution
505
+ self.num_heads = num_heads
506
+ self.window_size = window_size
507
+ self.shift_size = shift_size
508
+ self.mlp_ratio = mlp_ratio
509
+ self.norm_before_mlp = norm_before_mlp
510
+ if min(self.input_resolution) <= self.window_size:
511
+ # if window size is larger than input resolution, we don't partition windows
512
+ self.shift_size = 0
513
+ self.window_size = min(self.input_resolution)
514
+ assert (
515
+ 0 <= self.shift_size < self.window_size
516
+ ), "shift_size must in 0-window_size"
517
+
518
+ self.norm1 = norm_layer(dim)
519
+ self.attn = WindowAttention(
520
+ dim,
521
+ window_size=to_2tuple(self.window_size),
522
+ num_heads=num_heads,
523
+ qkv_bias=qkv_bias,
524
+ qk_scale=qk_scale,
525
+ attn_drop=attn_drop,
526
+ proj_drop=drop,
527
+ )
528
+
529
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
530
+ if self.norm_before_mlp == "ln":
531
+ self.norm2 = nn.LayerNorm(dim)
532
+ elif self.norm_before_mlp == "bn":
533
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
534
+ 1, 2
535
+ )
536
+ else:
537
+ raise NotImplementedError
538
+ mlp_hidden_dim = int(dim * mlp_ratio)
539
+ self.mlp = Mlp(
540
+ in_features=dim,
541
+ hidden_features=mlp_hidden_dim,
542
+ act_layer=act_layer,
543
+ drop=drop,
544
+ )
545
+
546
+ if self.shift_size > 0:
547
+ # calculate attention mask for SW-MSA
548
+ H, W = self.input_resolution
549
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
550
+ h_slices = (
551
+ slice(0, -self.window_size),
552
+ slice(-self.window_size, -self.shift_size),
553
+ slice(-self.shift_size, None),
554
+ )
555
+ w_slices = (
556
+ slice(0, -self.window_size),
557
+ slice(-self.window_size, -self.shift_size),
558
+ slice(-self.shift_size, None),
559
+ )
560
+ cnt = 0
561
+ for h in h_slices:
562
+ for w in w_slices:
563
+ img_mask[:, h, w, :] = cnt
564
+ cnt += 1
565
+
566
+ mask_windows = window_partition(
567
+ img_mask, self.window_size
568
+ ) # nW, window_size, window_size, 1
569
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
570
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
571
+ attn_mask = attn_mask.masked_fill(
572
+ attn_mask != 0, float(-100.0)
573
+ ).masked_fill(attn_mask == 0, float(0.0))
574
+ else:
575
+ attn_mask = None
576
+
577
+ self.register_buffer("attn_mask", attn_mask)
578
+
579
+ def forward(self, x):
580
+ # pdb.set_trace()
581
+ H, W = self.input_resolution
582
+ # print("H: ", H)
583
+ # print("W: ", W)
584
+ # pdb.set_trace()
585
+ B, L, C = x.shape
586
+ # assert L == H * W, "input feature has wrong size"
587
+
588
+ shortcut = x
589
+ x = self.norm1(x)
590
+ x = x.view(B, H, W, C)
591
+
592
+ # cyclic shift
593
+ if self.shift_size > 0:
594
+ shifted_x = torch.roll(
595
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
596
+ )
597
+ else:
598
+ shifted_x = x
599
+
600
+ # partition windows
601
+ x_windows = window_partition(
602
+ shifted_x, self.window_size
603
+ ) # nW*B, window_size, window_size, C
604
+ x_windows = x_windows.view(
605
+ -1, self.window_size * self.window_size, C
606
+ ) # nW*B, window_size*window_size, C
607
+
608
+ # W-MSA/SW-MSA
609
+ attn_windows, attn = self.attn(
610
+ x_windows, mask=self.attn_mask
611
+ ) # nW*B, window_size*window_size, C
612
+
613
+ # merge windows
614
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
615
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
616
+
617
+ # reverse cyclic shift
618
+ if self.shift_size > 0:
619
+ x = torch.roll(
620
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
621
+ )
622
+ else:
623
+ x = shifted_x
624
+ x = x.view(B, H * W, C)
625
+
626
+ # FFN
627
+ x = shortcut + self.drop_path(x)
628
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
629
+
630
+ return x, attn
631
+
632
+ def extra_repr(self):
633
+ return (
634
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
635
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
636
+ )
637
+
638
+
639
+ class PatchMerging(nn.Module):
640
+ r"""Patch Merging Layer.
641
+ Args:
642
+ input_resolution (tuple[int]): Resolution of input feature.
643
+ dim (int): Number of input channels.
644
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
645
+ """
646
+
647
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
648
+ super().__init__()
649
+ self.input_resolution = input_resolution
650
+ self.dim = dim
651
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
652
+ self.norm = norm_layer(4 * dim)
653
+
654
+ def forward(self, x):
655
+ """
656
+ x: B, H*W, C
657
+ """
658
+ H, W = self.input_resolution
659
+ B, L, C = x.shape
660
+ assert L == H * W, "input feature has wrong size"
661
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
662
+
663
+ x = x.view(B, H, W, C)
664
+
665
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
666
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
667
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
668
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
669
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
670
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
671
+
672
+ x = self.norm(x)
673
+ x = self.reduction(x)
674
+
675
+ return x
676
+
677
+ def extra_repr(self):
678
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
679
+
680
+
681
+ class BasicLayer(nn.Module):
682
+ """A basic Swin Transformer layer for one stage.
683
+ Args:
684
+ dim (int): Number of input channels.
685
+ input_resolution (tuple[int]): Input resolution.
686
+ depth (int): Number of blocks.
687
+ num_heads (int): Number of attention heads.
688
+ window_size (int): Local window size.
689
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
690
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
691
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
692
+ drop (float, optional): Dropout rate. Default: 0.0
693
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
694
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
695
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
696
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
697
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
698
+ """
699
+
700
+ def __init__(
701
+ self,
702
+ dim,
703
+ input_resolution,
704
+ depth,
705
+ num_heads,
706
+ window_size,
707
+ mlp_ratio=4.0,
708
+ qkv_bias=True,
709
+ qk_scale=None,
710
+ drop=0.0,
711
+ attn_drop=0.0,
712
+ drop_path=0.0,
713
+ norm_layer=nn.LayerNorm,
714
+ downsample=None,
715
+ use_checkpoint=False,
716
+ norm_before_mlp="ln",
717
+ ):
718
+ super().__init__()
719
+ self.dim = dim
720
+ self.input_resolution = input_resolution
721
+ self.depth = depth
722
+ self.use_checkpoint = use_checkpoint
723
+
724
+ # build blocks
725
+ self.blocks = nn.ModuleList(
726
+ [
727
+ SwinTransformerBlock(
728
+ dim=dim,
729
+ input_resolution=input_resolution,
730
+ num_heads=num_heads,
731
+ window_size=window_size,
732
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
733
+ mlp_ratio=mlp_ratio,
734
+ qkv_bias=qkv_bias,
735
+ qk_scale=qk_scale,
736
+ drop=drop,
737
+ attn_drop=attn_drop,
738
+ drop_path=drop_path[i]
739
+ if isinstance(drop_path, list)
740
+ else drop_path,
741
+ norm_layer=norm_layer,
742
+ norm_before_mlp=norm_before_mlp,
743
+ )
744
+ for i in range(depth)
745
+ ]
746
+ )
747
+
748
+ # patch merging layer
749
+ if downsample is not None:
750
+ self.downsample = downsample(
751
+ input_resolution, dim=dim, norm_layer=norm_layer
752
+ )
753
+ else:
754
+ self.downsample = None
755
+
756
+ def forward(self, x):
757
+ attns = []
758
+ for blk in self.blocks:
759
+ if self.use_checkpoint:
760
+ x = checkpoint.checkpoint(blk, x)
761
+ else:
762
+ x, attn = blk(x)
763
+ if not self.training:
764
+ attns.append(attn.unsqueeze(0))
765
+ if self.downsample is not None:
766
+ x = self.downsample(x)
767
+ if not self.training:
768
+ attn = torch.cat(attns, dim=0)
769
+ attn = torch.mean(attn, dim=0)
770
+ return x, attn
771
+
772
+ def extra_repr(self):
773
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
774
+
775
+
776
+ # The Core of HTSAT
777
+ class HTSAT_Swin_Transformer(nn.Module):
778
+ r"""HTSAT based on the Swin Transformer
779
+ Args:
780
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
781
+ patch_size (int | tuple(int)): Patch size. Default: 4
782
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
783
+ in_chans (int): Number of input image channels. Default: 1 (mono)
784
+ num_classes (int): Number of classes for classification head. Default: 527
785
+ embed_dim (int): Patch embedding dimension. Default: 96
786
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
787
+ num_heads (tuple(int)): Number of attention heads in different layers.
788
+ window_size (int): Window size. Default: 8
789
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
790
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
791
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
792
+ drop_rate (float): Dropout rate. Default: 0
793
+ attn_drop_rate (float): Attention dropout rate. Default: 0
794
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
795
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
796
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
797
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
798
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
799
+ config (module): The configuration Module from config.py
800
+ """
801
+
802
+ def __init__(
803
+ self,
804
+ spec_size=256,
805
+ patch_size=4,
806
+ patch_stride=(4, 4),
807
+ in_chans=1,
808
+ num_classes=527,
809
+ embed_dim=96,
810
+ depths=[2, 2, 6, 2],
811
+ num_heads=[4, 8, 16, 32],
812
+ window_size=8,
813
+ mlp_ratio=4.0,
814
+ qkv_bias=True,
815
+ qk_scale=None,
816
+ drop_rate=0.0,
817
+ attn_drop_rate=0.0,
818
+ drop_path_rate=0.1,
819
+ norm_layer=nn.LayerNorm,
820
+ ape=False,
821
+ patch_norm=True,
822
+ use_checkpoint=False,
823
+ norm_before_mlp="ln",
824
+ config=None,
825
+ enable_fusion=False,
826
+ fusion_type="None",
827
+ **kwargs,
828
+ ):
829
+ super(HTSAT_Swin_Transformer, self).__init__()
830
+
831
+ self.config = config
832
+ self.spec_size = spec_size
833
+ self.patch_stride = patch_stride
834
+ self.patch_size = patch_size
835
+ self.window_size = window_size
836
+ self.embed_dim = embed_dim
837
+ self.depths = depths
838
+ self.ape = ape
839
+ self.in_chans = in_chans
840
+ self.num_classes = num_classes
841
+ self.num_heads = num_heads
842
+ self.num_layers = len(self.depths)
843
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
844
+
845
+ self.drop_rate = drop_rate
846
+ self.attn_drop_rate = attn_drop_rate
847
+ self.drop_path_rate = drop_path_rate
848
+
849
+ self.qkv_bias = qkv_bias
850
+ self.qk_scale = None
851
+
852
+ self.patch_norm = patch_norm
853
+ self.norm_layer = norm_layer if self.patch_norm else None
854
+ self.norm_before_mlp = norm_before_mlp
855
+ self.mlp_ratio = mlp_ratio
856
+
857
+ self.use_checkpoint = use_checkpoint
858
+
859
+ self.enable_fusion = enable_fusion
860
+ self.fusion_type = fusion_type
861
+
862
+ # process mel-spec ; used only once
863
+ self.freq_ratio = self.spec_size // self.config.mel_bins
864
+ window = "hann"
865
+ center = True
866
+ pad_mode = "reflect"
867
+ ref = 1.0
868
+ amin = 1e-10
869
+ top_db = None
870
+ self.interpolate_ratio = 32 # Downsampled ratio
871
+ # Spectrogram extractor
872
+ self.spectrogram_extractor = Spectrogram(
873
+ n_fft=config.window_size,
874
+ hop_length=config.hop_size,
875
+ win_length=config.window_size,
876
+ window=window,
877
+ center=center,
878
+ pad_mode=pad_mode,
879
+ freeze_parameters=True,
880
+ )
881
+ # Logmel feature extractor
882
+ self.logmel_extractor = LogmelFilterBank(
883
+ sr=config.sample_rate,
884
+ n_fft=config.window_size,
885
+ n_mels=config.mel_bins,
886
+ fmin=config.fmin,
887
+ fmax=config.fmax,
888
+ ref=ref,
889
+ amin=amin,
890
+ top_db=top_db,
891
+ freeze_parameters=True,
892
+ )
893
+ # Spec augmenter
894
+ self.spec_augmenter = SpecAugmentation(
895
+ time_drop_width=64,
896
+ time_stripes_num=2,
897
+ freq_drop_width=8,
898
+ freq_stripes_num=2,
899
+ ) # 2 2
900
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
901
+
902
+ # split spctrogram into non-overlapping patches
903
+ self.patch_embed = PatchEmbed(
904
+ img_size=self.spec_size,
905
+ patch_size=self.patch_size,
906
+ in_chans=self.in_chans,
907
+ embed_dim=self.embed_dim,
908
+ norm_layer=self.norm_layer,
909
+ patch_stride=patch_stride,
910
+ enable_fusion=self.enable_fusion,
911
+ fusion_type=self.fusion_type,
912
+ )
913
+
914
+ num_patches = self.patch_embed.num_patches
915
+ patches_resolution = self.patch_embed.grid_size
916
+ self.patches_resolution = patches_resolution
917
+
918
+ # absolute position embedding
919
+ if self.ape:
920
+ self.absolute_pos_embed = nn.Parameter(
921
+ torch.zeros(1, num_patches, self.embed_dim)
922
+ )
923
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
924
+
925
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
926
+
927
+ # stochastic depth
928
+ dpr = [
929
+ x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
930
+ ] # stochastic depth decay rule
931
+
932
+ # build layers
933
+ self.layers = nn.ModuleList()
934
+ for i_layer in range(self.num_layers):
935
+ layer = BasicLayer(
936
+ dim=int(self.embed_dim * 2**i_layer),
937
+ input_resolution=(
938
+ patches_resolution[0] // (2**i_layer),
939
+ patches_resolution[1] // (2**i_layer),
940
+ ),
941
+ depth=self.depths[i_layer],
942
+ num_heads=self.num_heads[i_layer],
943
+ window_size=self.window_size,
944
+ mlp_ratio=self.mlp_ratio,
945
+ qkv_bias=self.qkv_bias,
946
+ qk_scale=self.qk_scale,
947
+ drop=self.drop_rate,
948
+ attn_drop=self.attn_drop_rate,
949
+ drop_path=dpr[
950
+ sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
951
+ ],
952
+ norm_layer=self.norm_layer,
953
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
954
+ use_checkpoint=use_checkpoint,
955
+ norm_before_mlp=self.norm_before_mlp,
956
+ )
957
+ self.layers.append(layer)
958
+
959
+ self.norm = self.norm_layer(self.num_features)
960
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
961
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
962
+
963
+ SF = (
964
+ self.spec_size
965
+ // (2 ** (len(self.depths) - 1))
966
+ // self.patch_stride[0]
967
+ // self.freq_ratio
968
+ )
969
+ self.tscam_conv = nn.Conv2d(
970
+ in_channels=self.num_features,
971
+ out_channels=self.num_classes,
972
+ kernel_size=(SF, 3),
973
+ padding=(0, 1),
974
+ )
975
+ self.head = nn.Linear(num_classes, num_classes)
976
+
977
+ if (self.enable_fusion) and (
978
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
979
+ ):
980
+ self.mel_conv1d = nn.Sequential(
981
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
982
+ nn.BatchNorm1d(64),
983
+ )
984
+ if self.fusion_type == "daf_1d":
985
+ self.fusion_model = DAF()
986
+ elif self.fusion_type == "aff_1d":
987
+ self.fusion_model = AFF(channels=64, type="1D")
988
+ elif self.fusion_type == "iaff_1d":
989
+ self.fusion_model = iAFF(channels=64, type="1D")
990
+
991
+ self.apply(self._init_weights)
992
+
993
+ def _init_weights(self, m):
994
+ if isinstance(m, nn.Linear):
995
+ trunc_normal_(m.weight, std=0.02)
996
+ if isinstance(m, nn.Linear) and m.bias is not None:
997
+ nn.init.constant_(m.bias, 0)
998
+ elif isinstance(m, nn.LayerNorm):
999
+ nn.init.constant_(m.bias, 0)
1000
+ nn.init.constant_(m.weight, 1.0)
1001
+
1002
+ @torch.jit.ignore
1003
+ def no_weight_decay(self):
1004
+ return {"absolute_pos_embed"}
1005
+
1006
+ @torch.jit.ignore
1007
+ def no_weight_decay_keywords(self):
1008
+ return {"relative_position_bias_table"}
1009
+
1010
+ def forward_features(self, x, longer_idx=None):
1011
+ # A deprecated optimization for using a hierarchical output from different blocks
1012
+
1013
+ frames_num = x.shape[2]
1014
+ x = self.patch_embed(x, longer_idx=longer_idx)
1015
+ if self.ape:
1016
+ x = x + self.absolute_pos_embed
1017
+ x = self.pos_drop(x)
1018
+ for i, layer in enumerate(self.layers):
1019
+ x, attn = layer(x)
1020
+ # for x
1021
+ x = self.norm(x)
1022
+ B, N, C = x.shape
1023
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1024
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1025
+ x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
1026
+ B, C, F, T = x.shape
1027
+ # group 2D CNN
1028
+ c_freq_bin = F // self.freq_ratio
1029
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1030
+ x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
1031
+ # get latent_output
1032
+ fine_grained_latent_output = torch.mean(x, dim=2)
1033
+ fine_grained_latent_output = interpolate(
1034
+ fine_grained_latent_output.permute(0, 2, 1).contiguous(),
1035
+ 8 * self.patch_stride[1],
1036
+ )
1037
+
1038
+ latent_output = self.avgpool(torch.flatten(x, 2))
1039
+ latent_output = torch.flatten(latent_output, 1)
1040
+
1041
+ # display the attention map, if needed
1042
+
1043
+ x = self.tscam_conv(x)
1044
+ x = torch.flatten(x, 2) # B, C, T
1045
+
1046
+ fpx = interpolate(
1047
+ torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
1048
+ )
1049
+
1050
+ x = self.avgpool(x)
1051
+ x = torch.flatten(x, 1)
1052
+
1053
+ output_dict = {
1054
+ "framewise_output": fpx, # already sigmoided
1055
+ "clipwise_output": torch.sigmoid(x),
1056
+ "fine_grained_embedding": fine_grained_latent_output,
1057
+ "embedding": latent_output,
1058
+ }
1059
+
1060
+ return output_dict
1061
+
1062
+ def crop_wav(self, x, crop_size, spe_pos=None):
1063
+ time_steps = x.shape[2]
1064
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1065
+ for i in range(len(x)):
1066
+ if spe_pos is None:
1067
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1068
+ else:
1069
+ crop_pos = spe_pos
1070
+ tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
1071
+ return tx
1072
+
1073
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1074
+ def reshape_wav2img(self, x):
1075
+ B, C, T, F = x.shape
1076
+ target_T = int(self.spec_size * self.freq_ratio)
1077
+ target_F = self.spec_size // self.freq_ratio
1078
+ assert (
1079
+ T <= target_T and F <= target_F
1080
+ ), "the wav size should less than or equal to the swin input size"
1081
+ # to avoid bicubic zero error
1082
+ if T < target_T:
1083
+ x = nn.functional.interpolate(
1084
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1085
+ )
1086
+ if F < target_F:
1087
+ x = nn.functional.interpolate(
1088
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1089
+ )
1090
+ x = x.permute(0, 1, 3, 2).contiguous()
1091
+ x = x.reshape(
1092
+ x.shape[0],
1093
+ x.shape[1],
1094
+ x.shape[2],
1095
+ self.freq_ratio,
1096
+ x.shape[3] // self.freq_ratio,
1097
+ )
1098
+ # print(x.shape)
1099
+ x = x.permute(0, 1, 3, 2, 4).contiguous()
1100
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1101
+ return x
1102
+
1103
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1104
+ def repeat_wat2img(self, x, cur_pos):
1105
+ B, C, T, F = x.shape
1106
+ target_T = int(self.spec_size * self.freq_ratio)
1107
+ target_F = self.spec_size // self.freq_ratio
1108
+ assert (
1109
+ T <= target_T and F <= target_F
1110
+ ), "the wav size should less than or equal to the swin input size"
1111
+ # to avoid bicubic zero error
1112
+ if T < target_T:
1113
+ x = nn.functional.interpolate(
1114
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1115
+ )
1116
+ if F < target_F:
1117
+ x = nn.functional.interpolate(
1118
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1119
+ )
1120
+ x = x.permute(0, 1, 3, 2).contiguous() # B C F T
1121
+ x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
1122
+ x = x.repeat(repeats=(1, 1, 4, 1))
1123
+ return x
1124
+
1125
+ def forward(
1126
+ self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
1127
+ ): # out_feat_keys: List[str] = None):
1128
+ if self.enable_fusion and x["longer"].sum() == 0:
1129
+ # if no audio is longer than 10s, then randomly select one audio to be longer
1130
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
1131
+
1132
+ if not self.enable_fusion:
1133
+ x = x["waveform"].to(device=device, non_blocking=True)
1134
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1135
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1136
+ x = x.transpose(1, 3)
1137
+ x = self.bn0(x)
1138
+ x = x.transpose(1, 3)
1139
+ if self.training:
1140
+ x = self.spec_augmenter(x)
1141
+
1142
+ if self.training and mixup_lambda is not None:
1143
+ x = do_mixup(x, mixup_lambda)
1144
+
1145
+ x = self.reshape_wav2img(x)
1146
+ output_dict = self.forward_features(x)
1147
+ else:
1148
+ longer_list = x["longer"].to(device=device, non_blocking=True)
1149
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
1150
+ x = x.transpose(1, 3)
1151
+ x = self.bn0(x)
1152
+ x = x.transpose(1, 3)
1153
+ longer_list_idx = torch.where(longer_list)[0]
1154
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
1155
+ new_x = x[:, 0:1, :, :].clone().contiguous()
1156
+ if len(longer_list_idx) > 0:
1157
+ # local processing
1158
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
1159
+ FB, FC, FT, FF = fusion_x_local.size()
1160
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
1161
+ fusion_x_local = torch.permute(
1162
+ fusion_x_local, (0, 2, 1)
1163
+ ).contiguous()
1164
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
1165
+ fusion_x_local = fusion_x_local.view(
1166
+ FB, FC, FF, fusion_x_local.size(-1)
1167
+ )
1168
+ fusion_x_local = (
1169
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
1170
+ .contiguous()
1171
+ .flatten(2)
1172
+ )
1173
+ if fusion_x_local.size(-1) < FT:
1174
+ fusion_x_local = torch.cat(
1175
+ [
1176
+ fusion_x_local,
1177
+ torch.zeros(
1178
+ (FB, FF, FT - fusion_x_local.size(-1)),
1179
+ device=device,
1180
+ ),
1181
+ ],
1182
+ dim=-1,
1183
+ )
1184
+ else:
1185
+ fusion_x_local = fusion_x_local[:, :, :FT]
1186
+ # 1D fusion
1187
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
1188
+ new_x[longer_list_idx] = self.fusion_model(
1189
+ new_x[longer_list_idx], fusion_x_local
1190
+ )
1191
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
1192
+ else:
1193
+ x = new_x
1194
+
1195
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
1196
+ x = x # no change
1197
+
1198
+ if self.training:
1199
+ x = self.spec_augmenter(x)
1200
+ if self.training and mixup_lambda is not None:
1201
+ x = do_mixup(x, mixup_lambda)
1202
+
1203
+ x = self.reshape_wav2img(x)
1204
+ output_dict = self.forward_features(x, longer_idx=longer_list_idx)
1205
+
1206
+ # if infer_mode:
1207
+ # # in infer mode. we need to handle different length audio input
1208
+ # frame_num = x.shape[2]
1209
+ # target_T = int(self.spec_size * self.freq_ratio)
1210
+ # repeat_ratio = math.floor(target_T / frame_num)
1211
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
1212
+ # x = self.reshape_wav2img(x)
1213
+ # output_dict = self.forward_features(x)
1214
+ # else:
1215
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
1216
+ # if self.training:
1217
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1218
+ # x = self.reshape_wav2img(x)
1219
+ # output_dict = self.forward_features(x)
1220
+ # else:
1221
+ # # Change: Hard code here
1222
+ # overlap_size = (x.shape[2] - 1) // 4
1223
+ # output_dicts = []
1224
+ # crop_size = (x.shape[2] - 1) // 2
1225
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1226
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1227
+ # tx = self.reshape_wav2img(tx)
1228
+ # output_dicts.append(self.forward_features(tx))
1229
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1230
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1231
+ # for d in output_dicts:
1232
+ # clipwise_output += d["clipwise_output"]
1233
+ # framewise_output += d["framewise_output"]
1234
+ # clipwise_output = clipwise_output / len(output_dicts)
1235
+ # framewise_output = framewise_output / len(output_dicts)
1236
+ # output_dict = {
1237
+ # 'framewise_output': framewise_output,
1238
+ # 'clipwise_output': clipwise_output
1239
+ # }
1240
+ # else: # this part is typically used, and most easy one
1241
+ # x = self.reshape_wav2img(x)
1242
+ # output_dict = self.forward_features(x)
1243
+ # x = self.head(x)
1244
+
1245
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
1246
+
1247
+ return output_dict
1248
+
1249
+
1250
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
1251
+ try:
1252
+ assert audio_cfg.model_name in [
1253
+ "tiny",
1254
+ "base",
1255
+ "large",
1256
+ ], "model name for HTS-AT is wrong!"
1257
+ if audio_cfg.model_name == "tiny":
1258
+ model = HTSAT_Swin_Transformer(
1259
+ spec_size=256,
1260
+ patch_size=4,
1261
+ patch_stride=(4, 4),
1262
+ num_classes=audio_cfg.class_num,
1263
+ embed_dim=96,
1264
+ depths=[2, 2, 6, 2],
1265
+ num_heads=[4, 8, 16, 32],
1266
+ window_size=8,
1267
+ config=audio_cfg,
1268
+ enable_fusion=enable_fusion,
1269
+ fusion_type=fusion_type,
1270
+ )
1271
+ elif audio_cfg.model_name == "base":
1272
+ model = HTSAT_Swin_Transformer(
1273
+ spec_size=256,
1274
+ patch_size=4,
1275
+ patch_stride=(4, 4),
1276
+ num_classes=audio_cfg.class_num,
1277
+ embed_dim=128,
1278
+ depths=[2, 2, 12, 2],
1279
+ num_heads=[4, 8, 16, 32],
1280
+ window_size=8,
1281
+ config=audio_cfg,
1282
+ enable_fusion=enable_fusion,
1283
+ fusion_type=fusion_type,
1284
+ )
1285
+ elif audio_cfg.model_name == "large":
1286
+ model = HTSAT_Swin_Transformer(
1287
+ spec_size=256,
1288
+ patch_size=4,
1289
+ patch_stride=(4, 4),
1290
+ num_classes=audio_cfg.class_num,
1291
+ embed_dim=256,
1292
+ depths=[2, 2, 12, 2],
1293
+ num_heads=[4, 8, 16, 32],
1294
+ window_size=8,
1295
+ config=audio_cfg,
1296
+ enable_fusion=enable_fusion,
1297
+ fusion_type=fusion_type,
1298
+ )
1299
+
1300
+ return model
1301
+ except:
1302
+ raise RuntimeError(
1303
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
1304
+ )
audiosr/clap/open_clip/loss.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed.nn
3
+ from torch import distributed as dist, nn as nn
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
7
+
8
+ try:
9
+ import horovod.torch as hvd
10
+ except ImportError:
11
+ hvd = None
12
+
13
+
14
+ def gather_features(
15
+ audio_features,
16
+ text_features,
17
+ audio_features_mlp=None,
18
+ text_features_mlp=None,
19
+ local_loss=False,
20
+ gather_with_grad=False,
21
+ rank=0,
22
+ world_size=1,
23
+ use_horovod=False,
24
+ mlp_loss=False,
25
+ ):
26
+ if use_horovod:
27
+ assert hvd is not None, "Please install horovod"
28
+ if gather_with_grad:
29
+ all_audio_features = hvd.allgather(audio_features)
30
+ all_text_features = hvd.allgather(text_features)
31
+ if mlp_loss:
32
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
33
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
34
+ else:
35
+ with torch.no_grad():
36
+ all_audio_features = hvd.allgather(audio_features)
37
+ all_text_features = hvd.allgather(text_features)
38
+ if mlp_loss:
39
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
40
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
41
+ if not local_loss:
42
+ # ensure grads for local rank when all_* features don't have a gradient
43
+ gathered_audio_features = list(
44
+ all_audio_features.chunk(world_size, dim=0)
45
+ )
46
+ gathered_text_features = list(
47
+ all_text_features.chunk(world_size, dim=0)
48
+ )
49
+ gathered_audio_features[rank] = audio_features
50
+ gathered_text_features[rank] = text_features
51
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
52
+ all_text_features = torch.cat(gathered_text_features, dim=0)
53
+ if mlp_loss:
54
+ gathered_audio_features_mlp = list(
55
+ all_audio_features_mlp.chunk(world_size, dim=0)
56
+ )
57
+ gathered_text_features_mlp = list(
58
+ all_text_features_mlp.chunk(world_size, dim=0)
59
+ )
60
+ gathered_audio_features_mlp[rank] = audio_features_mlp
61
+ gathered_text_features_mlp[rank] = text_features_mlp
62
+ all_audio_features_mlp = torch.cat(
63
+ gathered_audio_features_mlp, dim=0
64
+ )
65
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
66
+ else:
67
+ # We gather tensors from all gpus
68
+ if gather_with_grad:
69
+ all_audio_features = torch.cat(
70
+ torch.distributed.nn.all_gather(audio_features), dim=0
71
+ )
72
+ all_text_features = torch.cat(
73
+ torch.distributed.nn.all_gather(text_features), dim=0
74
+ )
75
+ if mlp_loss:
76
+ all_audio_features_mlp = torch.cat(
77
+ torch.distributed.nn.all_gather(audio_features_mlp), dim=0
78
+ )
79
+ all_text_features_mlp = torch.cat(
80
+ torch.distributed.nn.all_gather(text_features_mlp), dim=0
81
+ )
82
+ else:
83
+ gathered_audio_features = [
84
+ torch.zeros_like(audio_features) for _ in range(world_size)
85
+ ]
86
+ gathered_text_features = [
87
+ torch.zeros_like(text_features) for _ in range(world_size)
88
+ ]
89
+ dist.all_gather(gathered_audio_features, audio_features)
90
+ dist.all_gather(gathered_text_features, text_features)
91
+ if mlp_loss:
92
+ gathered_audio_features_mlp = [
93
+ torch.zeros_like(audio_features_mlp) for _ in range(world_size)
94
+ ]
95
+ gathered_text_features_mlp = [
96
+ torch.zeros_like(text_features_mlp) for _ in range(world_size)
97
+ ]
98
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
99
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
100
+ if not local_loss:
101
+ # ensure grads for local rank when all_* features don't have a gradient
102
+ gathered_audio_features[rank] = audio_features
103
+ gathered_text_features[rank] = text_features
104
+ if mlp_loss:
105
+ gathered_audio_features_mlp[rank] = audio_features_mlp
106
+ gathered_text_features_mlp[rank] = text_features_mlp
107
+
108
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
109
+ all_text_features = torch.cat(gathered_text_features, dim=0)
110
+ if mlp_loss:
111
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
112
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
113
+ if mlp_loss:
114
+ return (
115
+ all_audio_features,
116
+ all_text_features,
117
+ all_audio_features_mlp,
118
+ all_text_features_mlp,
119
+ )
120
+ else:
121
+ return all_audio_features, all_text_features
122
+
123
+
124
+ class ClipLoss(nn.Module):
125
+ def __init__(
126
+ self,
127
+ local_loss=False,
128
+ gather_with_grad=False,
129
+ cache_labels=False,
130
+ rank=0,
131
+ world_size=1,
132
+ use_horovod=False,
133
+ mlp_loss=False,
134
+ weight_loss_kappa=0,
135
+ ):
136
+ super().__init__()
137
+ self.local_loss = local_loss
138
+ self.gather_with_grad = gather_with_grad
139
+ self.cache_labels = cache_labels
140
+ self.rank = rank
141
+ self.world_size = world_size
142
+ self.use_horovod = use_horovod
143
+ self.mlp_loss = mlp_loss
144
+ self.weighted_loss = bool(weight_loss_kappa != 0)
145
+ self.weight_loss_kappa = weight_loss_kappa
146
+ # cache state
147
+ self.prev_num_logits = 0
148
+ self.labels = {}
149
+
150
+ def forward(
151
+ self,
152
+ audio_features,
153
+ text_features,
154
+ logit_scale_a,
155
+ logit_scale_t=None,
156
+ audio_features_mlp=None,
157
+ text_features_mlp=None,
158
+ ):
159
+ device = audio_features.device
160
+ if self.mlp_loss:
161
+ if self.world_size > 1:
162
+ (
163
+ all_audio_features,
164
+ all_text_features,
165
+ all_audio_features_mlp,
166
+ all_text_features_mlp,
167
+ ) = gather_features(
168
+ audio_features=audio_features,
169
+ text_features=text_features,
170
+ audio_features_mlp=audio_features_mlp,
171
+ text_features_mlp=text_features_mlp,
172
+ local_loss=self.local_loss,
173
+ gather_with_grad=self.gather_with_grad,
174
+ rank=self.rank,
175
+ world_size=self.world_size,
176
+ use_horovod=self.use_horovod,
177
+ mlp_loss=self.mlp_loss,
178
+ )
179
+ if self.local_loss:
180
+ a_logits_per_audio = (
181
+ logit_scale_a * audio_features @ all_text_features_mlp.T
182
+ )
183
+ a_logits_per_text = (
184
+ logit_scale_a * text_features_mlp @ all_audio_features.T
185
+ )
186
+ t_logits_per_audio = (
187
+ logit_scale_t * audio_features_mlp @ all_text_features.T
188
+ )
189
+ t_logits_per_text = (
190
+ logit_scale_t * text_features @ all_audio_features_mlp.T
191
+ )
192
+ else:
193
+ a_logits_per_audio = (
194
+ logit_scale_a * all_audio_features @ all_text_features_mlp.T
195
+ )
196
+ a_logits_per_text = a_logits_per_audio.T
197
+ t_logits_per_audio = (
198
+ logit_scale_t * all_audio_features_mlp @ all_text_features.T
199
+ )
200
+ t_logits_per_text = t_logits_per_audio.T
201
+ else:
202
+ a_logits_per_audio = (
203
+ logit_scale_a * audio_features @ text_features_mlp.T
204
+ )
205
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
206
+ t_logits_per_audio = (
207
+ logit_scale_t * audio_features_mlp @ text_features.T
208
+ )
209
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
210
+
211
+ # calculated ground-truth and cache if enabled
212
+ num_logits = a_logits_per_audio.shape[0]
213
+ if self.prev_num_logits != num_logits or device not in self.labels:
214
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
215
+ if self.world_size > 1 and self.local_loss:
216
+ labels = labels + num_logits * self.rank
217
+ if self.cache_labels:
218
+ self.labels[device] = labels
219
+ self.prev_num_logits = num_logits
220
+ else:
221
+ labels = self.labels[device]
222
+
223
+ if not self.weighted_loss:
224
+ total_loss = (
225
+ F.cross_entropy(a_logits_per_audio, labels)
226
+ + F.cross_entropy(a_logits_per_text, labels)
227
+ + F.cross_entropy(t_logits_per_audio, labels)
228
+ + F.cross_entropy(t_logits_per_text, labels)
229
+ ) / 4
230
+ else:
231
+ audio_weight = (audio_features @ audio_features.T).detach()
232
+ audio_weight = (
233
+ torch.exp(
234
+ torch.sum(audio_weight, axis=1)
235
+ / (self.weight_loss_kappa * len(audio_weight))
236
+ )
237
+ ).detach()
238
+ text_weight = (text_features @ text_features.T).detach()
239
+ text_weight = (
240
+ torch.exp(
241
+ torch.sum(text_weight, axis=1)
242
+ / (self.weight_loss_kappa * len(text_features))
243
+ )
244
+ ).detach()
245
+ total_loss = (
246
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
247
+ + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
248
+ + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
249
+ + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
250
+ ) / 4
251
+ else:
252
+ if self.world_size > 1:
253
+ all_audio_features, all_text_features = gather_features(
254
+ audio_features=audio_features,
255
+ text_features=text_features,
256
+ local_loss=self.local_loss,
257
+ gather_with_grad=self.gather_with_grad,
258
+ rank=self.rank,
259
+ world_size=self.world_size,
260
+ use_horovod=self.use_horovod,
261
+ mlp_loss=self.mlp_loss,
262
+ )
263
+
264
+ if self.local_loss:
265
+ logits_per_audio = (
266
+ logit_scale_a * audio_features @ all_text_features.T
267
+ )
268
+ logits_per_text = (
269
+ logit_scale_a * text_features @ all_audio_features.T
270
+ )
271
+ else:
272
+ logits_per_audio = (
273
+ logit_scale_a * all_audio_features @ all_text_features.T
274
+ )
275
+ logits_per_text = logits_per_audio.T
276
+ else:
277
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
278
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
279
+
280
+ # calculated ground-truth and cache if enabled
281
+ num_logits = logits_per_audio.shape[0]
282
+ if self.prev_num_logits != num_logits or device not in self.labels:
283
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
284
+ if self.world_size > 1 and self.local_loss:
285
+ labels = labels + num_logits * self.rank
286
+ if self.cache_labels:
287
+ self.labels[device] = labels
288
+ self.prev_num_logits = num_logits
289
+ else:
290
+ labels = self.labels[device]
291
+ if not self.weighted_loss:
292
+ total_loss = (
293
+ F.cross_entropy(logits_per_audio, labels)
294
+ + F.cross_entropy(logits_per_text, labels)
295
+ ) / 2
296
+ else:
297
+ audio_weight = (all_audio_features @ all_audio_features.T).detach()
298
+ audio_weight = (
299
+ torch.exp(
300
+ torch.sum(audio_weight, axis=1)
301
+ / (self.weight_loss_kappa * len(all_audio_features))
302
+ )
303
+ ).detach()
304
+ text_weight = (all_text_features @ all_text_features.T).detach()
305
+ text_weight = (
306
+ torch.exp(
307
+ torch.sum(text_weight, axis=1)
308
+ / (self.weight_loss_kappa * len(all_text_features))
309
+ )
310
+ ).detach()
311
+ total_loss = (
312
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight)
313
+ + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
314
+ ) / 2
315
+ return total_loss
316
+
317
+
318
+ def lp_gather_features(pred, target, world_size=1, use_horovod=False):
319
+ if use_horovod:
320
+ assert hvd is not None, "Please install horovod"
321
+ with torch.no_grad():
322
+ all_preds = hvd.allgather(pred)
323
+ all_targets = hvd.allgath(target)
324
+ else:
325
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
326
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
327
+
328
+ dist.all_gather(gathered_preds, pred)
329
+ dist.all_gather(gathered_targets, target)
330
+ all_preds = torch.cat(gathered_preds, dim=0)
331
+ all_targets = torch.cat(gathered_targets, dim=0)
332
+
333
+ return all_preds, all_targets
334
+
335
+
336
+ def get_map(pred, target):
337
+ pred = torch.sigmoid(pred).numpy()
338
+ target = target.numpy()
339
+ return np.mean(average_precision_score(target, pred, average=None))
340
+
341
+
342
+ def get_acc(pred, target):
343
+ pred = torch.argmax(pred, 1).numpy()
344
+ target = torch.argmax(target, 1).numpy()
345
+ return accuracy_score(target, pred)
346
+
347
+
348
+ def get_mauc(pred, target):
349
+ pred = torch.sigmoid(pred).numpy()
350
+ target = target.numpy()
351
+ return np.mean(roc_auc_score(target, pred, average=None))
352
+
353
+
354
+ class LPMetrics(object):
355
+ def __init__(self, metric_names=["map", "acc", "mauc"]):
356
+ self.metrics = []
357
+ for name in metric_names:
358
+ self.metrics.append(self.get_metric(name))
359
+ self.metric_names = metric_names
360
+
361
+ def get_metric(self, name):
362
+ if name == "map":
363
+ return get_map
364
+ elif name == "acc":
365
+ return get_acc
366
+ elif name == "mauc":
367
+ return get_mauc
368
+ else:
369
+ raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
370
+
371
+ def evaluate_mertics(self, pred, target):
372
+ metric_dict = {}
373
+ for i in range(len(self.metric_names)):
374
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
375
+ return metric_dict
376
+
377
+
378
+ def calc_celoss(pred, target):
379
+ target = torch.argmax(target, 1).long()
380
+ return nn.CrossEntropyLoss()(pred, target)
381
+
382
+
383
+ class LPLoss(nn.Module):
384
+ def __init__(self, loss_name):
385
+ super().__init__()
386
+ if loss_name == "bce":
387
+ self.loss_func = nn.BCEWithLogitsLoss()
388
+ elif loss_name == "ce":
389
+ self.loss_func = calc_celoss
390
+ elif loss_name == "mse":
391
+ self.loss_func = nn.MSELoss()
392
+ else:
393
+ raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
394
+
395
+ def forward(self, pred, target):
396
+ loss = self.loss_func(pred, target)
397
+ return loss
audiosr/clap/open_clip/model.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLAP Model
2
+
3
+ Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ Adapted to the Audio Task.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from dataclasses import dataclass
9
+ from typing import Tuple, Union, Callable, Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ import logging
17
+ from .utils import freeze_batch_norm_2d
18
+
19
+ from .pann_model import create_pann_model
20
+ from .htsat import create_htsat_model
21
+ from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
22
+
23
+
24
+ class MLPLayers(nn.Module):
25
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
26
+ super(MLPLayers, self).__init__()
27
+ self.nonlin = nonlin
28
+ self.dropout = dropout
29
+
30
+ sequence = []
31
+ for u0, u1 in zip(units[:-1], units[1:]):
32
+ sequence.append(nn.Linear(u0, u1))
33
+ sequence.append(self.nonlin)
34
+ sequence.append(nn.Dropout(self.dropout))
35
+ sequence = sequence[:-2]
36
+
37
+ self.sequential = nn.Sequential(*sequence)
38
+
39
+ def forward(self, X):
40
+ X = self.sequential(X)
41
+ return X
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1):
48
+ super().__init__()
49
+
50
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
51
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
52
+ self.bn1 = nn.BatchNorm2d(planes)
53
+
54
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
55
+ self.bn2 = nn.BatchNorm2d(planes)
56
+
57
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
58
+
59
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
60
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61
+
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.downsample = None
64
+ self.stride = stride
65
+
66
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
67
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
68
+ self.downsample = nn.Sequential(
69
+ OrderedDict(
70
+ [
71
+ ("-1", nn.AvgPool2d(stride)),
72
+ (
73
+ "0",
74
+ nn.Conv2d(
75
+ inplanes,
76
+ planes * self.expansion,
77
+ 1,
78
+ stride=1,
79
+ bias=False,
80
+ ),
81
+ ),
82
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
83
+ ]
84
+ )
85
+ )
86
+
87
+ def forward(self, x: torch.Tensor):
88
+ identity = x
89
+
90
+ out = self.relu(self.bn1(self.conv1(x)))
91
+ out = self.relu(self.bn2(self.conv2(out)))
92
+ out = self.avgpool(out)
93
+ out = self.bn3(self.conv3(out))
94
+
95
+ if self.downsample is not None:
96
+ identity = self.downsample(x)
97
+
98
+ out += identity
99
+ out = self.relu(out)
100
+ return out
101
+
102
+
103
+ class AttentionPool2d(nn.Module):
104
+ def __init__(
105
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
106
+ ):
107
+ super().__init__()
108
+ self.positional_embedding = nn.Parameter(
109
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
110
+ )
111
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
112
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
113
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
114
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
115
+ self.num_heads = num_heads
116
+
117
+ def forward(self, x):
118
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
119
+ 2, 0, 1
120
+ ) # NCHW -> (HW)NC
121
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
122
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
123
+ x, _ = F.multi_head_attention_forward(
124
+ query=x,
125
+ key=x,
126
+ value=x,
127
+ embed_dim_to_check=x.shape[-1],
128
+ num_heads=self.num_heads,
129
+ q_proj_weight=self.q_proj.weight,
130
+ k_proj_weight=self.k_proj.weight,
131
+ v_proj_weight=self.v_proj.weight,
132
+ in_proj_weight=None,
133
+ in_proj_bias=torch.cat(
134
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
135
+ ),
136
+ bias_k=None,
137
+ bias_v=None,
138
+ add_zero_attn=False,
139
+ dropout_p=0,
140
+ out_proj_weight=self.c_proj.weight,
141
+ out_proj_bias=self.c_proj.bias,
142
+ use_separate_proj_weight=True,
143
+ training=self.training,
144
+ need_weights=False,
145
+ )
146
+
147
+ return x[0]
148
+
149
+
150
+ class ModifiedResNet(nn.Module):
151
+ """
152
+ A ResNet class that is similar to torchvision's but contains the following changes:
153
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
154
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
155
+ - The final pooling layer is a QKV attention instead of an average pool
156
+ """
157
+
158
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
159
+ super().__init__()
160
+ self.output_dim = output_dim
161
+ self.image_size = image_size
162
+
163
+ # the 3-layer stem
164
+ self.conv1 = nn.Conv2d(
165
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
166
+ )
167
+ self.bn1 = nn.BatchNorm2d(width // 2)
168
+ self.conv2 = nn.Conv2d(
169
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
170
+ )
171
+ self.bn2 = nn.BatchNorm2d(width // 2)
172
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
173
+ self.bn3 = nn.BatchNorm2d(width)
174
+ self.avgpool = nn.AvgPool2d(2)
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ # residual layers
178
+ self._inplanes = width # this is a *mutable* variable used during construction
179
+ self.layer1 = self._make_layer(width, layers[0])
180
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
181
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
182
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
183
+
184
+ embed_dim = width * 32 # the ResNet feature dimension
185
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
186
+
187
+ self.init_parameters()
188
+
189
+ def _make_layer(self, planes, blocks, stride=1):
190
+ layers = [Bottleneck(self._inplanes, planes, stride)]
191
+
192
+ self._inplanes = planes * Bottleneck.expansion
193
+ for _ in range(1, blocks):
194
+ layers.append(Bottleneck(self._inplanes, planes))
195
+
196
+ return nn.Sequential(*layers)
197
+
198
+ def init_parameters(self):
199
+ if self.attnpool is not None:
200
+ std = self.attnpool.c_proj.in_features**-0.5
201
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
202
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
203
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
204
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
205
+
206
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
207
+ for name, param in resnet_block.named_parameters():
208
+ if name.endswith("bn3.weight"):
209
+ nn.init.zeros_(param)
210
+
211
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
212
+ assert (
213
+ unlocked_groups == 0
214
+ ), "partial locking not currently supported for this model"
215
+ for param in self.parameters():
216
+ param.requires_grad = False
217
+ if freeze_bn_stats:
218
+ freeze_batch_norm_2d(self)
219
+
220
+ def stem(self, x):
221
+ for conv, bn in [
222
+ (self.conv1, self.bn1),
223
+ (self.conv2, self.bn2),
224
+ (self.conv3, self.bn3),
225
+ ]:
226
+ x = self.relu(bn(conv(x)))
227
+ x = self.avgpool(x)
228
+ return x
229
+
230
+ def forward(self, x):
231
+ x = self.stem(x)
232
+ x = self.layer1(x)
233
+ x = self.layer2(x)
234
+ x = self.layer3(x)
235
+ x = self.layer4(x)
236
+ x = self.attnpool(x)
237
+
238
+ return x
239
+
240
+
241
+ class LayerNorm(nn.LayerNorm):
242
+ """Subclass torch's LayerNorm to handle fp16."""
243
+
244
+ def forward(self, x: torch.Tensor):
245
+ orig_type = x.dtype
246
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
247
+ return x.to(orig_type)
248
+
249
+
250
+ class QuickGELU(nn.Module):
251
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
252
+ def forward(self, x: torch.Tensor):
253
+ return x * torch.sigmoid(1.702 * x)
254
+
255
+
256
+ class ResidualAttentionBlock(nn.Module):
257
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
258
+ super().__init__()
259
+
260
+ self.attn = nn.MultiheadAttention(d_model, n_head)
261
+ self.ln_1 = LayerNorm(d_model)
262
+ self.mlp = nn.Sequential(
263
+ OrderedDict(
264
+ [
265
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
266
+ ("gelu", act_layer()),
267
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
268
+ ]
269
+ )
270
+ )
271
+ self.ln_2 = LayerNorm(d_model)
272
+
273
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
274
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
275
+
276
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
278
+ x = x + self.mlp(self.ln_2(x))
279
+ return x
280
+
281
+
282
+ class Transformer(nn.Module):
283
+ def __init__(
284
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
285
+ ):
286
+ super().__init__()
287
+ self.width = width
288
+ self.layers = layers
289
+ self.resblocks = nn.ModuleList(
290
+ [
291
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
292
+ for _ in range(layers)
293
+ ]
294
+ )
295
+
296
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
297
+ for r in self.resblocks:
298
+ x = r(x, attn_mask=attn_mask)
299
+ return x
300
+
301
+
302
+ class VisualTransformer(nn.Module):
303
+ def __init__(
304
+ self,
305
+ image_size: int,
306
+ patch_size: int,
307
+ width: int,
308
+ layers: int,
309
+ heads: int,
310
+ output_dim: int,
311
+ act_layer: Callable = nn.GELU,
312
+ ):
313
+ super().__init__()
314
+ self.image_size = image_size
315
+ self.output_dim = output_dim
316
+ self.conv1 = nn.Conv2d(
317
+ in_channels=3,
318
+ out_channels=width,
319
+ kernel_size=patch_size,
320
+ stride=patch_size,
321
+ bias=False,
322
+ )
323
+
324
+ scale = width**-0.5
325
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
326
+ self.positional_embedding = nn.Parameter(
327
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
328
+ )
329
+ self.ln_pre = LayerNorm(width)
330
+
331
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
332
+
333
+ self.ln_post = LayerNorm(width)
334
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
335
+
336
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
337
+ assert (
338
+ unlocked_groups == 0
339
+ ), "partial locking not currently supported for this model"
340
+ for param in self.parameters():
341
+ param.requires_grad = False
342
+
343
+ def forward(self, x: torch.Tensor):
344
+ x = self.conv1(x) # shape = [*, width, grid, grid]
345
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
346
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
347
+ x = torch.cat(
348
+ [
349
+ self.class_embedding.to(x.dtype)
350
+ + torch.zeros(
351
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
352
+ ),
353
+ x,
354
+ ],
355
+ dim=1,
356
+ ) # shape = [*, grid ** 2 + 1, width]
357
+ x = x + self.positional_embedding.to(x.dtype)
358
+ x = self.ln_pre(x)
359
+
360
+ x = x.permute(1, 0, 2) # NLD -> LND
361
+ x = self.text_branch(x)
362
+ x = x.permute(1, 0, 2) # LND -> NLD
363
+
364
+ x = self.ln_post(x[:, 0, :])
365
+
366
+ if self.proj is not None:
367
+ x = x @ self.proj
368
+
369
+ return x
370
+
371
+
372
+ @dataclass
373
+ class CLAPVisionCfg:
374
+ layers: Union[Tuple[int, int, int, int], int] = 12
375
+ width: int = 768
376
+ patch_size: int = 16
377
+ image_size: Union[Tuple[int, int], int] = 224
378
+ timm_model_name: str = (
379
+ None # a valid model name overrides layers, width, patch_size
380
+ )
381
+ timm_model_pretrained: bool = (
382
+ False # use (imagenet) pretrained weights for named model
383
+ )
384
+ timm_pool: str = (
385
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
386
+ )
387
+ timm_proj: str = (
388
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
389
+ )
390
+
391
+
392
+ # Audio Config Class
393
+ @dataclass
394
+ class CLAPAudioCfp:
395
+ model_type: str = "PANN"
396
+ model_name: str = "Cnn14"
397
+ sample_rate: int = 48000
398
+ # Param
399
+ audio_length: int = 1024
400
+ window_size: int = 1024
401
+ hop_size: int = 1024
402
+ fmin: int = 50
403
+ fmax: int = 14000
404
+ class_num: int = 527
405
+ mel_bins: int = 64
406
+ clip_samples: int = 480000
407
+
408
+
409
+ @dataclass
410
+ class CLAPTextCfg:
411
+ context_length: int
412
+ vocab_size: int
413
+ width: int
414
+ heads: int
415
+ layers: int
416
+ model_type: str
417
+
418
+
419
+ class CLAP(nn.Module):
420
+ def __init__(
421
+ self,
422
+ embed_dim: int,
423
+ audio_cfg: CLAPAudioCfp,
424
+ text_cfg: CLAPTextCfg,
425
+ quick_gelu: bool = False,
426
+ enable_fusion: bool = False,
427
+ fusion_type: str = "None",
428
+ joint_embed_shape: int = 512,
429
+ mlp_act: str = "relu",
430
+ ):
431
+ super().__init__()
432
+ if isinstance(audio_cfg, dict):
433
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
434
+ if isinstance(text_cfg, dict):
435
+ text_cfg = CLAPTextCfg(**text_cfg)
436
+
437
+ self.audio_cfg = audio_cfg
438
+ self.text_cfg = text_cfg
439
+ self.enable_fusion = enable_fusion
440
+ self.fusion_type = fusion_type
441
+ self.joint_embed_shape = joint_embed_shape
442
+ self.mlp_act = mlp_act
443
+
444
+ self.context_length = text_cfg.context_length
445
+
446
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
447
+ # memory efficient in recent PyTorch releases (>= 1.10).
448
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
449
+ act_layer = QuickGELU if quick_gelu else nn.GELU
450
+
451
+ if mlp_act == "relu":
452
+ mlp_act_layer = nn.ReLU()
453
+ elif mlp_act == "gelu":
454
+ mlp_act_layer = nn.GELU()
455
+ else:
456
+ raise NotImplementedError
457
+
458
+ # audio branch
459
+ # audio branch parameters
460
+ if audio_cfg.model_type == "PANN":
461
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
462
+ elif audio_cfg.model_type == "HTSAT":
463
+ self.audio_branch = create_htsat_model(
464
+ audio_cfg, enable_fusion, fusion_type
465
+ )
466
+ else:
467
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
468
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
469
+
470
+ # text branch
471
+ # text branch parameters
472
+ if text_cfg.model_type == "transformer":
473
+ self.text_branch = Transformer(
474
+ width=text_cfg.width,
475
+ layers=text_cfg.layers,
476
+ heads=text_cfg.heads,
477
+ act_layer=act_layer,
478
+ )
479
+ self.vocab_size = text_cfg.vocab_size
480
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
481
+ self.positional_embedding = nn.Parameter(
482
+ torch.empty(self.context_length, text_cfg.width)
483
+ )
484
+ self.ln_final = LayerNorm(text_cfg.width)
485
+ self.text_transform = MLPLayers(
486
+ units=[
487
+ self.joint_embed_shape,
488
+ self.joint_embed_shape,
489
+ self.joint_embed_shape,
490
+ ],
491
+ dropout=0.1,
492
+ )
493
+ self.text_projection = nn.Sequential(
494
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
495
+ mlp_act_layer,
496
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
497
+ )
498
+ elif text_cfg.model_type == "bert":
499
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
500
+ self.text_transform = MLPLayers(
501
+ units=[
502
+ self.joint_embed_shape,
503
+ self.joint_embed_shape,
504
+ self.joint_embed_shape,
505
+ ],
506
+ dropout=0.1,
507
+ )
508
+ self.text_projection = nn.Sequential(
509
+ nn.Linear(768, self.joint_embed_shape),
510
+ mlp_act_layer,
511
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
512
+ )
513
+ elif text_cfg.model_type == "roberta":
514
+ self.text_branch = RobertaModel(
515
+ RobertaConfig.from_pretrained("roberta-base")
516
+ )
517
+ self.text_transform = MLPLayers(
518
+ units=[
519
+ self.joint_embed_shape,
520
+ self.joint_embed_shape,
521
+ self.joint_embed_shape,
522
+ ],
523
+ dropout=0.1,
524
+ )
525
+ self.text_projection = nn.Sequential(
526
+ nn.Linear(768, self.joint_embed_shape),
527
+ mlp_act_layer,
528
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
529
+ )
530
+ elif text_cfg.model_type == "bart":
531
+ self.text_branch = BartModel.from_pretrained("facebook/bart-base")
532
+ self.text_transform = MLPLayers(
533
+ units=[
534
+ self.joint_embed_shape,
535
+ self.joint_embed_shape,
536
+ self.joint_embed_shape,
537
+ ],
538
+ dropout=0.1,
539
+ )
540
+ self.text_projection = nn.Sequential(
541
+ nn.Linear(768, self.joint_embed_shape),
542
+ mlp_act_layer,
543
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
544
+ )
545
+ else:
546
+ logging.error(f"Model config for {text_cfg.model_type} not found")
547
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
548
+ self.text_branch_type = text_cfg.model_type
549
+ # text branch parameters
550
+
551
+ # audio branch parameters
552
+ self.audio_transform = MLPLayers(
553
+ units=[
554
+ self.joint_embed_shape,
555
+ self.joint_embed_shape,
556
+ self.joint_embed_shape,
557
+ ],
558
+ dropout=0.1,
559
+ )
560
+
561
+ # below here is text branch parameters
562
+
563
+ # ============================================================================================================
564
+ self.audio_projection = nn.Sequential(
565
+ nn.Linear(embed_dim, self.joint_embed_shape),
566
+ mlp_act_layer,
567
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
568
+ )
569
+
570
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
571
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
573
+
574
+ self.init_text_branch_parameters()
575
+
576
+ def init_text_branch_parameters(self):
577
+ if self.text_branch_type == "transformer":
578
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
579
+ nn.init.normal_(self.positional_embedding, std=0.01)
580
+ proj_std = (self.text_branch.width**-0.5) * (
581
+ (2 * self.text_branch.layers) ** -0.5
582
+ )
583
+ attn_std = self.text_branch.width**-0.5
584
+ fc_std = (2 * self.text_branch.width) ** -0.5
585
+ for block in self.text_branch.resblocks:
586
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
587
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
588
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
589
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
590
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
591
+ self.text_branch.embeddings.word_embeddings.weight.shape[-1]
592
+ elif self.text_branch_type == "bart":
593
+ self.text_branch.shared.weight.shape[-1]
594
+ else:
595
+ self.text_branch.width
596
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
597
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
598
+
599
+ # deprecated
600
+ # if hasattr(self.visual, 'init_parameters'):
601
+ # self.visual.init_parameters()
602
+
603
+ # if self.text_projection is not None:
604
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
605
+
606
+ def build_attention_mask(self):
607
+ # lazily create causal attention mask, with full attention between the vision tokens
608
+ # pytorch uses additive attention mask; fill with -inf
609
+ mask = torch.empty(self.context_length, self.context_length)
610
+ mask.fill_(float("-inf"))
611
+ mask.triu_(1) # zero out the lower diagonal
612
+ return mask
613
+
614
+ def encode_audio(self, audio, device):
615
+ return self.audio_branch(
616
+ audio, mixup_lambda=None, device=device
617
+ ) # mix lambda needs to add
618
+
619
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
620
+ # tmp = {}
621
+ # for k in x[0].keys():
622
+ # tmp[k] = []
623
+ # for i in range(len(x)):
624
+ # tmp[k].append(x[i][k][:77])
625
+ # for k in x[0].keys():
626
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
627
+ # return tmp
628
+
629
+ def encode_text(self, text, device):
630
+ if self.text_branch_type == "transformer":
631
+ text = text.to(device=device, non_blocking=True)
632
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
633
+
634
+ x = x + self.positional_embedding
635
+ x = x.permute(1, 0, 2) # NLD -> LND
636
+ x = self.text_branch(x, attn_mask=self.attn_mask)
637
+ x = x.permute(1, 0, 2) # LND -> NLD
638
+ x = self.ln_final(x)
639
+
640
+ # x.shape = [batch_size, n_ctx, transformer.width]
641
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
642
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
643
+ elif self.text_branch_type == "bert":
644
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
645
+ # text = BatchEncoding(text)
646
+ x = self.text_branch(
647
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
648
+ attention_mask=text["attention_mask"].to(
649
+ device=device, non_blocking=True
650
+ ),
651
+ token_type_ids=text["token_type_ids"].to(
652
+ device=device, non_blocking=True
653
+ ),
654
+ )["pooler_output"]
655
+ x = self.text_projection(x)
656
+ elif self.text_branch_type == "roberta":
657
+ x = self.text_branch(
658
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
659
+ attention_mask=text["attention_mask"].to(
660
+ device=device, non_blocking=True
661
+ ),
662
+ )["pooler_output"]
663
+ x = self.text_projection(x)
664
+ elif self.text_branch_type == "bart":
665
+ x = torch.mean(
666
+ self.text_branch(
667
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
668
+ attention_mask=text["attention_mask"].to(
669
+ device=device, non_blocking=True
670
+ ),
671
+ )["encoder_last_hidden_state"],
672
+ axis=1,
673
+ )
674
+ x = self.text_projection(x)
675
+ else:
676
+ logging.error(f"Model type {self.text_branch_type} not found")
677
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
678
+ return x
679
+
680
+ def forward(self, audio, text, device=None):
681
+ """Forward audio and text into the CLAP
682
+
683
+ Parameters
684
+ ----------
685
+ audio: torch.Tensor (batch_size, audio_length)
686
+ the time-domain audio input / the batch of mel_spec and longer list.
687
+ text: torch.Tensor () // need to add
688
+ the text token input
689
+ """
690
+ if device is None:
691
+ if audio is not None:
692
+ device = audio.device
693
+ elif text is not None:
694
+ device = text.device
695
+ if audio is None and text is None:
696
+ # a hack to get the logit scale
697
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
698
+ elif audio is None:
699
+ return self.encode_text(text, device=device)
700
+ elif text is None:
701
+ return self.audio_projection(
702
+ self.encode_audio(audio, device=device)["embedding"]
703
+ )
704
+ audio_features = self.audio_projection(
705
+ self.encode_audio(audio, device=device)["embedding"]
706
+ )
707
+ audio_features = F.normalize(audio_features, dim=-1)
708
+
709
+ text_features = self.encode_text(text, device=device)
710
+ # print("text_features", text_features)
711
+ # print("text_features.shape", text_features.shape)
712
+ # print("text_features.type", type(text_features))
713
+ text_features = F.normalize(text_features, dim=-1)
714
+
715
+ audio_features_mlp = self.audio_transform(audio_features)
716
+ text_features_mlp = self.text_transform(text_features)
717
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
718
+ return (
719
+ audio_features,
720
+ text_features,
721
+ audio_features_mlp,
722
+ text_features_mlp,
723
+ self.logit_scale_a.exp(),
724
+ self.logit_scale_t.exp(),
725
+ )
726
+
727
+ def get_logit_scale(self):
728
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
729
+
730
+ def get_text_embedding(self, data):
731
+ """Get the text embedding from the model
732
+
733
+ Parameters
734
+ ----------
735
+ data: torch.Tensor
736
+ a tensor of text embedding
737
+
738
+ Returns
739
+ ----------
740
+ text_embed: torch.Tensor
741
+ a tensor of text_embeds (N, D)
742
+
743
+ """
744
+ device = next(self.parameters()).device
745
+ for k in data:
746
+ data[k] = data[k].to(device)
747
+ text_embeds = self.encode_text(data, device=device)
748
+ text_embeds = F.normalize(text_embeds, dim=-1)
749
+
750
+ return text_embeds
751
+
752
+ def get_audio_embedding(self, data):
753
+ """Get the audio embedding from the model
754
+
755
+ Parameters
756
+ ----------
757
+ data: a list of dict
758
+ the audio input dict list from 'get_audio_feature' method
759
+
760
+ Returns
761
+ ----------
762
+ audio_embed: torch.Tensor
763
+ a tensor of audio_embeds (N, D)
764
+
765
+ """
766
+ device = next(self.parameters()).device
767
+ # input_dict = {}
768
+ # keys = data[0].keys()
769
+ # for k in keys:
770
+ # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
771
+ # device
772
+ # )
773
+ audio_embeds = self.audio_projection(
774
+ self.encode_audio(data, device=device)["embedding"]
775
+ )
776
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
777
+
778
+ return audio_embeds
779
+
780
+ def audio_infer(self, audio, hopsize=None, device=None):
781
+ """Forward one audio and produce the audio embedding
782
+
783
+ Parameters
784
+ ----------
785
+ audio: (audio_length)
786
+ the time-domain audio input, notice that it must be only one input
787
+ hopsize: int
788
+ the overlap hopsize as the sliding window
789
+
790
+ Returns
791
+ ----------
792
+ output_dict: {
793
+ key: [n, (embedding_shape)] if "HTS-AT"
794
+ or
795
+ key: [(embedding_shape)] if "PANN"
796
+ }
797
+ the list of key values of the audio branch
798
+
799
+ """
800
+
801
+ assert not self.training, "the inference mode must be run at eval stage"
802
+ output_dict = {}
803
+ # PANN
804
+ if self.audio_cfg.model_type == "PANN":
805
+ audio_input = audio.unsqueeze(dim=0)
806
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
807
+ key
808
+ ].squeeze(dim=0)
809
+ elif self.audio_cfg.model_type == "HTSAT":
810
+ # repeat
811
+ audio_len = len(audio)
812
+ k = self.audio_cfg.clip_samples // audio_len
813
+ if k > 1:
814
+ audio = audio.repeat(k)
815
+ audio_len = len(audio)
816
+
817
+ if hopsize is None:
818
+ hopsize = min(hopsize, audio_len)
819
+
820
+ if audio_len > self.audio_cfg.clip_samples:
821
+ audio_input = [
822
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
823
+ for pos in range(
824
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
825
+ )
826
+ ]
827
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
828
+ audio_input = torch.stack(audio_input)
829
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
830
+ else:
831
+ audio_input = audio.unsqueeze(dim=0)
832
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
833
+ key
834
+ ].squeeze(dim=0)
835
+
836
+ return output_dict
837
+
838
+
839
+ def convert_weights_to_fp16(model: nn.Module):
840
+ """Convert applicable model parameters to fp16"""
841
+
842
+ def _convert_weights_to_fp16(l):
843
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
844
+ l.weight.data = l.weight.data.half()
845
+ if l.bias is not None:
846
+ l.bias.data = l.bias.data.half()
847
+
848
+ if isinstance(l, nn.MultiheadAttention):
849
+ for attr in [
850
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
851
+ "in_proj_bias",
852
+ "bias_k",
853
+ "bias_v",
854
+ ]:
855
+ tensor = getattr(l, attr)
856
+ if tensor is not None:
857
+ tensor.data = tensor.data.half()
858
+
859
+ for name in ["text_projection", "proj"]:
860
+ if hasattr(l, name):
861
+ attr = getattr(l, name)
862
+ if attr is not None:
863
+ attr.data = attr.data.half()
864
+
865
+ model.apply(_convert_weights_to_fp16)
866
+
867
+
868
+ # Ignore the state dict of the vision part
869
+ def build_model_from_openai_state_dict(
870
+ state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
871
+ ):
872
+ embed_dim = model_cfg["embed_dim"]
873
+ audio_cfg = model_cfg["audio_cfg"]
874
+ text_cfg = model_cfg["text_cfg"]
875
+ state_dict["positional_embedding"].shape[0]
876
+ state_dict["token_embedding.weight"].shape[0]
877
+ transformer_width = state_dict["ln_final.weight"].shape[0]
878
+ transformer_width // 64
879
+ transformer_layers = len(
880
+ set(
881
+ k.split(".")[2]
882
+ for k in state_dict
883
+ if k.startswith(f"transformer.resblocks")
884
+ )
885
+ )
886
+
887
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
888
+ text_cfg = CLAPTextCfg(**text_cfg)
889
+
890
+ model = CLAP(
891
+ embed_dim,
892
+ audio_cfg=audio_cfg,
893
+ text_cfg=text_cfg,
894
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
895
+ enable_fusion=enable_fusion,
896
+ fusion_type=fusion_type,
897
+ )
898
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
899
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
900
+ pop_keys = list(state_dict.keys())[::]
901
+ # pop the visual branch saved weights
902
+ for key in pop_keys:
903
+ if key.startswith("visual."):
904
+ state_dict.pop(key, None)
905
+
906
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
907
+ state_dict.pop(key, None)
908
+
909
+ # not use fp16
910
+ # convert_weights_to_fp16(model)
911
+ model.load_state_dict(state_dict, strict=False)
912
+ return model.eval()
913
+
914
+
915
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
916
+ model.eval()
917
+ audio_length = model.audio_cfg.audio_length
918
+ example_audio = torch.ones((batch_size, audio_length), device=device)
919
+ example_text = torch.zeros(
920
+ (batch_size, model.context_length), dtype=torch.int, device=device
921
+ )
922
+ model = torch.jit.trace_module(
923
+ model,
924
+ inputs=dict(
925
+ forward=(example_audio, example_text),
926
+ encode_text=(example_text,),
927
+ encode_image=(example_audio,),
928
+ ),
929
+ )
930
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
931
+ return model
audiosr/clap/open_clip/model_configs/HTSAT-base.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/HTSAT-large.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "large"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/HTSAT-tiny.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-10.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn10"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 18000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 960000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 360,
10
+ "fmin": 50,
11
+ "fmax": 8000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 4
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-14.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/PANN-6.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn6"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audiosr/clap/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audiosr/clap/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
audiosr/clap/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audiosr/clap/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
audiosr/clap/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
audiosr/clap/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
audiosr/clap/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
audiosr/clap/open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
audiosr/clap/open_clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }