File size: 2,108 Bytes
c2f4ff5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
'''
Transform augmented model back to normal hf supported version.
i.e remove first module
'''
from augmentation import AUG
from torch import nn
import transformers
from transformers import Wav2Vec2ForCTC, AutoModelForPreTraining
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2NoLayerNormConvLayer, Wav2Vec2LayerNormConvLayer, Wav2Vec2GroupNormConvLayer
def patch_init(cls):
__class__ = cls # provide closure cell for super()
def new_init(self, config):
if config.feat_extract_norm == "group":
conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
aug = AUG()
from IPython import embed
embed()
conv_layers.insert(0, aug)
self.conv_layers = nn.ModuleList(conv_layers)
cls.__init__ = new_init
patch_init(transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureExtractor)
# model_path = "pytorch_model.bin"
# Wav2Vec2ForPreTraining
model = Wav2Vec2ForCTC.from_pretrained(".")
from IPython import embed
embed()
# monkey patching from augmentation return model to normal state
model.wav2vec2.feature_extractor.conv_layers = nn.Sequential(*list(model.wav2vec2.feature_extractor.conv_layers.children())[1:])
model.save_pretrained(".")
"""
replace with temprarily then save model, patching didn't work, get loaded so far for some reason.
"conv_dim": [
1,
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
5,
2,
2,
2,
2,
2,
2
""" |