jadechoghari commited on
Commit
1752041
1 Parent(s): 893807d

Create tools.py

Browse files
Files changed (1) hide show
  1. tools.py +566 -0
tools.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Haohe Liu
2
+ # Email: haoheliu@gmail.com
3
+ # Date: 11 Feb 2023
4
+
5
+ import os
6
+ import json
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ import matplotlib
12
+ from scipy.io import wavfile
13
+ from matplotlib import pyplot as plt
14
+
15
+ matplotlib.use("Agg")
16
+
17
+ import hashlib
18
+ import os
19
+
20
+ import requests
21
+ from tqdm import tqdm
22
+
23
+ URL_MAP = {
24
+ "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
25
+ "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
26
+ "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
27
+ }
28
+
29
+ CKPT_MAP = {
30
+ "vggishish_lpaps": "vggishish16.pt",
31
+ "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
32
+ "melception": "melception-21-05-10T09-28-40.pt",
33
+ }
34
+
35
+ MD5_MAP = {
36
+ "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
37
+ "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
38
+ "melception": "a71a41041e945b457c7d3d814bbcf72d",
39
+ }
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+
44
+ def read_list(fname):
45
+ result = []
46
+ with open(fname, "r") as f:
47
+ for each in f.readlines():
48
+ each = each.strip("\n")
49
+ result.append(each)
50
+ return result
51
+
52
+
53
+ def build_dataset_json_from_list(list_path):
54
+ data = []
55
+ for each in read_list(list_path):
56
+ if "|" in each:
57
+ wav, caption = each.split("|")
58
+ else:
59
+ caption = each
60
+ wav = ""
61
+ data.append(
62
+ {
63
+ "wav": wav,
64
+ "caption": caption,
65
+ }
66
+ )
67
+ return {"data": data}
68
+
69
+
70
+ def load_json(fname):
71
+ with open(fname, "r") as f:
72
+ data = json.load(f)
73
+ return data
74
+
75
+
76
+ def read_json(dataset_json_file):
77
+ with open(dataset_json_file, "r") as fp:
78
+ data_json = json.load(fp)
79
+ return data_json["data"]
80
+
81
+
82
+ def copy_test_subset_data(metadata, testset_copy_target_path):
83
+ # metadata = read_json(testset_metadata)
84
+ os.makedirs(testset_copy_target_path, exist_ok=True)
85
+ if len(os.listdir(testset_copy_target_path)) == len(metadata):
86
+ return
87
+ else:
88
+ # delete files in folder testset_copy_target_path
89
+ for file in os.listdir(testset_copy_target_path):
90
+ try:
91
+ os.remove(os.path.join(testset_copy_target_path, file))
92
+ except Exception as e:
93
+ print(e)
94
+
95
+ print("Copying test subset data to {}".format(testset_copy_target_path))
96
+ for each in tqdm(metadata):
97
+ cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
98
+ os.system(cmd)
99
+
100
+
101
+ def listdir_nohidden(path):
102
+ for f in os.listdir(path):
103
+ if not f.startswith("."):
104
+ yield f
105
+
106
+
107
+ def get_restore_step(path):
108
+ checkpoints = os.listdir(path)
109
+ if os.path.exists(os.path.join(path, "final.ckpt")):
110
+ return "final.ckpt", 0
111
+ elif not os.path.exists(os.path.join(path, "last.ckpt")):
112
+ steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
113
+ return checkpoints[np.argmax(steps)], np.max(steps)
114
+ else:
115
+ steps = []
116
+ for x in checkpoints:
117
+ if "last" in x:
118
+ if "-v" not in x:
119
+ fname = "last.ckpt"
120
+ else:
121
+ this_version = int(x.split(".ckpt")[0].split("-v")[1])
122
+ steps.append(this_version)
123
+ if len(steps) == 0 or this_version > np.max(steps):
124
+ fname = "last-v%s.ckpt" % this_version
125
+ return fname, 0
126
+
127
+
128
+ def download(url, local_path, chunk_size=1024):
129
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
130
+ with requests.get(url, stream=True) as r:
131
+ total_size = int(r.headers.get("content-length", 0))
132
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
133
+ with open(local_path, "wb") as f:
134
+ for data in r.iter_content(chunk_size=chunk_size):
135
+ if data:
136
+ f.write(data)
137
+ pbar.update(chunk_size)
138
+
139
+
140
+ def md5_hash(path):
141
+ with open(path, "rb") as f:
142
+ content = f.read()
143
+ return hashlib.md5(content).hexdigest()
144
+
145
+
146
+ def get_ckpt_path(name, root, check=False):
147
+ assert name in URL_MAP
148
+ path = os.path.join(root, CKPT_MAP[name])
149
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
150
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
151
+ download(URL_MAP[name], path)
152
+ md5 = md5_hash(path)
153
+ assert md5 == MD5_MAP[name], md5
154
+ return path
155
+
156
+
157
+ class KeyNotFoundError(Exception):
158
+ def __init__(self, cause, keys=None, visited=None):
159
+ self.cause = cause
160
+ self.keys = keys
161
+ self.visited = visited
162
+ messages = list()
163
+ if keys is not None:
164
+ messages.append("Key not found: {}".format(keys))
165
+ if visited is not None:
166
+ messages.append("Visited: {}".format(visited))
167
+ messages.append("Cause:\n{}".format(cause))
168
+ message = "\n".join(messages)
169
+ super().__init__(message)
170
+
171
+
172
+ def retrieve(
173
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
174
+ ):
175
+ """Given a nested list or dict return the desired value at key expanding
176
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
177
+ is done in-place.
178
+
179
+ Parameters
180
+ ----------
181
+ list_or_dict : list or dict
182
+ Possibly nested list or dictionary.
183
+ key : str
184
+ key/to/value, path like string describing all keys necessary to
185
+ consider to get to the desired value. List indices can also be
186
+ passed here.
187
+ splitval : str
188
+ String that defines the delimiter between keys of the
189
+ different depth levels in `key`.
190
+ default : obj
191
+ Value returned if :attr:`key` is not found.
192
+ expand : bool
193
+ Whether to expand callable nodes on the path or not.
194
+
195
+ Returns
196
+ -------
197
+ The desired value or if :attr:`default` is not ``None`` and the
198
+ :attr:`key` is not found returns ``default``.
199
+
200
+ Raises
201
+ ------
202
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
203
+ ``None``.
204
+ """
205
+
206
+ keys = key.split(splitval)
207
+
208
+ success = True
209
+ try:
210
+ visited = []
211
+ parent = None
212
+ last_key = None
213
+ for key in keys:
214
+ if callable(list_or_dict):
215
+ if not expand:
216
+ raise KeyNotFoundError(
217
+ ValueError(
218
+ "Trying to get past callable node with expand=False."
219
+ ),
220
+ keys=keys,
221
+ visited=visited,
222
+ )
223
+ list_or_dict = list_or_dict()
224
+ parent[last_key] = list_or_dict
225
+
226
+ last_key = key
227
+ parent = list_or_dict
228
+
229
+ try:
230
+ if isinstance(list_or_dict, dict):
231
+ list_or_dict = list_or_dict[key]
232
+ else:
233
+ list_or_dict = list_or_dict[int(key)]
234
+ except (KeyError, IndexError, ValueError) as e:
235
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
236
+
237
+ visited += [key]
238
+ # final expansion of retrieved value
239
+ if expand and callable(list_or_dict):
240
+ list_or_dict = list_or_dict()
241
+ parent[last_key] = list_or_dict
242
+ except KeyNotFoundError as e:
243
+ if default is None:
244
+ raise e
245
+ else:
246
+ list_or_dict = default
247
+ success = False
248
+
249
+ if not pass_success:
250
+ return list_or_dict
251
+ else:
252
+ return list_or_dict, success
253
+
254
+
255
+ def to_device(data, device):
256
+ if len(data) == 12:
257
+ (
258
+ ids,
259
+ raw_texts,
260
+ speakers,
261
+ texts,
262
+ src_lens,
263
+ max_src_len,
264
+ mels,
265
+ mel_lens,
266
+ max_mel_len,
267
+ pitches,
268
+ energies,
269
+ durations,
270
+ ) = data
271
+
272
+ speakers = torch.from_numpy(speakers).long().to(device)
273
+ texts = torch.from_numpy(texts).long().to(device)
274
+ src_lens = torch.from_numpy(src_lens).to(device)
275
+ mels = torch.from_numpy(mels).float().to(device)
276
+ mel_lens = torch.from_numpy(mel_lens).to(device)
277
+ pitches = torch.from_numpy(pitches).float().to(device)
278
+ energies = torch.from_numpy(energies).to(device)
279
+ durations = torch.from_numpy(durations).long().to(device)
280
+
281
+ return (
282
+ ids,
283
+ raw_texts,
284
+ speakers,
285
+ texts,
286
+ src_lens,
287
+ max_src_len,
288
+ mels,
289
+ mel_lens,
290
+ max_mel_len,
291
+ pitches,
292
+ energies,
293
+ durations,
294
+ )
295
+
296
+ if len(data) == 6:
297
+ (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
298
+
299
+ speakers = torch.from_numpy(speakers).long().to(device)
300
+ texts = torch.from_numpy(texts).long().to(device)
301
+ src_lens = torch.from_numpy(src_lens).to(device)
302
+
303
+ return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
304
+
305
+
306
+ def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
307
+ # if losses is not None:
308
+ # logger.add_scalar("Loss/total_loss", losses[0], step)
309
+ # logger.add_scalar("Loss/mel_loss", losses[1], step)
310
+ # logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
311
+ # logger.add_scalar("Loss/pitch_loss", losses[3], step)
312
+ # logger.add_scalar("Loss/energy_loss", losses[4], step)
313
+ # logger.add_scalar("Loss/duration_loss", losses[5], step)
314
+ # if(len(losses) > 6):
315
+ # logger.add_scalar("Loss/disc_loss", losses[6], step)
316
+ # logger.add_scalar("Loss/fmap_loss", losses[7], step)
317
+ # logger.add_scalar("Loss/r_loss", losses[8], step)
318
+ # logger.add_scalar("Loss/g_loss", losses[9], step)
319
+ # logger.add_scalar("Loss/gen_loss", losses[10], step)
320
+ # logger.add_scalar("Loss/diff_loss", losses[11], step)
321
+
322
+ if fig is not None:
323
+ logger.add_figure(tag, fig)
324
+
325
+ if audio is not None:
326
+ audio = audio / (max(abs(audio)) * 1.1)
327
+ logger.add_audio(
328
+ tag,
329
+ audio,
330
+ sample_rate=sampling_rate,
331
+ )
332
+
333
+
334
+ def get_mask_from_lengths(lengths, max_len=None):
335
+ batch_size = lengths.shape[0]
336
+ if max_len is None:
337
+ max_len = torch.max(lengths).item()
338
+
339
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
340
+ mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
341
+
342
+ return mask
343
+
344
+
345
+ def expand(values, durations):
346
+ out = list()
347
+ for value, d in zip(values, durations):
348
+ out += [value] * max(0, int(d))
349
+ return np.array(out)
350
+
351
+
352
+ def synth_one_sample_val(
353
+ targets, predictions, vocoder, model_config, preprocess_config
354
+ ):
355
+ index = np.random.choice(list(np.arange(targets[6].size(0))))
356
+
357
+ basename = targets[0][index]
358
+ src_len = predictions[8][index].item()
359
+ mel_len = predictions[9][index].item()
360
+ mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
361
+
362
+ mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
363
+ postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
364
+ duration = targets[11][index, :src_len].detach().cpu().numpy()
365
+
366
+ if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
367
+ pitch = predictions[2][index, :src_len].detach().cpu().numpy()
368
+ pitch = expand(pitch, duration)
369
+ else:
370
+ pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
371
+
372
+ if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
373
+ energy = predictions[3][index, :src_len].detach().cpu().numpy()
374
+ energy = expand(energy, duration)
375
+ else:
376
+ energy = predictions[3][index, :mel_len].detach().cpu().numpy()
377
+
378
+ with open(
379
+ os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
380
+ ) as f:
381
+ stats = json.load(f)
382
+ stats = stats["pitch"] + stats["energy"][:2]
383
+
384
+ # from datetime import datetime
385
+ # now = datetime.now()
386
+ # current_time = now.strftime("%D:%H:%M:%S")
387
+ # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
388
+ # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
389
+ # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
390
+
391
+ fig = plot_mel(
392
+ [
393
+ (mel_prediction.cpu().numpy(), pitch, energy),
394
+ (postnet_mel_prediction.cpu().numpy(), pitch, energy),
395
+ (mel_target.cpu().numpy(), pitch, energy),
396
+ ],
397
+ stats,
398
+ [
399
+ "Raw mel spectrogram prediction",
400
+ "Postnet mel prediction",
401
+ "Ground-Truth Spectrogram",
402
+ ],
403
+ )
404
+
405
+ if vocoder is not None:
406
+ from .model_util import vocoder_infer
407
+
408
+ wav_reconstruction = vocoder_infer(
409
+ mel_target.unsqueeze(0),
410
+ vocoder,
411
+ model_config,
412
+ preprocess_config,
413
+ )[0]
414
+ wav_prediction = vocoder_infer(
415
+ postnet_mel_prediction.unsqueeze(0),
416
+ vocoder,
417
+ model_config,
418
+ preprocess_config,
419
+ )[0]
420
+ else:
421
+ wav_reconstruction = wav_prediction = None
422
+
423
+ return fig, wav_reconstruction, wav_prediction, basename
424
+
425
+
426
+ def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
427
+ if vocoder is not None:
428
+ from .model_util import vocoder_infer
429
+
430
+ wav_reconstruction = vocoder_infer(
431
+ mel_input.permute(0, 2, 1),
432
+ vocoder,
433
+ )
434
+ wav_prediction = vocoder_infer(
435
+ mel_prediction.permute(0, 2, 1),
436
+ vocoder,
437
+ )
438
+ else:
439
+ wav_reconstruction = wav_prediction = None
440
+
441
+ return wav_reconstruction, wav_prediction
442
+
443
+
444
+ def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
445
+ # (diff_output, diff_loss, latent_loss) = diffusion
446
+
447
+ basenames = targets[0]
448
+
449
+ for i in range(len(predictions[1])):
450
+ basename = basenames[i]
451
+ src_len = predictions[8][i].item()
452
+ mel_len = predictions[9][i].item()
453
+ mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
454
+ # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
455
+ # duration = predictions[5][i, :src_len].detach().cpu().numpy()
456
+ if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
457
+ pitch = predictions[2][i, :src_len].detach().cpu().numpy()
458
+ # pitch = expand(pitch, duration)
459
+ else:
460
+ pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
461
+ if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
462
+ energy = predictions[3][i, :src_len].detach().cpu().numpy()
463
+ # energy = expand(energy, duration)
464
+ else:
465
+ energy = predictions[3][i, :mel_len].detach().cpu().numpy()
466
+ # import ipdb; ipdb.set_trace()
467
+ with open(
468
+ os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
469
+ ) as f:
470
+ stats = json.load(f)
471
+ stats = stats["pitch"] + stats["energy"][:2]
472
+
473
+ fig = plot_mel(
474
+ [
475
+ (mel_prediction.cpu().numpy(), pitch, energy),
476
+ ],
477
+ stats,
478
+ ["Synthetized Spectrogram by PostNet"],
479
+ )
480
+ # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
481
+ plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
482
+ plt.close()
483
+
484
+ from .model_util import vocoder_infer
485
+
486
+ mel_predictions = predictions[1].transpose(1, 2)
487
+ lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
488
+ wav_predictions = vocoder_infer(
489
+ mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
490
+ )
491
+
492
+ sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
493
+ for wav, basename in zip(wav_predictions, basenames):
494
+ wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
495
+
496
+
497
+ def plot_mel(data, titles=None):
498
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
499
+ if titles is None:
500
+ titles = [None for i in range(len(data))]
501
+
502
+ for i in range(len(data)):
503
+ mel = data[i]
504
+ axes[i][0].imshow(mel, origin="lower", aspect="auto")
505
+ axes[i][0].set_aspect(2.5, adjustable="box")
506
+ axes[i][0].set_ylim(0, mel.shape[0])
507
+ axes[i][0].set_title(titles[i], fontsize="medium")
508
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
509
+ axes[i][0].set_anchor("W")
510
+
511
+ return fig
512
+
513
+
514
+ def pad_1D(inputs, PAD=0):
515
+ def pad_data(x, length, PAD):
516
+ x_padded = np.pad(
517
+ x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
518
+ )
519
+ return x_padded
520
+
521
+ max_len = max((len(x) for x in inputs))
522
+ padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
523
+
524
+ return padded
525
+
526
+
527
+ def pad_2D(inputs, maxlen=None):
528
+ def pad(x, max_len):
529
+ PAD = 0
530
+ if np.shape(x)[0] > max_len:
531
+ raise ValueError("not max_len")
532
+
533
+ s = np.shape(x)[1]
534
+ x_padded = np.pad(
535
+ x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
536
+ )
537
+ return x_padded[:, :s]
538
+
539
+ if maxlen:
540
+ output = np.stack([pad(x, maxlen) for x in inputs])
541
+ else:
542
+ max_len = max(np.shape(x)[0] for x in inputs)
543
+ output = np.stack([pad(x, max_len) for x in inputs])
544
+
545
+ return output
546
+
547
+
548
+ def pad(input_ele, mel_max_length=None):
549
+ if mel_max_length:
550
+ max_len = mel_max_length
551
+ else:
552
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
553
+
554
+ out_list = list()
555
+ for i, batch in enumerate(input_ele):
556
+ if len(batch.shape) == 1:
557
+ one_batch_padded = F.pad(
558
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
559
+ )
560
+ elif len(batch.shape) == 2:
561
+ one_batch_padded = F.pad(
562
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
563
+ )
564
+ out_list.append(one_batch_padded)
565
+ out_padded = torch.stack(out_list)
566
+ return out_padded