update export models and scripts (#6)
Browse files- Use the latest pruned_transducer_stateless7_streaming script in icefall to export the model and update the export scripts in the repo (427c4b37e1885d621595389453cad0717fbad7bd)
- exp/cpu_jit.pt +2 -2
- exp/decoder_jit_trace.pt +2 -2
- exp/encoder_jit_trace.pt +2 -2
- exp/export-stateless7-streaming-zh.sh +2 -2
- exp/jit_trace_export-zh.py +0 -323
- exp/jit_trace_export-zh.sh +2 -2
- exp/joiner_jit_trace.pt +2 -2
exp/cpu_jit.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b41fa49583b69438105016d68672604e0359498925dd0c7b5965184a445cc8c
|
3 |
+
size 379196926
|
exp/decoder_jit_trace.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83cc6f5cbf4e3e7a518546c2ee4e8d9c17d479ceebcac3648a031985d44ec89d
|
3 |
+
size 12831333
|
exp/encoder_jit_trace.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83e2ec17083607da4085d9653c7d15c908e28fd50488e62c6f24e11e795a90a4
|
3 |
+
size 330607841
|
exp/export-stateless7-streaming-zh.sh
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
|
3 |
. path.sh
|
4 |
|
5 |
-
./pruned_transducer_stateless7_streaming/export
|
6 |
-
--
|
7 |
--use-averaged-model 0 \
|
8 |
--epoch 99 \
|
9 |
--avg 1 \
|
|
|
2 |
|
3 |
. path.sh
|
4 |
|
5 |
+
./pruned_transducer_stateless7_streaming/export.py \
|
6 |
+
--tokens ./k2fsa-zipformer-chinese-english-mixed/data/lang_char_bpe/tokens.txt \
|
7 |
--use-averaged-model 0 \
|
8 |
--epoch 99 \
|
9 |
--avg 1 \
|
exp/jit_trace_export-zh.py
DELETED
@@ -1,323 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
"""
|
4 |
-
Usage:
|
5 |
-
./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
|
6 |
-
--exp-dir $dir/exp \
|
7 |
-
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
8 |
-
--lang-dir ./data/lang_char_bpe \
|
9 |
-
--epoch 99 \
|
10 |
-
--avg 1 \
|
11 |
-
--use-averaged-model 0 \
|
12 |
-
\
|
13 |
-
--decode-chunk-len 32 \
|
14 |
-
--num-encoder-layers "2,4,3,2,4" \
|
15 |
-
--feedforward-dims "1024,1024,1536,1536,1024" \
|
16 |
-
--nhead "8,8,8,8,8" \
|
17 |
-
--encoder-dims "384,384,384,384,384" \
|
18 |
-
--attention-dims "192,192,192,192,192" \
|
19 |
-
--encoder-unmasked-dims "256,256,256,256,256" \
|
20 |
-
--zipformer-downsampling-factors "1,2,4,8,2" \
|
21 |
-
--cnn-module-kernels "31,31,31,31,31" \
|
22 |
-
--decoder-dim 512 \
|
23 |
-
--joiner-dim 512
|
24 |
-
"""
|
25 |
-
|
26 |
-
import argparse
|
27 |
-
import logging
|
28 |
-
from pathlib import Path
|
29 |
-
|
30 |
-
import sentencepiece as spm
|
31 |
-
import torch
|
32 |
-
from scaling_converter import convert_scaled_to_non_scaled
|
33 |
-
from train import add_model_arguments, get_params, get_transducer_model
|
34 |
-
from icefall.lexicon import Lexicon
|
35 |
-
|
36 |
-
from icefall.checkpoint import (
|
37 |
-
average_checkpoints,
|
38 |
-
average_checkpoints_with_averaged_model,
|
39 |
-
find_checkpoints,
|
40 |
-
load_checkpoint,
|
41 |
-
)
|
42 |
-
from icefall.utils import AttributeDict, str2bool
|
43 |
-
|
44 |
-
|
45 |
-
def get_parser():
|
46 |
-
parser = argparse.ArgumentParser(
|
47 |
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
48 |
-
)
|
49 |
-
|
50 |
-
parser.add_argument(
|
51 |
-
"--epoch",
|
52 |
-
type=int,
|
53 |
-
default=28,
|
54 |
-
help="""It specifies the checkpoint to use for averaging.
|
55 |
-
Note: Epoch counts from 0.
|
56 |
-
You can specify --avg to use more checkpoints for model averaging.""",
|
57 |
-
)
|
58 |
-
|
59 |
-
parser.add_argument(
|
60 |
-
"--iter",
|
61 |
-
type=int,
|
62 |
-
default=0,
|
63 |
-
help="""If positive, --epoch is ignored and it
|
64 |
-
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
65 |
-
You can specify --avg to use more checkpoints for model averaging.
|
66 |
-
""",
|
67 |
-
)
|
68 |
-
|
69 |
-
parser.add_argument(
|
70 |
-
"--avg",
|
71 |
-
type=int,
|
72 |
-
default=15,
|
73 |
-
help="Number of checkpoints to average. Automatically select "
|
74 |
-
"consecutive checkpoints before the checkpoint specified by "
|
75 |
-
"'--epoch' and '--iter'",
|
76 |
-
)
|
77 |
-
|
78 |
-
parser.add_argument(
|
79 |
-
"--exp-dir",
|
80 |
-
type=str,
|
81 |
-
default="pruned_transducer_stateless2/exp",
|
82 |
-
help="""It specifies the directory where all training related
|
83 |
-
files, e.g., checkpoints, log, etc, are saved
|
84 |
-
""",
|
85 |
-
)
|
86 |
-
|
87 |
-
parser.add_argument(
|
88 |
-
"--lang-dir",
|
89 |
-
type=str,
|
90 |
-
default="data/lang_char",
|
91 |
-
help="The lang dir",
|
92 |
-
)
|
93 |
-
|
94 |
-
parser.add_argument(
|
95 |
-
"--context-size",
|
96 |
-
type=int,
|
97 |
-
default=2,
|
98 |
-
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
99 |
-
)
|
100 |
-
|
101 |
-
parser.add_argument(
|
102 |
-
"--use-averaged-model",
|
103 |
-
type=str2bool,
|
104 |
-
default=True,
|
105 |
-
help="Whether to load averaged model. Currently it only supports "
|
106 |
-
"using --epoch. If True, it would decode with the averaged model "
|
107 |
-
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
108 |
-
"Actually only the models with epoch number of `epoch-avg` and "
|
109 |
-
"`epoch` are loaded for averaging. ",
|
110 |
-
)
|
111 |
-
|
112 |
-
add_model_arguments(parser)
|
113 |
-
|
114 |
-
return parser
|
115 |
-
|
116 |
-
|
117 |
-
def export_encoder_model_jit_trace(
|
118 |
-
encoder_model: torch.nn.Module,
|
119 |
-
encoder_filename: str,
|
120 |
-
params: AttributeDict,
|
121 |
-
) -> None:
|
122 |
-
"""Export the given encoder model with torch.jit.trace()
|
123 |
-
|
124 |
-
Note: The warmup argument is fixed to 1.
|
125 |
-
|
126 |
-
Args:
|
127 |
-
encoder_model:
|
128 |
-
The input encoder model
|
129 |
-
encoder_filename:
|
130 |
-
The filename to save the exported model.
|
131 |
-
"""
|
132 |
-
decode_chunk_len = params.decode_chunk_len # before subsampling
|
133 |
-
pad_length = 7
|
134 |
-
s = f"decode_chunk_len: {decode_chunk_len}"
|
135 |
-
logging.info(s)
|
136 |
-
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
|
137 |
-
encoder_model.decode_chunk_size,
|
138 |
-
decode_chunk_len,
|
139 |
-
)
|
140 |
-
|
141 |
-
T = decode_chunk_len + pad_length
|
142 |
-
|
143 |
-
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
144 |
-
x_lens = torch.full((1,), T, dtype=torch.int32)
|
145 |
-
states = encoder_model.get_init_state(device=x.device)
|
146 |
-
|
147 |
-
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
148 |
-
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
149 |
-
traced_model.save(encoder_filename)
|
150 |
-
logging.info(f"Saved to {encoder_filename}")
|
151 |
-
|
152 |
-
|
153 |
-
def export_decoder_model_jit_trace(
|
154 |
-
decoder_model: torch.nn.Module,
|
155 |
-
decoder_filename: str,
|
156 |
-
) -> None:
|
157 |
-
"""Export the given decoder model with torch.jit.trace()
|
158 |
-
|
159 |
-
Note: The argument need_pad is fixed to False.
|
160 |
-
|
161 |
-
Args:
|
162 |
-
decoder_model:
|
163 |
-
The input decoder model
|
164 |
-
decoder_filename:
|
165 |
-
The filename to save the exported model.
|
166 |
-
"""
|
167 |
-
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
168 |
-
need_pad = torch.tensor([False])
|
169 |
-
|
170 |
-
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
171 |
-
traced_model.save(decoder_filename)
|
172 |
-
logging.info(f"Saved to {decoder_filename}")
|
173 |
-
|
174 |
-
|
175 |
-
def export_joiner_model_jit_trace(
|
176 |
-
joiner_model: torch.nn.Module,
|
177 |
-
joiner_filename: str,
|
178 |
-
) -> None:
|
179 |
-
"""Export the given joiner model with torch.jit.trace()
|
180 |
-
|
181 |
-
Note: The argument project_input is fixed to True. A user should not
|
182 |
-
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
183 |
-
will do that for the user.
|
184 |
-
|
185 |
-
Args:
|
186 |
-
joiner_model:
|
187 |
-
The input joiner model
|
188 |
-
joiner_filename:
|
189 |
-
The filename to save the exported model.
|
190 |
-
|
191 |
-
"""
|
192 |
-
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
193 |
-
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
194 |
-
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
195 |
-
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
196 |
-
|
197 |
-
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
198 |
-
traced_model.save(joiner_filename)
|
199 |
-
logging.info(f"Saved to {joiner_filename}")
|
200 |
-
|
201 |
-
|
202 |
-
@torch.no_grad()
|
203 |
-
def main():
|
204 |
-
args = get_parser().parse_args()
|
205 |
-
args.exp_dir = Path(args.exp_dir)
|
206 |
-
|
207 |
-
params = get_params()
|
208 |
-
params.update(vars(args))
|
209 |
-
|
210 |
-
device = torch.device("cpu")
|
211 |
-
|
212 |
-
logging.info(f"device: {device}")
|
213 |
-
|
214 |
-
lexicon = Lexicon(params.lang_dir)
|
215 |
-
params.blank_id = 0
|
216 |
-
params.vocab_size = max(lexicon.tokens) + 1
|
217 |
-
|
218 |
-
logging.info(params)
|
219 |
-
|
220 |
-
logging.info("About to create model")
|
221 |
-
model = get_transducer_model(params)
|
222 |
-
|
223 |
-
if not params.use_averaged_model:
|
224 |
-
if params.iter > 0:
|
225 |
-
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
226 |
-
: params.avg
|
227 |
-
]
|
228 |
-
if len(filenames) == 0:
|
229 |
-
raise ValueError(
|
230 |
-
f"No checkpoints found for"
|
231 |
-
f" --iter {params.iter}, --avg {params.avg}"
|
232 |
-
)
|
233 |
-
elif len(filenames) < params.avg:
|
234 |
-
raise ValueError(
|
235 |
-
f"Not enough checkpoints ({len(filenames)}) found for"
|
236 |
-
f" --iter {params.iter}, --avg {params.avg}"
|
237 |
-
)
|
238 |
-
logging.info(f"averaging {filenames}")
|
239 |
-
model.to(device)
|
240 |
-
model.load_state_dict(average_checkpoints(filenames, device=device))
|
241 |
-
elif params.avg == 1:
|
242 |
-
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
243 |
-
else:
|
244 |
-
start = params.epoch - params.avg + 1
|
245 |
-
filenames = []
|
246 |
-
for i in range(start, params.epoch + 1):
|
247 |
-
if i >= 1:
|
248 |
-
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
249 |
-
logging.info(f"averaging {filenames}")
|
250 |
-
model.to(device)
|
251 |
-
model.load_state_dict(average_checkpoints(filenames, device=device))
|
252 |
-
else:
|
253 |
-
if params.iter > 0:
|
254 |
-
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
255 |
-
: params.avg + 1
|
256 |
-
]
|
257 |
-
if len(filenames) == 0:
|
258 |
-
raise ValueError(
|
259 |
-
f"No checkpoints found for"
|
260 |
-
f" --iter {params.iter}, --avg {params.avg}"
|
261 |
-
)
|
262 |
-
elif len(filenames) < params.avg + 1:
|
263 |
-
raise ValueError(
|
264 |
-
f"Not enough checkpoints ({len(filenames)}) found for"
|
265 |
-
f" --iter {params.iter}, --avg {params.avg}"
|
266 |
-
)
|
267 |
-
filename_start = filenames[-1]
|
268 |
-
filename_end = filenames[0]
|
269 |
-
logging.info(
|
270 |
-
"Calculating the averaged model over iteration checkpoints"
|
271 |
-
f" from {filename_start} (excluded) to {filename_end}"
|
272 |
-
)
|
273 |
-
model.to(device)
|
274 |
-
model.load_state_dict(
|
275 |
-
average_checkpoints_with_averaged_model(
|
276 |
-
filename_start=filename_start,
|
277 |
-
filename_end=filename_end,
|
278 |
-
device=device,
|
279 |
-
)
|
280 |
-
)
|
281 |
-
else:
|
282 |
-
assert params.avg > 0, params.avg
|
283 |
-
start = params.epoch - params.avg
|
284 |
-
assert start >= 1, start
|
285 |
-
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
286 |
-
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
287 |
-
logging.info(
|
288 |
-
f"Calculating the averaged model over epoch range from "
|
289 |
-
f"{start} (excluded) to {params.epoch}"
|
290 |
-
)
|
291 |
-
model.to(device)
|
292 |
-
model.load_state_dict(
|
293 |
-
average_checkpoints_with_averaged_model(
|
294 |
-
filename_start=filename_start,
|
295 |
-
filename_end=filename_end,
|
296 |
-
device=device,
|
297 |
-
)
|
298 |
-
)
|
299 |
-
|
300 |
-
model.to("cpu")
|
301 |
-
model.eval()
|
302 |
-
|
303 |
-
convert_scaled_to_non_scaled(model, inplace=True)
|
304 |
-
logging.info("Using torch.jit.trace()")
|
305 |
-
|
306 |
-
logging.info("Exporting encoder")
|
307 |
-
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
308 |
-
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
|
309 |
-
|
310 |
-
logging.info("Exporting decoder")
|
311 |
-
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
312 |
-
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
313 |
-
|
314 |
-
logging.info("Exporting joiner")
|
315 |
-
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
316 |
-
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
317 |
-
|
318 |
-
|
319 |
-
if __name__ == "__main__":
|
320 |
-
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
321 |
-
|
322 |
-
logging.basicConfig(format=formatter, level=logging.INFO)
|
323 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp/jit_trace_export-zh.sh
CHANGED
@@ -15,9 +15,9 @@ if [ ! -f $dir/exp/epoch-99.pt ]; then
|
|
15 |
popd
|
16 |
fi
|
17 |
|
18 |
-
./pruned_transducer_stateless7_streaming/jit_trace_export
|
19 |
--exp-dir $dir/exp \
|
20 |
-
--
|
21 |
--epoch 99 \
|
22 |
--avg 1 \
|
23 |
--use-averaged-model 0 \
|
|
|
15 |
popd
|
16 |
fi
|
17 |
|
18 |
+
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
19 |
--exp-dir $dir/exp \
|
20 |
+
--bpe-model $dir/data/lang_char_bpe/bpe.model \
|
21 |
--epoch 99 \
|
22 |
--avg 1 \
|
23 |
--use-averaged-model 0 \
|
exp/joiner_jit_trace.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdd58624fd2df70b5583c684c6705cd28b27d73bc15ef95484356315d1579043
|
3 |
+
size 14681115
|