Edoardo-BS commited on
Commit
a6e10d3
1 Parent(s): 493b96d

Upload HuBERTECG

Browse files
Files changed (2) hide show
  1. config.json +4 -0
  2. hubert_ecg.py +47 -0
config.json CHANGED
@@ -5,6 +5,10 @@
5
  "HuBERTECG"
6
  ],
7
  "attention_dropout": 0.1,
 
 
 
 
8
  "bos_token_id": 1,
9
  "classifier_proj_size": 256,
10
  "conv_bias": false,
 
5
  "HuBERTECG"
6
  ],
7
  "attention_dropout": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "hubert_ecg.HuBERTECGConfig",
10
+ "AutoModel": "hubert_ecg.HuBERTECG"
11
+ },
12
  "bos_token_id": 1,
13
  "classifier_proj_size": 256,
14
  "conv_bias": false,
hubert_ecg.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import HubertConfig, HubertModel
4
+ from typing import List
5
+
6
+ class HuBERTECGConfig(HubertConfig):
7
+
8
+ model_type = "hubert_ecg"
9
+
10
+ def __init__(self, ensemble_length: int = 1, vocab_sizes: List[int] = [100], **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.ensemble_length = ensemble_length
13
+ self.vocab_sizes = vocab_sizes if isinstance(vocab_sizes, list) else [vocab_sizes]
14
+
15
+ class HuBERTECG(HubertModel):
16
+
17
+ config_class = HuBERTECGConfig
18
+
19
+ def __init__(self, config: HuBERTECGConfig):
20
+ super().__init__(config)
21
+ self.config = config
22
+
23
+ self.pretraining_vocab_sizes = config.vocab_sizes
24
+
25
+ assert config.ensemble_length > 0 and config.ensemble_length == len(config.vocab_sizes), f"ensemble_length {config.ensemble_length} must be equal to len(vocab_sizes) {len(config.vocab_sizes)}"
26
+
27
+ # final projection layer to map encodings into the space of the codebook
28
+ self.final_proj = nn.ModuleList([nn.Linear(config.hidden_size, config.classifier_proj_size) for _ in range(config.ensemble_length)])
29
+
30
+ # embedding for codebooks
31
+ self.label_embedding = nn.ModuleList([nn.Embedding(vocab_size, config.classifier_proj_size) for vocab_size in config.vocab_sizes])
32
+
33
+ assert len(self.final_proj) == len(self.label_embedding), f"final_proj and label_embedding must have the same length"
34
+
35
+ def logits(self, transformer_output: torch.Tensor) -> torch.Tensor:
36
+ # takes (B, T, D)
37
+
38
+ # compute a projected output for each ensemble
39
+ projected_outputs = [final_projection(transformer_output) for final_projection in self.final_proj]
40
+
41
+ ensemble_logits = [torch.cosine_similarity(
42
+ projected_output.unsqueeze(2),
43
+ label_emb.weight.unsqueeze(0).unsqueeze(0),
44
+ dim=-1,
45
+ ) / 0.1 for projected_output, label_emb in zip(projected_outputs, self.label_embedding)]
46
+
47
+ return ensemble_logits # returns [(BS, T, V)] * ensemble_length