File size: 4,363 Bytes
a4f8069 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import types
import torch
import torch.nn as nn
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
model.device = kwargs.get("device")
is_onnx = kwargs.get("type", "onnx") == "onnx"
# encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
# model.encoder = encoder_class(model.encoder, onnx=is_onnx)
from funasr.utils.torch_function import sequence_mask
# model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
model.export_name = "model"
return model
def export_forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
language: torch.Tensor,
textnorm: torch.Tensor,
**kwargs,
):
speech = speech.to(device='cuda')
speech_lengths = speech_lengths.to(device='cuda')
language_query = self.embed(language.to(speech.device)).unsqueeze(1)
textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
print(textnorm_query.shape, speech.shape)
speech = torch.cat((textnorm_query, speech), dim=1)
speech_lengths += 1
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
speech.size(0), 1, 1
)
input_query = torch.cat((language_query, event_emo_query), dim=1)
speech = torch.cat((input_query, speech), dim=1)
speech_lengths += 3
# Encoder
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
# ctc_logits = self.ctc.log_softmax(encoder_out)
ctc_logits = self.ctc.ctc_lo(encoder_out)
return ctc_logits, encoder_out_lens
def export_dummy_inputs(self):
speech = torch.randn(2, 30, 560)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
language = torch.tensor([0, 0], dtype=torch.int32)
textnorm = torch.tensor([15, 15], dtype=torch.int32)
return (speech, speech_lengths, language, textnorm)
def export_input_names(self):
return ["speech", "speech_lengths", "language", "textnorm"]
def export_output_names(self):
return ["ctc_logits", "encoder_out_lens"]
def export_dynamic_axes(self):
return {
"speech": {0: "batch_size", 1: "feats_length"},
"speech_lengths": {
0: "batch_size",
},
"language": {0: "batch_size"},
"textnorm": {0: "batch_size"},
"ctc_logits": {0: "batch_size", 1: "logits_length"},
}
def export_name(
self,
):
return "model.onnx"
if __name__ == "__main__":
from model import SenseVoiceSmall
model_dir = "iic/SenseVoiceSmall"
#model_dir = "./SenseVoiceSmall"
model, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir)
# model = model.to("cpu")
model = export_rebuild_model(model, max_seq_len=512, device="cuda")
# model.export()
print("Export Done.")
dummy_inputs = model.export_dummy_inputs()
# Export the model
torch.onnx.export(
model,
dummy_inputs,
"model.onnx",
input_names=model.export_input_names(),
output_names=model.export_output_names(),
dynamic_axes=model.export_dynamic_axes(),
opset_version=18
)
# import os
# import onnxmltools
# from onnxmltools.utils.float16_converter import (
# convert_float_to_float16)
# decoder_onnx_model = onnxmltools.utils.load_model("model.onnx")
# decoder_onnx_model = convert_float_to_float16(decoder_onnx_model)
# decoder_onnx_path = "model_fp16.onnx"
# onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path)
# print("Model has been successfully exported to model.onnx") |