File size: 2,108 Bytes
c2f4ff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
'''
Transform augmented model back to normal hf supported version.
i.e remove first module
'''
from augmentation import AUG

from torch import nn
import transformers
from transformers import Wav2Vec2ForCTC, AutoModelForPreTraining
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2NoLayerNormConvLayer, Wav2Vec2LayerNormConvLayer, Wav2Vec2GroupNormConvLayer


def patch_init(cls):
    __class__ = cls  # provide closure cell for super()

    def new_init(self, config):
        if config.feat_extract_norm == "group":
            conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
                Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
            ]
        elif config.feat_extract_norm == "layer":
            conv_layers = [
                Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
            ]
        else:
            raise ValueError(
                f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
            )
        aug = AUG()
        from IPython import embed
        embed()
        conv_layers.insert(0, aug)
        self.conv_layers = nn.ModuleList(conv_layers)
    cls.__init__ = new_init

patch_init(transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureExtractor)

# model_path = "pytorch_model.bin"
# Wav2Vec2ForPreTraining
model = Wav2Vec2ForCTC.from_pretrained(".")
from IPython import embed
embed()
# monkey patching from augmentation return model to normal state
model.wav2vec2.feature_extractor.conv_layers = nn.Sequential(*list(model.wav2vec2.feature_extractor.conv_layers.children())[1:])


model.save_pretrained(".")

"""
replace with temprarily then save model, patching didn't work, get loaded so far for some reason.
"conv_dim": [
    1,
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    5,
    2,
    2,
    2,
    2,
    2,
    2
"""