crystal-technologies
commited on
Commit
•
2d8da09
1
Parent(s):
c1bb68d
Upload 1287 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SoundScribe/SpeakerID/Dockerfile +140 -0
- SoundScribe/SpeakerID/Jenkinsfile +0 -0
- SoundScribe/SpeakerID/LICENSE +201 -0
- SoundScribe/SpeakerID/ci.groovy +119 -0
- SoundScribe/SpeakerID/external/get_collections.py +90 -0
- SoundScribe/SpeakerID/external/get_modules.py +159 -0
- SoundScribe/SpeakerID/nemo/README.md +9 -0
- SoundScribe/SpeakerID/nemo/__init__.py +28 -0
- SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-39.pyc +0 -0
- SoundScribe/SpeakerID/nemo/__pycache__/constants.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/__pycache__/package_info.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/__init__.py +13 -0
- SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-39.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/__init__.py +25 -0
- SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-39.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__init__.py +13 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/__init__.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio_dataset.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_diar_label.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label_dataset.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dali.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dataset.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label_dataset.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio.py +1136 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio_dataset.py +95 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_ctm_dataset.py +95 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_diar_label.py +853 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label.py +1294 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label_dataset.py +304 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text.py +1366 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dali.py +772 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dataset.py +950 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/data_simulation.py +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label.py +497 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label_dataset.py +68 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text.py +488 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text_dataset.py +94 -0
- SoundScribe/SpeakerID/nemo/collections/asr/data/text_to_text.py +482 -0
- SoundScribe/SpeakerID/nemo/collections/asr/losses/__init__.py +22 -0
- SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/__init__.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/angularloss.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/audio_losses.cpython-310.pyc +0 -0
- SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/ctc.cpython-310.pyc +0 -0
SoundScribe/SpeakerID/Dockerfile
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# syntax=docker/dockerfile:experimental
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.08-py3
|
18 |
+
|
19 |
+
# build an image that includes only the nemo dependencies, ensures that dependencies
|
20 |
+
# are included first for optimal caching, and useful for building a development
|
21 |
+
# image (by specifying build target as `nemo-deps`)
|
22 |
+
FROM ${BASE_IMAGE} as nemo-deps
|
23 |
+
|
24 |
+
# dependency flags; should be declared after FROM
|
25 |
+
# torchaudio: not required by default
|
26 |
+
ARG REQUIRE_TORCHAUDIO=false
|
27 |
+
# k2: not required by default
|
28 |
+
ARG REQUIRE_K2=false
|
29 |
+
# ais cli: not required by default, install only if required
|
30 |
+
ARG REQUIRE_AIS_CLI=false
|
31 |
+
|
32 |
+
# Ensure apt-get won't prompt for selecting options
|
33 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
34 |
+
# libavdevice-dev rerquired for latest torchaudio
|
35 |
+
RUN apt-get update && \
|
36 |
+
apt-get upgrade -y && \
|
37 |
+
apt-get install -y \
|
38 |
+
libsndfile1 sox \
|
39 |
+
libfreetype6 \
|
40 |
+
swig \
|
41 |
+
ffmpeg \
|
42 |
+
libavdevice-dev && \
|
43 |
+
rm -rf /var/lib/apt/lists/*
|
44 |
+
|
45 |
+
WORKDIR /workspace/
|
46 |
+
# install megatron core, this can be removed once 0.3 pip package is released
|
47 |
+
RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
|
48 |
+
cd Megatron-LM && \
|
49 |
+
git checkout ab0336a5c8eab77aa74ae604ba1e73decbf6d560 && \
|
50 |
+
pip install -e .
|
51 |
+
|
52 |
+
WORKDIR /tmp/
|
53 |
+
|
54 |
+
# Distributed Adam support for multiple dtypes
|
55 |
+
RUN git clone https://github.com/NVIDIA/apex.git && \
|
56 |
+
cd apex && \
|
57 |
+
git checkout 52e18c894223800cb611682dce27d88050edf1de && \
|
58 |
+
pip3 install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./
|
59 |
+
|
60 |
+
# uninstall stuff from base container
|
61 |
+
RUN pip3 uninstall -y sacrebleu torchtext
|
62 |
+
|
63 |
+
# build torchaudio
|
64 |
+
WORKDIR /tmp/torchaudio_build
|
65 |
+
COPY scripts/installers /tmp/torchaudio_build/scripts/installers/
|
66 |
+
RUN INSTALL_MSG=$(/bin/bash /tmp/torchaudio_build/scripts/installers/install_torchaudio_latest.sh); INSTALL_CODE=$?; \
|
67 |
+
echo ${INSTALL_MSG}; \
|
68 |
+
if [ ${INSTALL_CODE} -ne 0 ]; then \
|
69 |
+
echo "torchaudio installation failed"; \
|
70 |
+
if [ "${REQUIRE_TORCHAUDIO}" = true ]; then \
|
71 |
+
exit ${INSTALL_CODE}; \
|
72 |
+
else echo "Skipping failed torchaudio installation"; fi \
|
73 |
+
else echo "torchaudio installed successfully"; fi
|
74 |
+
|
75 |
+
# install nemo dependencies
|
76 |
+
WORKDIR /tmp/nemo
|
77 |
+
COPY requirements .
|
78 |
+
RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done
|
79 |
+
|
80 |
+
# install flash attention dependencies
|
81 |
+
RUN pip install flash-attn
|
82 |
+
# pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
|
83 |
+
RUN pip install triton==2.0.0.dev20221202
|
84 |
+
# install numba for latest containers
|
85 |
+
RUN pip install numba>=0.57.1
|
86 |
+
|
87 |
+
# install k2, skip if installation fails
|
88 |
+
COPY scripts /tmp/nemo/scripts/
|
89 |
+
RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/speech_recognition/k2/setup.sh); INSTALL_CODE=$?; \
|
90 |
+
echo ${INSTALL_MSG}; \
|
91 |
+
if [ ${INSTALL_CODE} -ne 0 ]; then \
|
92 |
+
echo "k2 installation failed"; \
|
93 |
+
if [ "${REQUIRE_K2}" = true ]; then \
|
94 |
+
exit ${INSTALL_CODE}; \
|
95 |
+
else echo "Skipping failed k2 installation"; fi \
|
96 |
+
else echo "k2 installed successfully"; fi
|
97 |
+
|
98 |
+
# copy nemo source into a scratch image
|
99 |
+
FROM scratch as nemo-src
|
100 |
+
COPY . .
|
101 |
+
|
102 |
+
# start building the final container
|
103 |
+
FROM nemo-deps as nemo
|
104 |
+
ARG NEMO_VERSION=1.21.0
|
105 |
+
|
106 |
+
# Check that NEMO_VERSION is set. Build will fail without this. Expose NEMO and base container
|
107 |
+
# version information as runtime environment variable for introspection purposes
|
108 |
+
RUN /usr/bin/test -n "$NEMO_VERSION" && \
|
109 |
+
/bin/echo "export NEMO_VERSION=${NEMO_VERSION}" >> /root/.bashrc && \
|
110 |
+
/bin/echo "export BASE_IMAGE=${BASE_IMAGE}" >> /root/.bashrc
|
111 |
+
|
112 |
+
# Install NeMo
|
113 |
+
RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]"
|
114 |
+
|
115 |
+
# Check install
|
116 |
+
RUN python -c "import nemo.collections.nlp as nemo_nlp" && \
|
117 |
+
python -c "import nemo.collections.tts as nemo_tts" && \
|
118 |
+
python -c "import nemo_text_processing.text_normalization as text_normalization"
|
119 |
+
|
120 |
+
|
121 |
+
# copy scripts/examples/tests into container for end user
|
122 |
+
WORKDIR /workspace/nemo
|
123 |
+
COPY scripts /workspace/nemo/scripts
|
124 |
+
COPY examples /workspace/nemo/examples
|
125 |
+
COPY tests /workspace/nemo/tests
|
126 |
+
COPY tutorials /workspace/nemo/tutorials
|
127 |
+
# COPY README.rst LICENSE /workspace/nemo/
|
128 |
+
|
129 |
+
RUN printf "#!/bin/bash\njupyter lab --no-browser --allow-root --ip=0.0.0.0" >> start-jupyter.sh && \
|
130 |
+
chmod +x start-jupyter.sh
|
131 |
+
|
132 |
+
# If required, install AIS CLI
|
133 |
+
RUN if [ "${REQUIRE_AIS_CLI}" = true ]; then \
|
134 |
+
INSTALL_MSG=$(/bin/bash scripts/installers/install_ais_cli_latest.sh); INSTALL_CODE=$?; \
|
135 |
+
echo ${INSTALL_MSG}; \
|
136 |
+
if [ ${INSTALL_CODE} -ne 0 ]; then \
|
137 |
+
echo "AIS CLI installation failed"; \
|
138 |
+
exit ${INSTALL_CODE}; \
|
139 |
+
else echo "AIS CLI installed successfully"; fi \
|
140 |
+
else echo "Skipping AIS CLI installation"; fi
|
SoundScribe/SpeakerID/Jenkinsfile
ADDED
The diff for this file is too large to render.
See raw diff
|
|
SoundScribe/SpeakerID/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
SoundScribe/SpeakerID/ci.groovy
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@Library('blossom-github-lib@master')
|
2 |
+
import ipp.blossom.*
|
3 |
+
|
4 |
+
podTemplate(cloud:'sc-ipp-blossom-prod', yaml : """
|
5 |
+
apiVersion: v1
|
6 |
+
kind: Pod
|
7 |
+
metadata:
|
8 |
+
labels:
|
9 |
+
some-label: some-label-value
|
10 |
+
spec:
|
11 |
+
volumes:
|
12 |
+
- name: scratch
|
13 |
+
nfs:
|
14 |
+
server: ipp1-cdot01-col01
|
15 |
+
path: /vol/scratch1/scratch.okuchaiev_blossom
|
16 |
+
containers:
|
17 |
+
- name: latestdlfw
|
18 |
+
image: nvcr.io/nvidia/pytorch:23.02-py3
|
19 |
+
command:
|
20 |
+
- cat
|
21 |
+
volumeMounts:
|
22 |
+
- name: scratch
|
23 |
+
mountPath: /testdata
|
24 |
+
resources:
|
25 |
+
limits:
|
26 |
+
nvidia.com/gpu: 2
|
27 |
+
restartPolicy: Never
|
28 |
+
backoffLimit: 4
|
29 |
+
tty: true
|
30 |
+
shm-size: 32g
|
31 |
+
nodeSelector:
|
32 |
+
kubernetes.io/os: linux
|
33 |
+
nvidia.com/gpu_type: "Tesla_T4x4"
|
34 |
+
nvidia.com/node_type: gpu_tester
|
35 |
+
nvidia.com/driver_version: "510.20"
|
36 |
+
"""
|
37 |
+
) {
|
38 |
+
node(POD_LABEL) {
|
39 |
+
def githubHelper
|
40 |
+
stage('Get Token') {
|
41 |
+
withCredentials([usernamePassword(credentialsId: 'GHAtoken', passwordVariable: 'GIT_PASSWORD', usernameVariable: 'GIT_USERNAME')]) {
|
42 |
+
// create new instance of helper object
|
43 |
+
githubHelper = GithubHelper.getInstance("${GIT_PASSWORD}", githubData)
|
44 |
+
}
|
45 |
+
|
46 |
+
}
|
47 |
+
def stageName = ''
|
48 |
+
try {
|
49 |
+
currentBuild.description = githubHelper.getBuildDescription()
|
50 |
+
container('latestdlfw') {
|
51 |
+
stage('Code checkout') {
|
52 |
+
// update status on github
|
53 |
+
githubHelper.updateCommitStatus("$BUILD_URL", "$stageName Running", GitHubCommitState.PENDING)
|
54 |
+
checkout changelog: true, poll: true, scm: [$class: 'GitSCM', branches: [[name: "pr/"+githubHelper.getPRNumber()]],
|
55 |
+
doGenerateSubmoduleConfigurations: false,
|
56 |
+
submoduleCfg: [],
|
57 |
+
userRemoteConfigs: [[credentialsId: 'github-token', url: githubHelper.getCloneUrl(), refspec: '+refs/pull/*/head:refs/remotes/origin/pr/*']]]
|
58 |
+
}
|
59 |
+
|
60 |
+
stage('Code Style') {
|
61 |
+
sh "apt-get update && \
|
62 |
+
apt-get install -y bc && \
|
63 |
+
nvidia-smi && \
|
64 |
+
pip install -r requirements/requirements_test.txt && \
|
65 |
+
python setup.py style && ls -l /testdata/TestData && ln -s /testdata/TestData /home/TestData && \
|
66 |
+
ls -l /home && ls -l /home/TestData"
|
67 |
+
}
|
68 |
+
|
69 |
+
stage('Installation') {
|
70 |
+
sh "git config --global --add safe.directory '*' && nvidia-smi && ./reinstall.sh release"
|
71 |
+
}
|
72 |
+
|
73 |
+
stage('L0: GPU unit tests') {
|
74 |
+
sh "NEMO_NUMBA_MINVER=0.53 pytest -m 'not pleasefixme'"
|
75 |
+
}
|
76 |
+
|
77 |
+
parallel( //USE CUDA_VISIBLE_DEVICES to execute 2 single GPU tests in parallel here
|
78 |
+
[
|
79 |
+
"L1: NMT Training Pre-LN": { sh 'CUDA_VISIBLE_DEVICES=0 python examples/nlp/machine_translation/enc_dec_nmt.py \
|
80 |
+
--config-path=conf \
|
81 |
+
--config-name=aayn_base \
|
82 |
+
do_testing=true \
|
83 |
+
model.train_ds.src_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
84 |
+
model.train_ds.tgt_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
|
85 |
+
model.validation_ds.src_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
86 |
+
model.validation_ds.tgt_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
87 |
+
model.test_ds.src_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
88 |
+
model.test_ds.tgt_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
89 |
+
model.encoder_tokenizer.tokenizer_model=/testdata/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
|
90 |
+
model.decoder_tokenizer.tokenizer_model=/testdata/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
|
91 |
+
model.encoder.pre_ln=true \
|
92 |
+
model.decoder.pre_ln=true \
|
93 |
+
trainer.devices=[0] \
|
94 |
+
trainer.accelerator="gpu" \
|
95 |
+
+trainer.fast_dev_run=true \
|
96 |
+
+trainer.limit_test_batches=2 \
|
97 |
+
exp_manager=null \
|
98 |
+
'},
|
99 |
+
"L1: Speech to text": { sh 'CUDA_VISIBLE_DEVICES=1 python examples/asr/asr_ctc/speech_to_text_ctc.py \
|
100 |
+
model.train_ds.manifest_filepath=/testdata/TestData/an4_dataset/an4_train.json \
|
101 |
+
model.validation_ds.manifest_filepath=/testdata/TestData/an4_dataset/an4_val.json \
|
102 |
+
trainer.devices=[0] \
|
103 |
+
trainer.accelerator="gpu" \
|
104 |
+
+trainer.fast_dev_run=True \
|
105 |
+
exp_manager=null \
|
106 |
+
'}
|
107 |
+
]
|
108 |
+
)//end of parallel
|
109 |
+
}
|
110 |
+
githubHelper.updateCommitStatus("$BUILD_URL", "Complete", GitHubCommitState.SUCCESS)
|
111 |
+
}
|
112 |
+
catch (Exception ex){
|
113 |
+
currentBuild.result = 'FAILURE'
|
114 |
+
println ex
|
115 |
+
githubHelper.updateCommitStatus("$BUILD_URL", "$stageName Failed", GitHubCommitState.FAILURE)
|
116 |
+
}
|
117 |
+
|
118 |
+
}
|
119 |
+
}
|
SoundScribe/SpeakerID/external/get_collections.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Script responsible for generation of a JSON file with list of NeMo collections. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import importlib
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
|
23 |
+
import nemo
|
24 |
+
from nemo.utils import logging
|
25 |
+
|
26 |
+
|
27 |
+
def process_collection(id, col):
|
28 |
+
""" Helper function processing the collection.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
id: (short) name of the collection.
|
32 |
+
col: a collection (python module).
|
33 |
+
"""
|
34 |
+
return {
|
35 |
+
"id": id,
|
36 |
+
"name": col.__name__,
|
37 |
+
"description": col.__description__,
|
38 |
+
"version": col.__version__,
|
39 |
+
"author": col.__author__,
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
""" Main function generating a JSON file with list of NeMo collections. """
|
45 |
+
# Parse filename.
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="collections.json")
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
# Get collections directory.
|
51 |
+
colletions_dir = os.path.dirname(nemo.collections.__file__)
|
52 |
+
logging.info('Analysing collections in `{}`'.format(colletions_dir))
|
53 |
+
|
54 |
+
# Generate list of NeMo collections - from the list of collection subfolders.
|
55 |
+
collections = {}
|
56 |
+
for sub_dir in os.listdir(colletions_dir):
|
57 |
+
# Skip cache.
|
58 |
+
if sub_dir == "__pycache__":
|
59 |
+
continue
|
60 |
+
# Check if it is a directory.
|
61 |
+
if os.path.isdir(os.path.join(colletions_dir, sub_dir)):
|
62 |
+
collections[sub_dir] = "nemo.collections." + sub_dir
|
63 |
+
|
64 |
+
output_list = []
|
65 |
+
# Iterate over all collections.
|
66 |
+
for key, val in collections.items():
|
67 |
+
# Try to get module specification.
|
68 |
+
module_spec = importlib.util.find_spec(val)
|
69 |
+
if module_spec is None:
|
70 |
+
logging.warning(" * Failed to process `{}`".format(val))
|
71 |
+
else:
|
72 |
+
try:
|
73 |
+
# Import the module from the module specification.
|
74 |
+
module = importlib.util.module_from_spec(module_spec)
|
75 |
+
module_spec.loader.exec_module(module)
|
76 |
+
# Add to list.
|
77 |
+
output_list.append(process_collection(key, module))
|
78 |
+
logging.info(" * Processed `{}`".format(val))
|
79 |
+
except AttributeError:
|
80 |
+
logging.warning(" * Failed to process `{}`".format(val))
|
81 |
+
|
82 |
+
# Export to JSON.
|
83 |
+
with open(args.filename, 'w', encoding='utf-8') as outfile:
|
84 |
+
json.dump(output_list, outfile)
|
85 |
+
|
86 |
+
logging.info('Finshed the analysis, results exported to `{}`.'.format(args.filename))
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == '__main__':
|
90 |
+
main()
|
SoundScribe/SpeakerID/external/get_modules.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Script responsible for generation of a JSON file containing list of modules of a given collection. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import importlib
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
|
24 |
+
import nemo
|
25 |
+
from nemo.utils import logging
|
26 |
+
|
27 |
+
|
28 |
+
def process_member(name, obj, module_list):
|
29 |
+
""" Helper function processing the passed object and, if ok, adding a record to the module list.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
name: name of the member
|
33 |
+
obj: member (class/function etc.)
|
34 |
+
module_list: list of modules that (probably) will be expanded.
|
35 |
+
"""
|
36 |
+
# It is not a class - skip it.
|
37 |
+
if not inspect.isclass(obj):
|
38 |
+
return
|
39 |
+
|
40 |
+
# Check inheritance - we know that all our datasets/modules/losses inherit from Serialization,
|
41 |
+
# Btw. Serialization is also required by this script.
|
42 |
+
if not issubclass(obj, nemo.core.Serialization):
|
43 |
+
return
|
44 |
+
|
45 |
+
logging.info(" * Processing `{}`".format(str(obj)))
|
46 |
+
|
47 |
+
module_list.append(
|
48 |
+
{
|
49 |
+
"name": name,
|
50 |
+
"cls": str(obj),
|
51 |
+
# Temporary solution: mockup arguments.
|
52 |
+
"arguments": [
|
53 |
+
"jasper",
|
54 |
+
"activation",
|
55 |
+
"feat_in",
|
56 |
+
"normalization_mode",
|
57 |
+
"residual_mode",
|
58 |
+
"norm_groups",
|
59 |
+
"conv_mask",
|
60 |
+
"frame_splicing",
|
61 |
+
"init_mode",
|
62 |
+
],
|
63 |
+
# Temporary solution: mockup input types.
|
64 |
+
"input_types": {
|
65 |
+
"audio_signal": "axes: (batch, dimension, time); elements_type: MelSpectrogramType",
|
66 |
+
"length": "axes: (batch,); elements_type: LengthType",
|
67 |
+
},
|
68 |
+
# Temporary solution: mockup output types.
|
69 |
+
"output_types": {
|
70 |
+
"encoder_output": "axes: (batch, dimension, time); elements_type: AcousticEncodedRepresentation"
|
71 |
+
},
|
72 |
+
}
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
""" Main function analysing the indicated NeMo collection and generating a JSON file with module descriptions. """
|
78 |
+
# Parse filename.
|
79 |
+
parser = argparse.ArgumentParser()
|
80 |
+
parser.add_argument('--collection', help='ID of the collection', type=str)
|
81 |
+
parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="modules.json")
|
82 |
+
args = parser.parse_args()
|
83 |
+
|
84 |
+
# Get collections directory.
|
85 |
+
colletions_dir = os.path.dirname(nemo.collections.__file__)
|
86 |
+
logging.info('Analysing collections in `{}`'.format(colletions_dir))
|
87 |
+
|
88 |
+
# Generate list of NeMo collections - from the list of collection subfolders.
|
89 |
+
collections = {}
|
90 |
+
for sub_dir in os.listdir(colletions_dir):
|
91 |
+
# Skip cache.
|
92 |
+
if sub_dir == "__pycache__":
|
93 |
+
continue
|
94 |
+
# Check if it is a directory.
|
95 |
+
if os.path.isdir(os.path.join(colletions_dir, sub_dir)):
|
96 |
+
collections[sub_dir] = "nemo.collections." + sub_dir
|
97 |
+
|
98 |
+
# Check the collection.
|
99 |
+
if args.collection not in collections.keys():
|
100 |
+
logging.error("Coudn't process the incidated `{}` collection".format(args.collection))
|
101 |
+
logging.info(
|
102 |
+
"Please select one of the existing collections using `--collection [{}]`".format("|".join(collections))
|
103 |
+
)
|
104 |
+
exit(-1)
|
105 |
+
|
106 |
+
# Load the collection specification.
|
107 |
+
collection_spec = importlib.util.find_spec(collections[args.collection])
|
108 |
+
if collection_spec is None:
|
109 |
+
logging.error("Failed to load the `{}` collection".format(val))
|
110 |
+
|
111 |
+
# Import the module from the module specification.
|
112 |
+
collection = importlib.util.module_from_spec(collection_spec)
|
113 |
+
collection_spec.loader.exec_module(collection)
|
114 |
+
|
115 |
+
module_list = []
|
116 |
+
# Iterate over the packages in the indicated collection.
|
117 |
+
logging.info("Analysing the `{}` collection".format(args.collection))
|
118 |
+
|
119 |
+
try: # Datasets in dataset folder
|
120 |
+
logging.info("Analysing the 'data' package")
|
121 |
+
for name, obj in inspect.getmembers(collection.data):
|
122 |
+
process_member(name, obj, module_list)
|
123 |
+
except AttributeError as e:
|
124 |
+
logging.info(" * No datasets found")
|
125 |
+
|
126 |
+
try: # Datasets in dataset folder
|
127 |
+
logging.info("Analysing the 'datasets' package")
|
128 |
+
for name, obj in inspect.getmembers(collection.datasets):
|
129 |
+
process_member(name, obj, module_list)
|
130 |
+
except AttributeError as e:
|
131 |
+
logging.info(" * No datasets found")
|
132 |
+
|
133 |
+
try: # Modules
|
134 |
+
logging.info("Analysing the 'modules' package")
|
135 |
+
for name, obj in inspect.getmembers(collection.modules):
|
136 |
+
process_member(name, obj, module_list)
|
137 |
+
except AttributeError as e:
|
138 |
+
logging.info(" * No modules found")
|
139 |
+
|
140 |
+
try: # Losses
|
141 |
+
logging.info("Analysing the 'losses' package")
|
142 |
+
for name, obj in inspect.getmembers(collection.losses):
|
143 |
+
process_member(name, obj, module_list)
|
144 |
+
except AttributeError as e:
|
145 |
+
logging.info(" * No losses found")
|
146 |
+
|
147 |
+
# Add prefix - only for default name.
|
148 |
+
filename = args.filename if args.filename != "modules.json" else args.collection + "_" + args.filename
|
149 |
+
# Export to JSON.
|
150 |
+
with open(filename, 'w', encoding='utf-8') as outfile:
|
151 |
+
json.dump(module_list, outfile)
|
152 |
+
|
153 |
+
logging.info(
|
154 |
+
'Finished analysis of the `{}` collection, results exported to `{}`.'.format(args.collection, filename)
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == '__main__':
|
159 |
+
main()
|
SoundScribe/SpeakerID/nemo/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NeMo (**Ne**ural **Mo**dules) is a toolkit for creating AI applications built around **neural modules**, conceptual blocks of neural networks that take *typed* inputs and produce *typed* outputs.
|
2 |
+
|
3 |
+
**NeMo Core** provides common APIs all modules and models have to implement.
|
4 |
+
|
5 |
+
**NeMo Collections**
|
6 |
+
|
7 |
+
* ASR - collection of modules and models for building speech recognition networks
|
8 |
+
* TTS - collection of modules and models for building speech synthesis networks
|
9 |
+
* NLP - collection of modules and models for building NLP networks
|
SoundScribe/SpeakerID/nemo/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from nemo.package_info import (
|
17 |
+
__contact_emails__,
|
18 |
+
__contact_names__,
|
19 |
+
__description__,
|
20 |
+
__download_url__,
|
21 |
+
__homepage__,
|
22 |
+
__keywords__,
|
23 |
+
__license__,
|
24 |
+
__package_name__,
|
25 |
+
__repository_url__,
|
26 |
+
__shortversion__,
|
27 |
+
__version__,
|
28 |
+
)
|
SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (454 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (452 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (549 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/__pycache__/package_info.cpython-310.pyc
ADDED
Binary file (909 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (150 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from nemo.collections.asr import data, losses, models, modules
|
16 |
+
from nemo.package_info import __version__
|
17 |
+
|
18 |
+
# Set collection version equal to NeMo version.
|
19 |
+
__version = __version__
|
20 |
+
|
21 |
+
# Authorship.
|
22 |
+
__author__ = "NVIDIA Corporation"
|
23 |
+
|
24 |
+
# Set collection name.
|
25 |
+
__description__ = "Automatic Speech Recognition collection"
|
SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (429 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (427 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (161 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio.cpython-310.pyc
ADDED
Binary file (37.9 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio_dataset.cpython-310.pyc
ADDED
Binary file (2.42 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_diar_label.cpython-310.pyc
ADDED
Binary file (34.5 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label.cpython-310.pyc
ADDED
Binary file (50.4 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label_dataset.cpython-310.pyc
ADDED
Binary file (7.75 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text.cpython-310.pyc
ADDED
Binary file (50.8 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dali.cpython-310.pyc
ADDED
Binary file (24.9 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dataset.cpython-310.pyc
ADDED
Binary file (23.8 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label.cpython-310.pyc
ADDED
Binary file (16.1 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label_dataset.cpython-310.pyc
ADDED
Binary file (1.78 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio.py
ADDED
@@ -0,0 +1,1136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import abc
|
16 |
+
import math
|
17 |
+
import random
|
18 |
+
from collections import OrderedDict, namedtuple
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
21 |
+
|
22 |
+
import librosa
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
27 |
+
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
|
28 |
+
from nemo.collections.common.parts.preprocessing import collections
|
29 |
+
from nemo.collections.common.parts.utils import flatten
|
30 |
+
from nemo.core.classes import Dataset
|
31 |
+
from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType
|
32 |
+
from nemo.utils import logging
|
33 |
+
|
34 |
+
__all__ = [
|
35 |
+
'AudioToTargetDataset',
|
36 |
+
'AudioToTargetWithReferenceDataset',
|
37 |
+
'AudioToTargetWithEmbeddingDataset',
|
38 |
+
]
|
39 |
+
|
40 |
+
|
41 |
+
def _audio_collate_fn(batch: List[dict]) -> Tuple[torch.Tensor]:
|
42 |
+
"""Collate a batch of items returned by __getitem__.
|
43 |
+
Examples for each signal are zero padded to the same length
|
44 |
+
(batch_length), which is determined by the longest example.
|
45 |
+
Lengths of the original signals are returned in the output.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
batch: List of dictionaries. Each element of the list
|
49 |
+
has the following format
|
50 |
+
```
|
51 |
+
{
|
52 |
+
'signal_0': 1D or 2D tensor,
|
53 |
+
'signal_1': 1D or 2D tensor,
|
54 |
+
...
|
55 |
+
'signal_N': 1D or 2D tensor,
|
56 |
+
}
|
57 |
+
```
|
58 |
+
1D tensors have shape (num_samples,) and 2D tensors
|
59 |
+
have shape (num_channels, num_samples)
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
A tuple containing signal tensor and signal length tensor (in samples)
|
63 |
+
for each signal.
|
64 |
+
The output has the following format:
|
65 |
+
```
|
66 |
+
(signal_0, signal_0_length, signal_1, signal_1_length, ..., signal_N, signal_N_length)
|
67 |
+
```
|
68 |
+
Note that the output format is obtained by interleaving signals and their length.
|
69 |
+
"""
|
70 |
+
signals = batch[0].keys()
|
71 |
+
|
72 |
+
batched = tuple()
|
73 |
+
|
74 |
+
for signal in signals:
|
75 |
+
signal_length = [b[signal].shape[-1] for b in batch]
|
76 |
+
# Batch length is determined by the longest signal in the batch
|
77 |
+
batch_length = max(signal_length)
|
78 |
+
b_signal = []
|
79 |
+
for s_len, b in zip(signal_length, batch):
|
80 |
+
# check if padding is necessary
|
81 |
+
if s_len < batch_length:
|
82 |
+
if b[signal].ndim == 1:
|
83 |
+
# single-channel signal
|
84 |
+
pad = (0, batch_length - s_len)
|
85 |
+
elif b[signal].ndim == 2:
|
86 |
+
# multi-channel signal
|
87 |
+
pad = (0, batch_length - s_len, 0, 0)
|
88 |
+
else:
|
89 |
+
raise RuntimeError(
|
90 |
+
f'Signal {signal} has unsuported dimensions {signal.shape}. Currently, only 1D and 2D arrays are supported.'
|
91 |
+
)
|
92 |
+
b[signal] = torch.nn.functional.pad(b[signal], pad)
|
93 |
+
# append the current padded signal
|
94 |
+
b_signal.append(b[signal])
|
95 |
+
# (signal_batched, signal_length)
|
96 |
+
batched += (torch.stack(b_signal), torch.tensor(signal_length, dtype=torch.int32))
|
97 |
+
|
98 |
+
# Currently, outputs are expected to be in a tuple, where each element must correspond
|
99 |
+
# to the output type in the OrderedDict returned by output_types.
|
100 |
+
#
|
101 |
+
# Therefore, we return batched signals by interleaving signals and their length:
|
102 |
+
# (signal_0, signal_0_length, signal_1, signal_1_length, ...)
|
103 |
+
return batched
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class SignalSetup:
|
108 |
+
signals: List[str] # signal names
|
109 |
+
duration: Optional[Union[float, list]] = None # duration for each signal
|
110 |
+
channel_selectors: Optional[List[ChannelSelectorType]] = None # channel selector for loading each signal
|
111 |
+
|
112 |
+
|
113 |
+
class ASRAudioProcessor:
|
114 |
+
"""Class that processes an example from Audio collection and returns
|
115 |
+
a dictionary with prepared signals.
|
116 |
+
|
117 |
+
For example, the output dictionary may be the following
|
118 |
+
```
|
119 |
+
{
|
120 |
+
'input_signal': input_signal_tensor,
|
121 |
+
'target_signal': target_signal_tensor,
|
122 |
+
'reference_signal': reference_signal_tensor,
|
123 |
+
'embedding_vector': embedding_vector
|
124 |
+
}
|
125 |
+
```
|
126 |
+
Keys in the output dictionary are ordered with synchronous signals given first,
|
127 |
+
followed by asynchronous signals and embedding.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
sample_rate: sample rate used for all audio signals
|
131 |
+
random_offset: If `True`, offset will be randomized when loading a subsegment
|
132 |
+
from a file.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self, sample_rate: float, random_offset: bool,
|
137 |
+
):
|
138 |
+
self.sample_rate = sample_rate
|
139 |
+
self.random_offset = random_offset
|
140 |
+
|
141 |
+
self.sync_setup = None
|
142 |
+
self.async_setup = None
|
143 |
+
self.embedding_setup = None
|
144 |
+
|
145 |
+
@property
|
146 |
+
def sample_rate(self) -> float:
|
147 |
+
return self._sample_rate
|
148 |
+
|
149 |
+
@sample_rate.setter
|
150 |
+
def sample_rate(self, value: float):
|
151 |
+
if value <= 0:
|
152 |
+
raise ValueError(f'Sample rate must be positive, received {value}')
|
153 |
+
|
154 |
+
self._sample_rate = value
|
155 |
+
|
156 |
+
@property
|
157 |
+
def random_offset(self) -> bool:
|
158 |
+
return self._random_offset
|
159 |
+
|
160 |
+
@random_offset.setter
|
161 |
+
def random_offset(self, value: bool):
|
162 |
+
self._random_offset = value
|
163 |
+
|
164 |
+
@property
|
165 |
+
def sync_setup(self) -> SignalSetup:
|
166 |
+
"""Return the current setup for synchronous signals.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
A dataclass containing the list of signals, their
|
170 |
+
duration and channel selectors.
|
171 |
+
"""
|
172 |
+
return self._sync_setup
|
173 |
+
|
174 |
+
@sync_setup.setter
|
175 |
+
def sync_setup(self, value: Optional[SignalSetup]):
|
176 |
+
"""Setup signals to be loaded synchronously.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
value: An instance of SignalSetup with the following fields
|
180 |
+
- signals: list of signals (keys of example.audio_signals) which will be loaded
|
181 |
+
synchronously with the same start time and duration.
|
182 |
+
- duration: Duration for each signal to be loaded.
|
183 |
+
If duration is set to None, the whole file will be loaded.
|
184 |
+
- channel_selectors: A list of channel selector for each signal. If channel selector
|
185 |
+
is None, all channels in the audio file will be loaded.
|
186 |
+
"""
|
187 |
+
if value is None or isinstance(value, SignalSetup):
|
188 |
+
self._sync_setup = value
|
189 |
+
else:
|
190 |
+
raise ValueError(f'Unexpected type {type(value)} for value {value}.')
|
191 |
+
|
192 |
+
@property
|
193 |
+
def async_setup(self) -> SignalSetup:
|
194 |
+
"""Return the current setup for asynchronous signals.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
A dataclass containing the list of signals, their
|
198 |
+
duration and channel selectors.
|
199 |
+
"""
|
200 |
+
return self._async_setup
|
201 |
+
|
202 |
+
@async_setup.setter
|
203 |
+
def async_setup(self, value: Optional[SignalSetup]):
|
204 |
+
"""Setup signals to be loaded asynchronously.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
Args:
|
208 |
+
value: An instance of SignalSetup with the following fields
|
209 |
+
- signals: list of signals (keys of example.audio_signals) which will be loaded
|
210 |
+
asynchronously with signals possibly having different start and duration
|
211 |
+
- duration: Duration for each signal to be loaded.
|
212 |
+
If duration is set to None, the whole file will be loaded.
|
213 |
+
- channel_selectors: A list of channel selector for each signal. If channel selector
|
214 |
+
is None, all channels in the audio file will be loaded.
|
215 |
+
"""
|
216 |
+
if value is None or isinstance(value, SignalSetup):
|
217 |
+
self._async_setup = value
|
218 |
+
else:
|
219 |
+
raise ValueError(f'Unexpected type {type(value)} for value {value}.')
|
220 |
+
|
221 |
+
@property
|
222 |
+
def embedding_setup(self) -> SignalSetup:
|
223 |
+
"""Setup signals corresponding to an embedding vector.
|
224 |
+
"""
|
225 |
+
return self._embedding_setup
|
226 |
+
|
227 |
+
@embedding_setup.setter
|
228 |
+
def embedding_setup(self, value: SignalSetup):
|
229 |
+
"""Setup signals corresponding to an embedding vector.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
value: An instance of SignalSetup with the following fields
|
233 |
+
- signals: list of signals (keys of example.audio_signals) which will be loaded
|
234 |
+
as embedding vectors.
|
235 |
+
"""
|
236 |
+
if value is None or isinstance(value, SignalSetup):
|
237 |
+
self._embedding_setup = value
|
238 |
+
else:
|
239 |
+
raise ValueError(f'Unexpected type {type(value)} for value {value}.')
|
240 |
+
|
241 |
+
def process(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
|
242 |
+
"""Process an example from a collection of audio examples.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
example: an example from Audio collection.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
An ordered dictionary of signals and their tensors.
|
249 |
+
For example, the output dictionary may be the following
|
250 |
+
```
|
251 |
+
{
|
252 |
+
'input_signal': input_signal_tensor,
|
253 |
+
'target_signal': target_signal_tensor,
|
254 |
+
'reference_signal': reference_signal_tensor,
|
255 |
+
'embedding_vector': embedding_vector
|
256 |
+
}
|
257 |
+
```
|
258 |
+
Keys in the output dictionary are ordered with synchronous signals given first,
|
259 |
+
followed by asynchronous signals and embedding.
|
260 |
+
"""
|
261 |
+
audio = self.load_audio(example=example)
|
262 |
+
audio = self.process_audio(audio=audio)
|
263 |
+
return audio
|
264 |
+
|
265 |
+
def load_audio(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
|
266 |
+
"""Given an example, load audio from `example.audio_files` and prepare
|
267 |
+
the output dictionary.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
example: An example from an audio collection
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
An ordered dictionary of signals and their tensors.
|
274 |
+
For example, the output dictionary may be the following
|
275 |
+
```
|
276 |
+
{
|
277 |
+
'input_signal': input_signal_tensor,
|
278 |
+
'target_signal': target_signal_tensor,
|
279 |
+
'reference_signal': reference_signal_tensor,
|
280 |
+
'embedding_vector': embedding_vector
|
281 |
+
}
|
282 |
+
```
|
283 |
+
Keys in the output dictionary are ordered with synchronous signals given first,
|
284 |
+
followed by asynchronous signals and embedding.
|
285 |
+
"""
|
286 |
+
output = OrderedDict()
|
287 |
+
|
288 |
+
if self.sync_setup is not None:
|
289 |
+
# Load all signals with the same start and duration
|
290 |
+
sync_signals = self.load_sync_signals(example)
|
291 |
+
output.update(sync_signals)
|
292 |
+
|
293 |
+
if self.async_setup is not None:
|
294 |
+
# Load each signal independently
|
295 |
+
async_signals = self.load_async_signals(example)
|
296 |
+
output.update(async_signals)
|
297 |
+
|
298 |
+
# Load embedding vector
|
299 |
+
if self.embedding_setup is not None:
|
300 |
+
embedding = self.load_embedding(example)
|
301 |
+
output.update(embedding)
|
302 |
+
|
303 |
+
if not output:
|
304 |
+
raise RuntimeError('Output dictionary is empty. Please use `_setup` methods to setup signals to be loaded')
|
305 |
+
|
306 |
+
return output
|
307 |
+
|
308 |
+
def process_audio(self, audio: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
309 |
+
"""Process audio signals available in the input dictionary.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
audio: A dictionary containing loaded signals `signal: tensor`
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
An ordered dictionary of signals and their tensors.
|
316 |
+
"""
|
317 |
+
# Currently, not doing any processing of the loaded signals.
|
318 |
+
return audio
|
319 |
+
|
320 |
+
def load_sync_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
|
321 |
+
"""Load signals with the same start and duration.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
example: an example from audio collection
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
An ordered dictionary of signals and their tensors.
|
328 |
+
"""
|
329 |
+
output = OrderedDict()
|
330 |
+
sync_audio_files = [example.audio_files[s] for s in self.sync_setup.signals]
|
331 |
+
|
332 |
+
sync_samples = self.get_samples_synchronized(
|
333 |
+
audio_files=sync_audio_files,
|
334 |
+
channel_selectors=self.sync_setup.channel_selectors,
|
335 |
+
sample_rate=self.sample_rate,
|
336 |
+
duration=self.sync_setup.duration,
|
337 |
+
fixed_offset=example.offset,
|
338 |
+
random_offset=self.random_offset,
|
339 |
+
)
|
340 |
+
|
341 |
+
for signal, samples in zip(self.sync_setup.signals, sync_samples):
|
342 |
+
output[signal] = torch.tensor(samples)
|
343 |
+
|
344 |
+
return output
|
345 |
+
|
346 |
+
def load_async_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
|
347 |
+
"""Load each async signal independently, no constraints on starting
|
348 |
+
from the same time.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
example: an example from audio collection
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
An ordered dictionary of signals and their tensors.
|
355 |
+
"""
|
356 |
+
output = OrderedDict()
|
357 |
+
for idx, signal in enumerate(self.async_setup.signals):
|
358 |
+
samples = self.get_samples(
|
359 |
+
audio_file=example.audio_files[signal],
|
360 |
+
sample_rate=self.sample_rate,
|
361 |
+
duration=self.async_setup.duration[idx],
|
362 |
+
channel_selector=self.async_setup.channel_selectors[idx],
|
363 |
+
fixed_offset=example.offset,
|
364 |
+
random_offset=self.random_offset,
|
365 |
+
)
|
366 |
+
output[signal] = torch.tensor(samples)
|
367 |
+
return output
|
368 |
+
|
369 |
+
@classmethod
|
370 |
+
def get_samples(
|
371 |
+
cls,
|
372 |
+
audio_file: str,
|
373 |
+
sample_rate: int,
|
374 |
+
duration: Optional[float] = None,
|
375 |
+
channel_selector: ChannelSelectorType = None,
|
376 |
+
fixed_offset: float = 0,
|
377 |
+
random_offset: bool = False,
|
378 |
+
) -> np.ndarray:
|
379 |
+
"""Get samples from an audio file.
|
380 |
+
For a single-channel signal, the output is shape (num_samples,).
|
381 |
+
For a multi-channel signal, the output is shape (num_samples, num_channels).
|
382 |
+
|
383 |
+
Args:
|
384 |
+
audio_file: path to an audio file
|
385 |
+
sample_rate: desired sample rate for output samples
|
386 |
+
duration: Optional desired duration of output samples.
|
387 |
+
If `None`, the complete file will be loaded.
|
388 |
+
If set, a segment of `duration` seconds will be loaded.
|
389 |
+
channel_selector: Optional channel selector, for selecting a subset of channels.
|
390 |
+
fixed_offset: Optional fixed offset when loading samples.
|
391 |
+
random_offset: If `True`, offset will be randomized when loading a short segment
|
392 |
+
from a file. The value is randomized between fixed_offset and
|
393 |
+
max_offset (set depending on the duration and fixed_offset).
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
Numpy array with samples from audio file.
|
397 |
+
The array has shape (num_samples,) for a single-channel signal
|
398 |
+
or (num_channels, num_samples) for a multi-channel signal.
|
399 |
+
"""
|
400 |
+
output = cls.get_samples_synchronized(
|
401 |
+
audio_files=[audio_file],
|
402 |
+
sample_rate=sample_rate,
|
403 |
+
duration=duration,
|
404 |
+
channel_selectors=[channel_selector],
|
405 |
+
fixed_offset=fixed_offset,
|
406 |
+
random_offset=random_offset,
|
407 |
+
)
|
408 |
+
|
409 |
+
return output[0]
|
410 |
+
|
411 |
+
@classmethod
|
412 |
+
def get_samples_synchronized(
|
413 |
+
cls,
|
414 |
+
audio_files: List[str],
|
415 |
+
sample_rate: int,
|
416 |
+
duration: Optional[float] = None,
|
417 |
+
channel_selectors: Optional[List[ChannelSelectorType]] = None,
|
418 |
+
fixed_offset: float = 0,
|
419 |
+
random_offset: bool = False,
|
420 |
+
) -> List[np.ndarray]:
|
421 |
+
"""Get samples from multiple files with the same start and end point.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
audio_files: list of paths to audio files
|
425 |
+
sample_rate: desired sample rate for output samples
|
426 |
+
duration: Optional desired duration of output samples.
|
427 |
+
If `None`, the complete files will be loaded.
|
428 |
+
If set, a segment of `duration` seconds will be loaded from
|
429 |
+
all files. Segment is synchronized across files, so that
|
430 |
+
start and end points are the same.
|
431 |
+
channel_selectors: Optional channel selector for each signal, for selecting
|
432 |
+
a subset of channels.
|
433 |
+
fixed_offset: Optional fixed offset when loading samples.
|
434 |
+
random_offset: If `True`, offset will be randomized when loading a short segment
|
435 |
+
from a file. The value is randomized between fixed_offset and
|
436 |
+
max_offset (set depending on the duration and fixed_offset).
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
List with the same size as `audio_files` but containing numpy arrays
|
440 |
+
with samples from each audio file.
|
441 |
+
Each array has shape (num_samples,) or (num_channels, num_samples), for single-
|
442 |
+
or multi-channel signal, respectively.
|
443 |
+
For example, if `audio_files = [path/to/file_1.wav, path/to/file_2.wav]`,
|
444 |
+
the output will be a list `output = [samples_file_1, samples_file_2]`.
|
445 |
+
"""
|
446 |
+
if channel_selectors is None:
|
447 |
+
channel_selectors = [None] * len(audio_files)
|
448 |
+
|
449 |
+
if duration is None:
|
450 |
+
# Load complete files starting from a fixed offset
|
451 |
+
offset = fixed_offset # fixed offset
|
452 |
+
num_samples = None # no constrain on the number of samples
|
453 |
+
|
454 |
+
else:
|
455 |
+
# Fixed duration of the output
|
456 |
+
audio_durations = cls.get_duration(audio_files)
|
457 |
+
min_audio_duration = min(audio_durations)
|
458 |
+
available_duration = min_audio_duration - fixed_offset
|
459 |
+
|
460 |
+
if available_duration <= 0:
|
461 |
+
raise ValueError(f'Fixed offset {fixed_offset}s is larger than shortest file {min_duration}s.')
|
462 |
+
|
463 |
+
if duration + fixed_offset > min_audio_duration:
|
464 |
+
# The shortest file is shorter than the requested duration
|
465 |
+
logging.debug(
|
466 |
+
f'Shortest file ({min_audio_duration}s) is less than the desired duration {duration}s + fixed offset {fixed_offset}s. Returned signals will be shortened to {available_duration} seconds.'
|
467 |
+
)
|
468 |
+
offset = fixed_offset
|
469 |
+
duration = available_duration
|
470 |
+
elif random_offset:
|
471 |
+
# Randomize offset based on the shortest file
|
472 |
+
max_offset = min_audio_duration - duration
|
473 |
+
offset = random.uniform(fixed_offset, max_offset)
|
474 |
+
else:
|
475 |
+
# Fixed offset
|
476 |
+
offset = fixed_offset
|
477 |
+
|
478 |
+
# Fixed number of samples
|
479 |
+
num_samples = math.floor(duration * sample_rate)
|
480 |
+
|
481 |
+
output = []
|
482 |
+
|
483 |
+
# Prepare segments
|
484 |
+
for idx, audio_file in enumerate(audio_files):
|
485 |
+
segment_samples = cls.get_samples_from_file(
|
486 |
+
audio_file=audio_file,
|
487 |
+
sample_rate=sample_rate,
|
488 |
+
offset=offset,
|
489 |
+
num_samples=num_samples,
|
490 |
+
channel_selector=channel_selectors[idx],
|
491 |
+
)
|
492 |
+
output.append(segment_samples)
|
493 |
+
|
494 |
+
return output
|
495 |
+
|
496 |
+
@classmethod
|
497 |
+
def get_samples_from_file(
|
498 |
+
cls,
|
499 |
+
audio_file: Union[str, List[str]],
|
500 |
+
sample_rate: int,
|
501 |
+
offset: float,
|
502 |
+
num_samples: Optional[int] = None,
|
503 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
504 |
+
) -> np.ndarray:
|
505 |
+
"""Get samples from a single or multiple files.
|
506 |
+
If loading samples from multiple files, they will
|
507 |
+
be concatenated along the channel dimension.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
audio_file: path or a list of paths.
|
511 |
+
sample_rate: sample rate of the loaded samples
|
512 |
+
offset: fixed offset in seconds
|
513 |
+
num_samples: Optional, number of samples to load.
|
514 |
+
If `None`, all available samples will be loaded.
|
515 |
+
channel_selector: Select a subset of available channels.
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
An array with shape (samples,) or (channels, samples)
|
519 |
+
"""
|
520 |
+
if isinstance(audio_file, str):
|
521 |
+
# Load samples from a single file
|
522 |
+
segment_samples = cls.get_segment_from_file(
|
523 |
+
audio_file=audio_file,
|
524 |
+
sample_rate=sample_rate,
|
525 |
+
offset=offset,
|
526 |
+
num_samples=num_samples,
|
527 |
+
channel_selector=channel_selector,
|
528 |
+
)
|
529 |
+
elif isinstance(audio_file, list):
|
530 |
+
# Load samples from multiple files and form a multi-channel signal
|
531 |
+
segment_samples = []
|
532 |
+
for a_file in audio_file:
|
533 |
+
a_file_samples = cls.get_segment_from_file(
|
534 |
+
audio_file=a_file,
|
535 |
+
sample_rate=sample_rate,
|
536 |
+
offset=offset,
|
537 |
+
num_samples=num_samples,
|
538 |
+
channel_selector=channel_selector,
|
539 |
+
)
|
540 |
+
segment_samples.append(a_file_samples)
|
541 |
+
segment_samples = cls.list_to_multichannel(segment_samples)
|
542 |
+
elif audio_file is None:
|
543 |
+
# Support for inference, when the target signal is `None`
|
544 |
+
segment_samples = []
|
545 |
+
else:
|
546 |
+
raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}')
|
547 |
+
return segment_samples
|
548 |
+
|
549 |
+
@staticmethod
|
550 |
+
def get_segment_from_file(
|
551 |
+
audio_file: str,
|
552 |
+
sample_rate: int,
|
553 |
+
offset: float,
|
554 |
+
num_samples: Optional[int] = None,
|
555 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
556 |
+
) -> np.ndarray:
|
557 |
+
"""Get a segment of samples from a single audio file.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
audio_file: path to an audio file
|
561 |
+
sample_rate: sample rate of the loaded samples
|
562 |
+
offset: fixed offset in seconds
|
563 |
+
num_samples: Optional, number of samples to load.
|
564 |
+
If `None`, all available samples will be loaded.
|
565 |
+
channel_selector: Select a subset of available channels.
|
566 |
+
|
567 |
+
Returns:
|
568 |
+
An array with shape (samples,) or (channels, samples)
|
569 |
+
"""
|
570 |
+
if num_samples is None:
|
571 |
+
segment = AudioSegment.from_file(
|
572 |
+
audio_file=audio_file, target_sr=sample_rate, offset=offset, channel_selector=channel_selector,
|
573 |
+
)
|
574 |
+
|
575 |
+
else:
|
576 |
+
segment = AudioSegment.segment_from_file(
|
577 |
+
audio_file=audio_file,
|
578 |
+
target_sr=sample_rate,
|
579 |
+
n_segments=num_samples,
|
580 |
+
offset=offset,
|
581 |
+
channel_selector=channel_selector,
|
582 |
+
)
|
583 |
+
|
584 |
+
if segment.samples.ndim == 1:
|
585 |
+
# Single-channel signal
|
586 |
+
return segment.samples
|
587 |
+
elif segment.samples.ndim == 2:
|
588 |
+
# Use multi-channel format as (channels, samples)
|
589 |
+
return segment.samples.T
|
590 |
+
else:
|
591 |
+
raise RuntimeError(f'Unexpected samples shape: {segment.samples.shape}')
|
592 |
+
|
593 |
+
@staticmethod
|
594 |
+
def list_to_multichannel(signal: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
|
595 |
+
"""Convert a list of signals into a multi-channel signal by concatenating
|
596 |
+
the elements of the list along the channel dimension.
|
597 |
+
|
598 |
+
If input is not a list, it is returned unmodified.
|
599 |
+
|
600 |
+
Args:
|
601 |
+
signal: list of arrays
|
602 |
+
|
603 |
+
Returns:
|
604 |
+
Numpy array obtained by concatenating the elements of the list
|
605 |
+
along the channel dimension (axis=0).
|
606 |
+
"""
|
607 |
+
if not isinstance(signal, list):
|
608 |
+
# Nothing to do there
|
609 |
+
return signal
|
610 |
+
elif len(signal) == 0:
|
611 |
+
# Nothing to do, return as is
|
612 |
+
return signal
|
613 |
+
elif len(signal) == 1:
|
614 |
+
# Nothing to concatenate, return the original format
|
615 |
+
return signal[0]
|
616 |
+
|
617 |
+
# If multiple signals are provided in a list, we concatenate them along the channel dimension
|
618 |
+
if signal[0].ndim == 1:
|
619 |
+
# Single-channel individual files
|
620 |
+
mc_signal = np.stack(signal, axis=0)
|
621 |
+
elif signal[0].ndim == 2:
|
622 |
+
# Multi-channel individual files
|
623 |
+
mc_signal = np.concatenate(signal, axis=0)
|
624 |
+
else:
|
625 |
+
raise RuntimeError(f'Unexpected target with {signal[0].ndim} dimensions.')
|
626 |
+
|
627 |
+
return mc_signal
|
628 |
+
|
629 |
+
@staticmethod
|
630 |
+
def get_duration(audio_files: List[str]) -> List[float]:
|
631 |
+
"""Get duration for each audio file in `audio_files`.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
audio_files: list of paths to audio files
|
635 |
+
|
636 |
+
Returns:
|
637 |
+
List of durations in seconds.
|
638 |
+
"""
|
639 |
+
duration = [librosa.get_duration(path=f) for f in flatten(audio_files)]
|
640 |
+
return duration
|
641 |
+
|
642 |
+
def load_embedding(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
|
643 |
+
"""Given an example, load embedding from `example.audio_files[embedding]`
|
644 |
+
and return it in a dictionary.
|
645 |
+
|
646 |
+
Args:
|
647 |
+
example: An example from audio collection
|
648 |
+
|
649 |
+
Returns:
|
650 |
+
An dictionary of embedding keys and their tensors.
|
651 |
+
"""
|
652 |
+
output = OrderedDict()
|
653 |
+
for idx, signal in enumerate(self.embedding_setup.signals):
|
654 |
+
embedding_file = example.audio_files[signal]
|
655 |
+
embedding = self.load_embedding_vector(embedding_file)
|
656 |
+
output[signal] = torch.tensor(embedding)
|
657 |
+
return output
|
658 |
+
|
659 |
+
@staticmethod
|
660 |
+
def load_embedding_vector(filepath: str) -> np.ndarray:
|
661 |
+
"""Load an embedding vector from a file.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
filepath: path to a file storing a vector.
|
665 |
+
Currently, it is assumed the file is a npy file.
|
666 |
+
|
667 |
+
Returns:
|
668 |
+
Array loaded from filepath.
|
669 |
+
"""
|
670 |
+
if filepath.endswith('.npy'):
|
671 |
+
with open(filepath, 'rb') as f:
|
672 |
+
embedding = np.load(f)
|
673 |
+
else:
|
674 |
+
raise RuntimeError(f'Unknown embedding file format in file: {filepath}')
|
675 |
+
|
676 |
+
return embedding
|
677 |
+
|
678 |
+
|
679 |
+
class BaseAudioDataset(Dataset):
|
680 |
+
"""Base class of audio datasets, providing common functionality
|
681 |
+
for other audio datasets.
|
682 |
+
|
683 |
+
Args:
|
684 |
+
collection: Collection of audio examples prepared from manifest files.
|
685 |
+
audio_processor: Used to process every example from the collection.
|
686 |
+
A callable with `process` method. For reference,
|
687 |
+
please check ASRAudioProcessor.
|
688 |
+
"""
|
689 |
+
|
690 |
+
@property
|
691 |
+
@abc.abstractmethod
|
692 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
693 |
+
"""Returns definitions of module output ports.
|
694 |
+
"""
|
695 |
+
|
696 |
+
def __init__(self, collection: collections.Audio, audio_processor: Callable, output_type: Type[namedtuple]):
|
697 |
+
"""Instantiates an audio dataset.
|
698 |
+
"""
|
699 |
+
super().__init__()
|
700 |
+
|
701 |
+
self.collection = collection
|
702 |
+
self.audio_processor = audio_processor
|
703 |
+
self.output_type = output_type
|
704 |
+
|
705 |
+
def num_channels(self, signal_key) -> int:
|
706 |
+
"""Returns the number of channels for a particular signal in
|
707 |
+
items prepared by this dictionary.
|
708 |
+
|
709 |
+
More specifically, this will get the tensor from the first
|
710 |
+
item in the dataset, check if it's a one- or two-dimensional
|
711 |
+
tensor, and return the number of channels based on the size
|
712 |
+
of the first axis (shape[0]).
|
713 |
+
|
714 |
+
NOTE:
|
715 |
+
This assumes that all examples have the same number of channels.
|
716 |
+
|
717 |
+
Args:
|
718 |
+
signal_key: string, used to select a signal from the dictionary
|
719 |
+
output by __getitem__
|
720 |
+
|
721 |
+
Returns:
|
722 |
+
Number of channels for the selected signal.
|
723 |
+
"""
|
724 |
+
# Assumption: whole dataset has the same number of channels
|
725 |
+
item = self.__getitem__(0)
|
726 |
+
|
727 |
+
if item[signal_key].ndim == 1:
|
728 |
+
return 1
|
729 |
+
elif item[signal_key].ndim == 2:
|
730 |
+
return item[signal_key].shape[0]
|
731 |
+
else:
|
732 |
+
raise RuntimeError(
|
733 |
+
f'Unexpected number of dimension for signal {signal_key} with shape {item[signal_key].shape}'
|
734 |
+
)
|
735 |
+
|
736 |
+
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
|
737 |
+
"""Return a single example from the dataset.
|
738 |
+
|
739 |
+
Args:
|
740 |
+
index: integer index of an example in the collection
|
741 |
+
|
742 |
+
Returns:
|
743 |
+
Dictionary providing mapping from signal to its tensor.
|
744 |
+
For example:
|
745 |
+
```
|
746 |
+
{
|
747 |
+
'input_signal': input_signal_tensor,
|
748 |
+
'target_signal': target_signal_tensor,
|
749 |
+
}
|
750 |
+
```
|
751 |
+
"""
|
752 |
+
example = self.collection[index]
|
753 |
+
output = self.audio_processor.process(example=example)
|
754 |
+
|
755 |
+
return output
|
756 |
+
|
757 |
+
def __len__(self) -> int:
|
758 |
+
"""Return the number of examples in the dataset.
|
759 |
+
"""
|
760 |
+
return len(self.collection)
|
761 |
+
|
762 |
+
def _collate_fn(self, batch) -> Tuple[torch.Tensor]:
|
763 |
+
"""Collate items in a batch.
|
764 |
+
"""
|
765 |
+
return self.output_type(*_audio_collate_fn(batch))
|
766 |
+
|
767 |
+
|
768 |
+
AudioToTargetExample = namedtuple(
|
769 |
+
typename='AudioToTargetExample', field_names='input_signal input_length target_signal target_length'
|
770 |
+
)
|
771 |
+
|
772 |
+
|
773 |
+
class AudioToTargetDataset(BaseAudioDataset):
|
774 |
+
"""A dataset for audio-to-audio tasks where the goal is to use
|
775 |
+
an input signal to recover the corresponding target signal.
|
776 |
+
|
777 |
+
Each line of the manifest file is expected to have the following format
|
778 |
+
```
|
779 |
+
{
|
780 |
+
'input_key': 'path/to/input.wav',
|
781 |
+
'target_key': 'path/to/path_to_target.wav',
|
782 |
+
'duration': duration_of_input,
|
783 |
+
}
|
784 |
+
```
|
785 |
+
|
786 |
+
Additionally, multiple audio files may be provided for each key in the manifest, for example,
|
787 |
+
```
|
788 |
+
{
|
789 |
+
'input_key': 'path/to/input.wav',
|
790 |
+
'target_key': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'],
|
791 |
+
'duration': duration_of_input,
|
792 |
+
}
|
793 |
+
```
|
794 |
+
|
795 |
+
Keys for input and target signals can be configured in the constructor (`input_key` and `target_key`).
|
796 |
+
|
797 |
+
Args:
|
798 |
+
manifest_filepath: Path to manifest file in a format described above.
|
799 |
+
sample_rate: Sample rate for loaded audio signals.
|
800 |
+
input_key: Key pointing to input audio files in the manifest
|
801 |
+
target_key: Key pointing to target audio files in manifest
|
802 |
+
audio_duration: Optional duration of each item returned by __getitem__.
|
803 |
+
If `None`, complete audio will be loaded.
|
804 |
+
If set, a random subsegment will be loaded synchronously from
|
805 |
+
target and audio, i.e., with the same start and end point.
|
806 |
+
random_offset: If `True`, offset will be randomized when loading a subsegment
|
807 |
+
from a file.
|
808 |
+
max_duration: If audio exceeds this length, do not include in dataset.
|
809 |
+
min_duration: If audio is less than this length, do not include in dataset.
|
810 |
+
max_utts: Limit number of utterances.
|
811 |
+
input_channel_selector: Optional, select subset of channels from each input audio file.
|
812 |
+
If `None`, all channels will be loaded.
|
813 |
+
target_channel_selector: Optional, select subset of channels from each input audio file.
|
814 |
+
If `None`, all channels will be loaded.
|
815 |
+
"""
|
816 |
+
|
817 |
+
def __init__(
|
818 |
+
self,
|
819 |
+
manifest_filepath: str,
|
820 |
+
sample_rate: int,
|
821 |
+
input_key: str,
|
822 |
+
target_key: str,
|
823 |
+
audio_duration: Optional[float] = None,
|
824 |
+
random_offset: bool = False,
|
825 |
+
max_duration: Optional[float] = None,
|
826 |
+
min_duration: Optional[float] = None,
|
827 |
+
max_utts: Optional[int] = None,
|
828 |
+
input_channel_selector: Optional[int] = None,
|
829 |
+
target_channel_selector: Optional[int] = None,
|
830 |
+
):
|
831 |
+
audio_to_manifest_key = {
|
832 |
+
'input_signal': input_key,
|
833 |
+
'target_signal': target_key,
|
834 |
+
}
|
835 |
+
|
836 |
+
collection = collections.AudioCollection(
|
837 |
+
manifest_files=manifest_filepath,
|
838 |
+
audio_to_manifest_key=audio_to_manifest_key,
|
839 |
+
min_duration=min_duration,
|
840 |
+
max_duration=max_duration,
|
841 |
+
max_number=max_utts,
|
842 |
+
)
|
843 |
+
|
844 |
+
audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
|
845 |
+
audio_processor.sync_setup = SignalSetup(
|
846 |
+
signals=['input_signal', 'target_signal'],
|
847 |
+
duration=audio_duration,
|
848 |
+
channel_selectors=[input_channel_selector, target_channel_selector],
|
849 |
+
)
|
850 |
+
|
851 |
+
super().__init__(collection=collection, audio_processor=audio_processor, output_type=AudioToTargetExample)
|
852 |
+
|
853 |
+
@property
|
854 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
855 |
+
"""Returns definitions of module output ports.
|
856 |
+
|
857 |
+
Returns:
|
858 |
+
Ordered dictionary in the following form:
|
859 |
+
```
|
860 |
+
{
|
861 |
+
'input_signal': batched single- or multi-channel format,
|
862 |
+
'input_length': batched original length of each input signal
|
863 |
+
'target_signal': batched single- or multi-channel format,
|
864 |
+
'target_length': batched original length of each target signal
|
865 |
+
}
|
866 |
+
```
|
867 |
+
"""
|
868 |
+
sc_audio_type = NeuralType(('B', 'T'), AudioSignal())
|
869 |
+
mc_audio_type = NeuralType(('B', 'C', 'T'), AudioSignal())
|
870 |
+
|
871 |
+
return OrderedDict(
|
872 |
+
input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
|
873 |
+
input_length=NeuralType(('B',), LengthsType()),
|
874 |
+
target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
|
875 |
+
target_length=NeuralType(('B',), LengthsType()),
|
876 |
+
)
|
877 |
+
|
878 |
+
|
879 |
+
AudioToTargetWithReferenceExample = namedtuple(
|
880 |
+
typename='AudioToTargetWithReferenceExample',
|
881 |
+
field_names='input_signal input_length target_signal target_length reference_signal reference_length',
|
882 |
+
)
|
883 |
+
|
884 |
+
|
885 |
+
class AudioToTargetWithReferenceDataset(BaseAudioDataset):
|
886 |
+
"""A dataset for audio-to-audio tasks where the goal is to use
|
887 |
+
an input signal to recover the corresponding target signal and an
|
888 |
+
additional reference signal is available.
|
889 |
+
|
890 |
+
This can be used, for example, when a reference signal is
|
891 |
+
available from
|
892 |
+
- enrollment utterance for the target signal
|
893 |
+
- echo reference from playback
|
894 |
+
- reference from another sensor that correlates with the target signal
|
895 |
+
|
896 |
+
Each line of the manifest file is expected to have the following format
|
897 |
+
```
|
898 |
+
{
|
899 |
+
'input_key': 'path/to/input.wav',
|
900 |
+
'target_key': 'path/to/path_to_target.wav',
|
901 |
+
'reference_key': 'path/to/path_to_reference.wav',
|
902 |
+
'duration': duration_of_input,
|
903 |
+
}
|
904 |
+
```
|
905 |
+
|
906 |
+
Keys for input, target and reference signals can be configured in the constructor.
|
907 |
+
|
908 |
+
Args:
|
909 |
+
manifest_filepath: Path to manifest file in a format described above.
|
910 |
+
sample_rate: Sample rate for loaded audio signals.
|
911 |
+
input_key: Key pointing to input audio files in the manifest
|
912 |
+
target_key: Key pointing to target audio files in manifest
|
913 |
+
reference_key: Key pointing to reference audio files in manifest
|
914 |
+
audio_duration: Optional duration of each item returned by __getitem__.
|
915 |
+
If `None`, complete audio will be loaded.
|
916 |
+
If set, a random subsegment will be loaded synchronously from
|
917 |
+
target and audio, i.e., with the same start and end point.
|
918 |
+
random_offset: If `True`, offset will be randomized when loading a subsegment
|
919 |
+
from a file.
|
920 |
+
max_duration: If audio exceeds this length, do not include in dataset.
|
921 |
+
min_duration: If audio is less than this length, do not include in dataset.
|
922 |
+
max_utts: Limit number of utterances.
|
923 |
+
input_channel_selector: Optional, select subset of channels from each input audio file.
|
924 |
+
If `None`, all channels will be loaded.
|
925 |
+
target_channel_selector: Optional, select subset of channels from each input audio file.
|
926 |
+
If `None`, all channels will be loaded.
|
927 |
+
reference_channel_selector: Optional, select subset of channels from each input audio file.
|
928 |
+
If `None`, all channels will be loaded.
|
929 |
+
reference_is_synchronized: If True, it is assumed that the reference signal is synchronized
|
930 |
+
with the input signal, so the same subsegment will be loaded as for
|
931 |
+
input and target. If False, reference signal will be loaded independently
|
932 |
+
from input and target.
|
933 |
+
reference_duration: Optional, can be used to set a fixed duration of the reference utterance. If `None`,
|
934 |
+
complete audio file will be loaded.
|
935 |
+
"""
|
936 |
+
|
937 |
+
def __init__(
|
938 |
+
self,
|
939 |
+
manifest_filepath: str,
|
940 |
+
sample_rate: int,
|
941 |
+
input_key: str,
|
942 |
+
target_key: str,
|
943 |
+
reference_key: str,
|
944 |
+
audio_duration: Optional[float] = None,
|
945 |
+
random_offset: bool = False,
|
946 |
+
max_duration: Optional[float] = None,
|
947 |
+
min_duration: Optional[float] = None,
|
948 |
+
max_utts: Optional[int] = None,
|
949 |
+
input_channel_selector: Optional[int] = None,
|
950 |
+
target_channel_selector: Optional[int] = None,
|
951 |
+
reference_channel_selector: Optional[int] = None,
|
952 |
+
reference_is_synchronized: bool = True,
|
953 |
+
reference_duration: Optional[float] = None,
|
954 |
+
):
|
955 |
+
audio_to_manifest_key = {
|
956 |
+
'input_signal': input_key,
|
957 |
+
'target_signal': target_key,
|
958 |
+
'reference_signal': reference_key,
|
959 |
+
}
|
960 |
+
|
961 |
+
collection = collections.AudioCollection(
|
962 |
+
manifest_files=manifest_filepath,
|
963 |
+
audio_to_manifest_key=audio_to_manifest_key,
|
964 |
+
min_duration=min_duration,
|
965 |
+
max_duration=max_duration,
|
966 |
+
max_number=max_utts,
|
967 |
+
)
|
968 |
+
|
969 |
+
audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
|
970 |
+
|
971 |
+
if reference_is_synchronized:
|
972 |
+
audio_processor.sync_setup = SignalSetup(
|
973 |
+
signals=['input_signal', 'target_signal', 'reference_signal'],
|
974 |
+
duration=audio_duration,
|
975 |
+
channel_selectors=[input_channel_selector, target_channel_selector, reference_channel_selector],
|
976 |
+
)
|
977 |
+
else:
|
978 |
+
audio_processor.sync_setup = SignalSetup(
|
979 |
+
signals=['input_signal', 'target_signal'],
|
980 |
+
duration=audio_duration,
|
981 |
+
channel_selectors=[input_channel_selector, target_channel_selector],
|
982 |
+
)
|
983 |
+
audio_processor.async_setup = SignalSetup(
|
984 |
+
signals=['reference_signal'],
|
985 |
+
duration=[reference_duration],
|
986 |
+
channel_selectors=[reference_channel_selector],
|
987 |
+
)
|
988 |
+
|
989 |
+
super().__init__(
|
990 |
+
collection=collection, audio_processor=audio_processor, output_type=AudioToTargetWithReferenceExample
|
991 |
+
)
|
992 |
+
|
993 |
+
@property
|
994 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
995 |
+
"""Returns definitions of module output ports.
|
996 |
+
|
997 |
+
Returns:
|
998 |
+
Ordered dictionary in the following form:
|
999 |
+
```
|
1000 |
+
{
|
1001 |
+
'input_signal': batched single- or multi-channel format,
|
1002 |
+
'input_length': batched original length of each input signal
|
1003 |
+
'target_signal': batched single- or multi-channel format,
|
1004 |
+
'target_length': batched original length of each target signal
|
1005 |
+
'reference_signal': single- or multi-channel format,
|
1006 |
+
'reference_length': original length of each reference signal
|
1007 |
+
}
|
1008 |
+
```
|
1009 |
+
"""
|
1010 |
+
sc_audio_type = NeuralType(('B', 'T'), AudioSignal())
|
1011 |
+
mc_audio_type = NeuralType(('B', 'C', 'T'), AudioSignal())
|
1012 |
+
|
1013 |
+
return OrderedDict(
|
1014 |
+
input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
|
1015 |
+
input_length=NeuralType(('B',), LengthsType()),
|
1016 |
+
target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
|
1017 |
+
target_length=NeuralType(('B',), LengthsType()),
|
1018 |
+
reference_signal=sc_audio_type if self.num_channels('reference_signal') == 1 else mc_audio_type,
|
1019 |
+
reference_length=NeuralType(('B',), LengthsType()),
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
|
1023 |
+
AudioToTargetWithEmbeddingExample = namedtuple(
|
1024 |
+
typename='AudioToTargetWithEmbeddingExample',
|
1025 |
+
field_names='input_signal input_length target_signal target_length embedding_vector embedding_length',
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
|
1029 |
+
class AudioToTargetWithEmbeddingDataset(BaseAudioDataset):
|
1030 |
+
"""A dataset for audio-to-audio tasks where the goal is to use
|
1031 |
+
an input signal to recover the corresponding target signal and an
|
1032 |
+
additional embedding signal. It is assumed that the embedding
|
1033 |
+
is in a form of a vector.
|
1034 |
+
|
1035 |
+
Each line of the manifest file is expected to have the following format
|
1036 |
+
```
|
1037 |
+
{
|
1038 |
+
input_key: 'path/to/input.wav',
|
1039 |
+
target_key: 'path/to/path_to_target.wav',
|
1040 |
+
embedding_key: 'path/to/path_to_reference.npy',
|
1041 |
+
'duration': duration_of_input,
|
1042 |
+
}
|
1043 |
+
```
|
1044 |
+
|
1045 |
+
Keys for input, target and embedding signals can be configured in the constructor.
|
1046 |
+
|
1047 |
+
Args:
|
1048 |
+
manifest_filepath: Path to manifest file in a format described above.
|
1049 |
+
sample_rate: Sample rate for loaded audio signals.
|
1050 |
+
input_key: Key pointing to input audio files in the manifest
|
1051 |
+
target_key: Key pointing to target audio files in manifest
|
1052 |
+
embedding_key: Key pointing to embedding files in manifest
|
1053 |
+
audio_duration: Optional duration of each item returned by __getitem__.
|
1054 |
+
If `None`, complete audio will be loaded.
|
1055 |
+
If set, a random subsegment will be loaded synchronously from
|
1056 |
+
target and audio, i.e., with the same start and end point.
|
1057 |
+
random_offset: If `True`, offset will be randomized when loading a subsegment
|
1058 |
+
from a file.
|
1059 |
+
max_duration: If audio exceeds this length, do not include in dataset.
|
1060 |
+
min_duration: If audio is less than this length, do not include in dataset.
|
1061 |
+
max_utts: Limit number of utterances.
|
1062 |
+
input_channel_selector: Optional, select subset of channels from each input audio file.
|
1063 |
+
If `None`, all channels will be loaded.
|
1064 |
+
target_channel_selector: Optional, select subset of channels from each input audio file.
|
1065 |
+
If `None`, all channels will be loaded.
|
1066 |
+
"""
|
1067 |
+
|
1068 |
+
def __init__(
|
1069 |
+
self,
|
1070 |
+
manifest_filepath: str,
|
1071 |
+
sample_rate: int,
|
1072 |
+
input_key: str,
|
1073 |
+
target_key: str,
|
1074 |
+
embedding_key: str,
|
1075 |
+
audio_duration: Optional[float] = None,
|
1076 |
+
random_offset: bool = False,
|
1077 |
+
max_duration: Optional[float] = None,
|
1078 |
+
min_duration: Optional[float] = None,
|
1079 |
+
max_utts: Optional[int] = None,
|
1080 |
+
input_channel_selector: Optional[int] = None,
|
1081 |
+
target_channel_selector: Optional[int] = None,
|
1082 |
+
):
|
1083 |
+
audio_to_manifest_key = {
|
1084 |
+
'input_signal': input_key,
|
1085 |
+
'target_signal': target_key,
|
1086 |
+
'embedding_vector': embedding_key,
|
1087 |
+
}
|
1088 |
+
|
1089 |
+
collection = collections.AudioCollection(
|
1090 |
+
manifest_files=manifest_filepath,
|
1091 |
+
audio_to_manifest_key=audio_to_manifest_key,
|
1092 |
+
min_duration=min_duration,
|
1093 |
+
max_duration=max_duration,
|
1094 |
+
max_number=max_utts,
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
|
1098 |
+
audio_processor.sync_setup = SignalSetup(
|
1099 |
+
signals=['input_signal', 'target_signal'],
|
1100 |
+
duration=audio_duration,
|
1101 |
+
channel_selectors=[input_channel_selector, target_channel_selector],
|
1102 |
+
)
|
1103 |
+
audio_processor.embedding_setup = SignalSetup(signals=['embedding_vector'])
|
1104 |
+
|
1105 |
+
super().__init__(
|
1106 |
+
collection=collection, audio_processor=audio_processor, output_type=AudioToTargetWithEmbeddingExample
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
@property
|
1110 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
1111 |
+
"""Returns definitions of module output ports.
|
1112 |
+
|
1113 |
+
Returns:
|
1114 |
+
Ordered dictionary in the following form:
|
1115 |
+
```
|
1116 |
+
{
|
1117 |
+
'input_signal': batched single- or multi-channel format,
|
1118 |
+
'input_length': batched original length of each input signal
|
1119 |
+
'target_signal': batched single- or multi-channel format,
|
1120 |
+
'target_length': batched original length of each target signal
|
1121 |
+
'embedding_vector': batched embedded vector format,
|
1122 |
+
'embedding_length': batched original length of each embedding vector
|
1123 |
+
}
|
1124 |
+
```
|
1125 |
+
"""
|
1126 |
+
sc_audio_type = NeuralType(('B', 'T'), AudioSignal())
|
1127 |
+
mc_audio_type = NeuralType(('B', 'C', 'T'), AudioSignal())
|
1128 |
+
|
1129 |
+
return OrderedDict(
|
1130 |
+
input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
|
1131 |
+
input_length=NeuralType(('B',), LengthsType()),
|
1132 |
+
target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
|
1133 |
+
target_length=NeuralType(('B',), LengthsType()),
|
1134 |
+
embedding_vector=NeuralType(('B', 'D'), EncodedRepresentation()),
|
1135 |
+
embedding_length=NeuralType(('B',), LengthsType()),
|
1136 |
+
)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio_dataset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from nemo.collections.asr.data import audio_to_audio
|
16 |
+
|
17 |
+
|
18 |
+
def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset:
|
19 |
+
"""Instantiates an audio-to-audio dataset.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
config: Config of AudioToTargetDataset.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
An instance of AudioToTargetDataset
|
26 |
+
"""
|
27 |
+
dataset = audio_to_audio.AudioToTargetDataset(
|
28 |
+
manifest_filepath=config['manifest_filepath'],
|
29 |
+
sample_rate=config['sample_rate'],
|
30 |
+
input_key=config['input_key'],
|
31 |
+
target_key=config['target_key'],
|
32 |
+
audio_duration=config.get('audio_duration', None),
|
33 |
+
random_offset=config.get('random_offset', False),
|
34 |
+
max_duration=config.get('max_duration', None),
|
35 |
+
min_duration=config.get('min_duration', None),
|
36 |
+
max_utts=config.get('max_utts', 0),
|
37 |
+
input_channel_selector=config.get('input_channel_selector', None),
|
38 |
+
target_channel_selector=config.get('target_channel_selector', None),
|
39 |
+
)
|
40 |
+
return dataset
|
41 |
+
|
42 |
+
|
43 |
+
def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset:
|
44 |
+
"""Instantiates an audio-to-audio dataset.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
config: Config of AudioToTargetWithReferenceDataset.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
An instance of AudioToTargetWithReferenceDataset
|
51 |
+
"""
|
52 |
+
dataset = audio_to_audio.AudioToTargetWithReferenceDataset(
|
53 |
+
manifest_filepath=config['manifest_filepath'],
|
54 |
+
sample_rate=config['sample_rate'],
|
55 |
+
input_key=config['input_key'],
|
56 |
+
target_key=config['target_key'],
|
57 |
+
reference_key=config['reference_key'],
|
58 |
+
audio_duration=config.get('audio_duration', None),
|
59 |
+
random_offset=config.get('random_offset', False),
|
60 |
+
max_duration=config.get('max_duration', None),
|
61 |
+
min_duration=config.get('min_duration', None),
|
62 |
+
max_utts=config.get('max_utts', 0),
|
63 |
+
input_channel_selector=config.get('input_channel_selector', None),
|
64 |
+
target_channel_selector=config.get('target_channel_selector', None),
|
65 |
+
reference_channel_selector=config.get('reference_channel_selector', None),
|
66 |
+
reference_is_synchronized=config.get('reference_is_synchronized', True),
|
67 |
+
reference_duration=config.get('reference_duration', None),
|
68 |
+
)
|
69 |
+
return dataset
|
70 |
+
|
71 |
+
|
72 |
+
def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset:
|
73 |
+
"""Instantiates an audio-to-audio dataset.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
config: Config of AudioToTargetWithEmbeddingDataset.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
An instance of AudioToTargetWithEmbeddingDataset
|
80 |
+
"""
|
81 |
+
dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset(
|
82 |
+
manifest_filepath=config['manifest_filepath'],
|
83 |
+
sample_rate=config['sample_rate'],
|
84 |
+
input_key=config['input_key'],
|
85 |
+
target_key=config['target_key'],
|
86 |
+
embedding_key=config['embedding_key'],
|
87 |
+
audio_duration=config.get('audio_duration', None),
|
88 |
+
random_offset=config.get('random_offset', False),
|
89 |
+
max_duration=config.get('max_duration', None),
|
90 |
+
min_duration=config.get('min_duration', None),
|
91 |
+
max_utts=config.get('max_utts', 0),
|
92 |
+
input_channel_selector=config.get('input_channel_selector', None),
|
93 |
+
target_channel_selector=config.get('target_channel_selector', None),
|
94 |
+
)
|
95 |
+
return dataset
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_ctm_dataset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Any, List, Tuple
|
20 |
+
|
21 |
+
from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
|
22 |
+
from nemo.utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class FrameCtmUnit:
|
27 |
+
"""A container class for one CTM unit with start and length countable in frames.
|
28 |
+
"""
|
29 |
+
|
30 |
+
label: str
|
31 |
+
start_frame: int
|
32 |
+
length: int
|
33 |
+
probability: float
|
34 |
+
|
35 |
+
def __repr__(self) -> str:
|
36 |
+
return f"{self.label}\t({self.probability:1.3f}): [{self.start_frame:6d}, {self.length:6d}]"
|
37 |
+
|
38 |
+
@property
|
39 |
+
def end_frame(self):
|
40 |
+
return self.start_frame + self.length
|
41 |
+
|
42 |
+
def to_ctm_str(self, time_per_frame: int) -> str:
|
43 |
+
"""Represents the data as part of the CTM line.
|
44 |
+
|
45 |
+
The CTM line format is
|
46 |
+
<utterance_name> <channel> <start_time> <duration> <label_str> <probability>
|
47 |
+
This method prepares the last four entities."""
|
48 |
+
return f"{self.start_frame * time_per_frame :.3f} {self.length * time_per_frame :.3f} {self.label} {self.probability :1.3f}"
|
49 |
+
|
50 |
+
|
51 |
+
class ASRCTMPredictionWriter(ASRPredictionWriter):
|
52 |
+
def __init__(self, dataset, output_file: str, output_ctm_dir: str, time_per_frame: float):
|
53 |
+
super().__init__(dataset, output_file)
|
54 |
+
self.output_ctm_dir = output_ctm_dir
|
55 |
+
self.time_per_frame = time_per_frame
|
56 |
+
os.makedirs(self.output_ctm_dir, exist_ok=True)
|
57 |
+
|
58 |
+
def write_ctm(self, name, filepath, frameCtmUnits):
|
59 |
+
with open(filepath, "tw", encoding="utf-8") as f:
|
60 |
+
for unit in frameCtmUnits:
|
61 |
+
f.write(f"{name} 1 {unit.to_ctm_str(self.time_per_frame)}\n")
|
62 |
+
|
63 |
+
def write_on_batch_end(
|
64 |
+
self,
|
65 |
+
trainer,
|
66 |
+
pl_module: 'LightningModule',
|
67 |
+
prediction: Tuple[int, List[FrameCtmUnit]],
|
68 |
+
batch_indices: List[int],
|
69 |
+
batch: Any,
|
70 |
+
batch_idx: int,
|
71 |
+
dataloader_idx: int,
|
72 |
+
):
|
73 |
+
for sample_id, units in prediction:
|
74 |
+
sample = self.dataset.get_manifest_sample(sample_id)
|
75 |
+
with_ctm = True
|
76 |
+
if len(units) == 0:
|
77 |
+
logging.warning(
|
78 |
+
f"""Do not producing CTM output for item `{sample.audio_file}`.
|
79 |
+
Check if text is empty or if duration is too short: `{sample.text_raw}`, {sample.duration}"""
|
80 |
+
)
|
81 |
+
with_ctm = False
|
82 |
+
item = {}
|
83 |
+
item["audio_filepath"] = sample.audio_file
|
84 |
+
item["duration"] = sample.duration
|
85 |
+
item["text"] = sample.text_raw
|
86 |
+
if with_ctm:
|
87 |
+
utt_name = Path(sample.audio_file).stem
|
88 |
+
ctm_filepath = os.path.join(self.output_ctm_dir, utt_name) + ".ctm"
|
89 |
+
self.write_ctm(utt_name, ctm_filepath, units)
|
90 |
+
item["ctm_filepath"] = ctm_filepath
|
91 |
+
else:
|
92 |
+
item["ctm_filepath"] = ""
|
93 |
+
self.outf.write(json.dumps(item) + "\n")
|
94 |
+
self.samples_num += 1
|
95 |
+
return
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_diar_label.py
ADDED
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from collections import OrderedDict
|
17 |
+
from statistics import mode
|
18 |
+
from typing import Dict, Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat
|
23 |
+
from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data
|
24 |
+
from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel
|
25 |
+
from nemo.core.classes import Dataset
|
26 |
+
from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType
|
27 |
+
|
28 |
+
|
29 |
+
def get_scale_mapping_list(uniq_timestamps):
|
30 |
+
"""
|
31 |
+
Call get_argmin_mat function to find the index of the non-base-scale segment that is closest to the
|
32 |
+
given base-scale segment. For each scale and each segment, a base-scale segment is assigned.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
uniq_timestamps: (dict)
|
36 |
+
The dictionary containing embeddings, timestamps and multiscale weights.
|
37 |
+
If uniq_timestamps contains only one scale, single scale diarization is performed.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
scale_mapping_argmat (torch.tensor):
|
41 |
+
|
42 |
+
The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
|
43 |
+
segment index which has the closest center distance with (n+1)-th segment in the base scale.
|
44 |
+
|
45 |
+
- Example:
|
46 |
+
`scale_mapping_argmat[2][101] = 85`
|
47 |
+
|
48 |
+
In the above example, the code snippet means that 86-th segment in the 3rd scale (python index is 2) is
|
49 |
+
mapped to the 102-th segment in the base scale. Thus, the longer segments bound to have more repeating
|
50 |
+
numbers since multiple base scale segments (since the base scale has the shortest length) fall into the
|
51 |
+
range of the longer segments. At the same time, each row contains N numbers of indices where N is number
|
52 |
+
of segments in the base-scale (i.e., the finest scale).
|
53 |
+
"""
|
54 |
+
timestamps_in_scales = []
|
55 |
+
for key, val in uniq_timestamps['scale_dict'].items():
|
56 |
+
timestamps_in_scales.append(torch.tensor(val['time_stamps']))
|
57 |
+
session_scale_mapping_list = get_argmin_mat(timestamps_in_scales)
|
58 |
+
scale_mapping_argmat = [[] for _ in range(len(uniq_timestamps['scale_dict'].keys()))]
|
59 |
+
for scale_idx in range(len(session_scale_mapping_list)):
|
60 |
+
scale_mapping_argmat[scale_idx] = session_scale_mapping_list[scale_idx]
|
61 |
+
scale_mapping_argmat = torch.stack(scale_mapping_argmat)
|
62 |
+
return scale_mapping_argmat
|
63 |
+
|
64 |
+
|
65 |
+
def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None):
|
66 |
+
"""
|
67 |
+
Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted
|
68 |
+
speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be
|
69 |
+
included in the output lists.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
uniq_id (str):
|
73 |
+
Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file.
|
74 |
+
rttm_lines (list):
|
75 |
+
List containing RTTM lines in str format.
|
76 |
+
mapping_dict (dict):
|
77 |
+
Mapping between the estimated speakers and the speakers in the ground-truth annotation.
|
78 |
+
`mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode.
|
79 |
+
Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation.
|
80 |
+
Returns:
|
81 |
+
rttm_tup (tuple):
|
82 |
+
Tuple containing lists of start time, end time and speaker labels.
|
83 |
+
|
84 |
+
"""
|
85 |
+
stt_list, end_list, speaker_list, pairwise_infer_spks = [], [], [], []
|
86 |
+
if target_spks:
|
87 |
+
inv_map = {v: k for k, v in mapping_dict.items()}
|
88 |
+
for spk_idx in target_spks:
|
89 |
+
spk_str = f'speaker_{spk_idx}'
|
90 |
+
if spk_str in inv_map:
|
91 |
+
pairwise_infer_spks.append(inv_map[spk_str])
|
92 |
+
for rttm_line in rttm_lines:
|
93 |
+
start, end, speaker = convert_rttm_line(rttm_line)
|
94 |
+
if target_spks is None or speaker in pairwise_infer_spks:
|
95 |
+
end_list.append(end)
|
96 |
+
stt_list.append(start)
|
97 |
+
speaker_list.append(speaker)
|
98 |
+
rttm_tup = (stt_list, end_list, speaker_list)
|
99 |
+
return rttm_tup
|
100 |
+
|
101 |
+
|
102 |
+
def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, target_spks, min_spks=2):
|
103 |
+
"""
|
104 |
+
Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM.
|
105 |
+
The unit-length is the frame shift length of the acoustic feature. The feature-level annotations
|
106 |
+
`fr_level_target` will later be converted to base-segment level diarization label.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
rttm_timestamps (list):
|
110 |
+
List containing start and end time for each speaker segment label.
|
111 |
+
stt_list, end_list and speaker_list are contained.
|
112 |
+
frame_per_sec (int):
|
113 |
+
Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module.
|
114 |
+
target_spks (tuple):
|
115 |
+
Speaker indices that are generated from combinations. If there are only one or two speakers,
|
116 |
+
only a single target_spks variable is generated.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
fr_level_target (torch.tensor):
|
120 |
+
Tensor containing label for each feature level frame.
|
121 |
+
"""
|
122 |
+
stt_list, end_list, speaker_list = rttm_timestamps
|
123 |
+
if len(speaker_list) == 0:
|
124 |
+
return None
|
125 |
+
else:
|
126 |
+
sorted_speakers = sorted(list(set(speaker_list)))
|
127 |
+
total_fr_len = int(max(end_list) * (10 ** round_digits))
|
128 |
+
spk_num = max(len(sorted_speakers), min_spks)
|
129 |
+
speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)}
|
130 |
+
fr_level_target = torch.zeros(total_fr_len, spk_num)
|
131 |
+
|
132 |
+
# If RTTM is not provided, then there is no speaker mapping dict in target_spks.
|
133 |
+
# Thus, return a zero-filled tensor as a placeholder.
|
134 |
+
for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)):
|
135 |
+
stt, end = round(stt, round_digits), round(end, round_digits)
|
136 |
+
spk = speaker_mapping_dict[spk_rttm_key]
|
137 |
+
stt_fr, end_fr = int(round(stt, 2) * frame_per_sec), int(round(end, round_digits) * frame_per_sec)
|
138 |
+
fr_level_target[stt_fr:end_fr, spk] = 1
|
139 |
+
return fr_level_target
|
140 |
+
|
141 |
+
|
142 |
+
class _AudioMSDDTrainDataset(Dataset):
|
143 |
+
"""
|
144 |
+
Dataset class that loads a json file containing paths to audio files,
|
145 |
+
RTTM files and number of speakers. This Dataset class is designed for
|
146 |
+
training or fine-tuning speaker embedding extractor and diarization decoder
|
147 |
+
at the same time.
|
148 |
+
|
149 |
+
Example:
|
150 |
+
{"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
|
151 |
+
"rttm_filepath": "/path/to/diar_label_0.rttm}
|
152 |
+
...
|
153 |
+
{"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
|
154 |
+
"rttm_filepath": "/path/to/diar_label_n.rttm}
|
155 |
+
|
156 |
+
Args:
|
157 |
+
manifest_filepath (str):
|
158 |
+
Path to input manifest json files.
|
159 |
+
multiscale_args_dict (dict):
|
160 |
+
Dictionary containing the parameters for multiscale segmentation and clustering.
|
161 |
+
emb_dir (str):
|
162 |
+
Path to a temporary folder where segmentation information for embedding extraction is saved.
|
163 |
+
soft_label_thres (float):
|
164 |
+
Threshold that determines the label of each segment based on RTTM file information.
|
165 |
+
featurizer:
|
166 |
+
Featurizer instance for generating features from the raw waveform.
|
167 |
+
window_stride (float):
|
168 |
+
Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
|
169 |
+
emb_batch_size (int):
|
170 |
+
Number of embedding vectors that are trained with attached computational graphs.
|
171 |
+
pairwise_infer (bool):
|
172 |
+
This variable should be True if dataloader is created for an inference task.
|
173 |
+
random_flip (bool):
|
174 |
+
If True, the two labels and input signals are randomly flipped per every epoch while training.
|
175 |
+
"""
|
176 |
+
|
177 |
+
@property
|
178 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
179 |
+
"""Returns definitions of module output ports."""
|
180 |
+
output_types = {
|
181 |
+
"features": NeuralType(('B', 'T'), AudioSignal()),
|
182 |
+
"feature_length": NeuralType(('B'), LengthsType()),
|
183 |
+
"ms_seg_timestamps": NeuralType(('B', 'C', 'T', 'D'), LengthsType()),
|
184 |
+
"ms_seg_counts": NeuralType(('B', 'C'), LengthsType()),
|
185 |
+
"clus_label_index": NeuralType(('B', 'T'), LengthsType()),
|
186 |
+
"scale_mapping": NeuralType(('B', 'C', 'T'), LengthsType()),
|
187 |
+
"targets": NeuralType(('B', 'T', 'C'), ProbsType()),
|
188 |
+
}
|
189 |
+
|
190 |
+
return output_types
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
*,
|
195 |
+
manifest_filepath: str,
|
196 |
+
multiscale_args_dict: str,
|
197 |
+
emb_dir: str,
|
198 |
+
soft_label_thres: float,
|
199 |
+
featurizer,
|
200 |
+
window_stride,
|
201 |
+
emb_batch_size,
|
202 |
+
pairwise_infer: bool,
|
203 |
+
random_flip: bool = True,
|
204 |
+
global_rank: int = 0,
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
self.collection = DiarizationSpeechLabel(
|
208 |
+
manifests_files=manifest_filepath.split(','),
|
209 |
+
emb_dict=None,
|
210 |
+
clus_label_dict=None,
|
211 |
+
pairwise_infer=pairwise_infer,
|
212 |
+
)
|
213 |
+
self.featurizer = featurizer
|
214 |
+
self.multiscale_args_dict = multiscale_args_dict
|
215 |
+
self.emb_dir = emb_dir
|
216 |
+
self.round_digits = 2
|
217 |
+
self.decim = 10 ** self.round_digits
|
218 |
+
self.soft_label_thres = soft_label_thres
|
219 |
+
self.pairwise_infer = pairwise_infer
|
220 |
+
self.max_spks = 2
|
221 |
+
self.frame_per_sec = int(1 / window_stride)
|
222 |
+
self.emb_batch_size = emb_batch_size
|
223 |
+
self.random_flip = random_flip
|
224 |
+
self.global_rank = global_rank
|
225 |
+
self.manifest_filepath = manifest_filepath
|
226 |
+
self.multiscale_timestamp_dict = prepare_split_data(
|
227 |
+
self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank,
|
228 |
+
)
|
229 |
+
|
230 |
+
def __len__(self):
|
231 |
+
return len(self.collection)
|
232 |
+
|
233 |
+
def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label):
|
234 |
+
"""
|
235 |
+
Assign the generated speaker labels from the base scale (the finest scale) to the longer scales.
|
236 |
+
This process is needed to get the cluster labels for each scale. The cluster labels are needed to
|
237 |
+
calculate the cluster-average speaker embedding for each scale.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
uniq_id (str):
|
241 |
+
Unique sample ID for training.
|
242 |
+
base_scale_clus_label (torch.tensor):
|
243 |
+
Tensor variable containing the speaker labels for the base-scale segments.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
per_scale_clus_label (torch.tensor):
|
247 |
+
Tensor variable containing the speaker labels for each segment in each scale.
|
248 |
+
Note that the total length of the speaker label sequence differs over scale since
|
249 |
+
each scale has a different number of segments for the same session.
|
250 |
+
|
251 |
+
scale_mapping (torch.tensor):
|
252 |
+
Matrix containing the segment indices of each scale. scale_mapping is necessary for reshaping the
|
253 |
+
multiscale embeddings to form an input matrix for the MSDD model.
|
254 |
+
"""
|
255 |
+
per_scale_clus_label = []
|
256 |
+
self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict'])
|
257 |
+
uniq_scale_mapping = get_scale_mapping_list(self.multiscale_timestamp_dict[uniq_id])
|
258 |
+
for scale_index in range(self.scale_n):
|
259 |
+
new_clus_label = []
|
260 |
+
scale_seq_len = len(self.multiscale_timestamp_dict[uniq_id]["scale_dict"][scale_index]["time_stamps"])
|
261 |
+
for seg_idx in range(scale_seq_len):
|
262 |
+
if seg_idx in uniq_scale_mapping[scale_index]:
|
263 |
+
seg_clus_label = mode(base_scale_clus_label[uniq_scale_mapping[scale_index] == seg_idx])
|
264 |
+
else:
|
265 |
+
seg_clus_label = 0 if len(new_clus_label) == 0 else new_clus_label[-1]
|
266 |
+
new_clus_label.append(seg_clus_label)
|
267 |
+
per_scale_clus_label.extend(new_clus_label)
|
268 |
+
per_scale_clus_label = torch.tensor(per_scale_clus_label)
|
269 |
+
return per_scale_clus_label, uniq_scale_mapping
|
270 |
+
|
271 |
+
def get_diar_target_labels(self, uniq_id, sample, fr_level_target):
|
272 |
+
"""
|
273 |
+
Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced
|
274 |
+
from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines
|
275 |
+
the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
uniq_id (str):
|
279 |
+
Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file.
|
280 |
+
sample:
|
281 |
+
`DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath.
|
282 |
+
fr_level_target (torch.tensor):
|
283 |
+
Tensor containing label for each feature-level frame.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
seg_target (torch.tensor):
|
287 |
+
Tensor containing binary speaker labels for base-scale segments.
|
288 |
+
base_clus_label (torch.tensor):
|
289 |
+
Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment.
|
290 |
+
-1 means that there is no corresponding speaker in the target_spks tuple.
|
291 |
+
"""
|
292 |
+
seg_target_list, base_clus_label = [], []
|
293 |
+
self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict'])
|
294 |
+
subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"]
|
295 |
+
for (seg_stt, seg_end) in subseg_time_stamp_list:
|
296 |
+
seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec)
|
297 |
+
soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / (
|
298 |
+
seg_end_fr - seg_stt_fr
|
299 |
+
)
|
300 |
+
label_int_sess = torch.argmax(soft_label_vec_sess)
|
301 |
+
soft_label_vec = soft_label_vec_sess.unsqueeze(0)[:, sample.target_spks].squeeze()
|
302 |
+
if label_int_sess in sample.target_spks and torch.sum(soft_label_vec_sess) > 0:
|
303 |
+
label_int = sample.target_spks.index(label_int_sess)
|
304 |
+
else:
|
305 |
+
label_int = -1
|
306 |
+
label_vec = (soft_label_vec > self.soft_label_thres).float()
|
307 |
+
seg_target_list.append(label_vec.detach())
|
308 |
+
base_clus_label.append(label_int)
|
309 |
+
seg_target = torch.stack(seg_target_list)
|
310 |
+
base_clus_label = torch.tensor(base_clus_label)
|
311 |
+
return seg_target, base_clus_label
|
312 |
+
|
313 |
+
def parse_rttm_for_ms_targets(self, sample):
|
314 |
+
"""
|
315 |
+
Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file.
|
316 |
+
This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level
|
317 |
+
diarization label in a matrix form.
|
318 |
+
|
319 |
+
Example of seg_target:
|
320 |
+
[[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]]
|
321 |
+
|
322 |
+
Args:
|
323 |
+
sample:
|
324 |
+
`DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath.
|
325 |
+
target_spks (tuple):
|
326 |
+
Speaker indices that are generated from combinations. If there are only one or two speakers,
|
327 |
+
only a single target_spks tuple is generated.
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
clus_label_index (torch.tensor):
|
331 |
+
Groundtruth clustering label (cluster index for each segment) from RTTM files for training purpose.
|
332 |
+
seg_target (torch.tensor):
|
333 |
+
Tensor variable containing hard-labels of speaker activity in each base-scale segment.
|
334 |
+
scale_mapping (torch.tensor):
|
335 |
+
Matrix containing the segment indices of each scale. scale_mapping is necessary for reshaping the
|
336 |
+
multiscale embeddings to form an input matrix for the MSDD model.
|
337 |
+
|
338 |
+
"""
|
339 |
+
rttm_lines = open(sample.rttm_file).readlines()
|
340 |
+
uniq_id = self.get_uniq_id_with_range(sample)
|
341 |
+
rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines)
|
342 |
+
fr_level_target = assign_frame_level_spk_vector(
|
343 |
+
rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks
|
344 |
+
)
|
345 |
+
seg_target, base_clus_label = self.get_diar_target_labels(uniq_id, sample, fr_level_target)
|
346 |
+
clus_label_index, scale_mapping = self.assign_labels_to_longer_segs(uniq_id, base_clus_label)
|
347 |
+
return clus_label_index, seg_target, scale_mapping
|
348 |
+
|
349 |
+
def get_uniq_id_with_range(self, sample, deci=3):
|
350 |
+
"""
|
351 |
+
Generate unique training sample ID from unique file ID, offset and duration. The start-end time added
|
352 |
+
unique ID is required for identifying the sample since multiple short audio samples are generated from a single
|
353 |
+
audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
sample:
|
357 |
+
`DiarizationSpeechLabel` instance from collections.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
uniq_id (str):
|
361 |
+
Unique sample ID which includes start and end time of the audio stream.
|
362 |
+
Example: abc1001_3122_6458
|
363 |
+
|
364 |
+
"""
|
365 |
+
bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0]
|
366 |
+
offset = str(int(round(sample.offset, deci) * pow(10, deci)))
|
367 |
+
endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci)))
|
368 |
+
uniq_id = f"{bare_uniq_id}_{offset}_{endtime}"
|
369 |
+
return uniq_id
|
370 |
+
|
371 |
+
def get_ms_seg_timestamps(self, sample):
|
372 |
+
"""
|
373 |
+
Get start and end time of segments in each scale.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
sample:
|
377 |
+
`DiarizationSpeechLabel` instance from preprocessing.collections
|
378 |
+
Returns:
|
379 |
+
ms_seg_timestamps (torch.tensor):
|
380 |
+
Tensor containing Multiscale segment timestamps.
|
381 |
+
ms_seg_counts (torch.tensor):
|
382 |
+
Number of segments for each scale. This information is used for reshaping embedding batch
|
383 |
+
during forward propagation.
|
384 |
+
"""
|
385 |
+
uniq_id = self.get_uniq_id_with_range(sample)
|
386 |
+
ms_seg_timestamps_list = []
|
387 |
+
max_seq_len = len(self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"])
|
388 |
+
ms_seg_counts = [0 for _ in range(self.scale_n)]
|
389 |
+
for scale_idx in range(self.scale_n):
|
390 |
+
scale_ts_list = []
|
391 |
+
for k, (seg_stt, seg_end) in enumerate(
|
392 |
+
self.multiscale_timestamp_dict[uniq_id]["scale_dict"][scale_idx]["time_stamps"]
|
393 |
+
):
|
394 |
+
stt, end = (
|
395 |
+
int((seg_stt - sample.offset) * self.frame_per_sec),
|
396 |
+
int((seg_end - sample.offset) * self.frame_per_sec),
|
397 |
+
)
|
398 |
+
scale_ts_list.append(torch.tensor([stt, end]).detach())
|
399 |
+
ms_seg_counts[scale_idx] = len(
|
400 |
+
self.multiscale_timestamp_dict[uniq_id]["scale_dict"][scale_idx]["time_stamps"]
|
401 |
+
)
|
402 |
+
scale_ts = torch.stack(scale_ts_list)
|
403 |
+
scale_ts_padded = torch.cat([scale_ts, torch.zeros(max_seq_len - len(scale_ts_list), 2)], dim=0)
|
404 |
+
ms_seg_timestamps_list.append(scale_ts_padded.detach())
|
405 |
+
ms_seg_timestamps = torch.stack(ms_seg_timestamps_list)
|
406 |
+
ms_seg_counts = torch.tensor(ms_seg_counts)
|
407 |
+
return ms_seg_timestamps, ms_seg_counts
|
408 |
+
|
409 |
+
def __getitem__(self, index):
|
410 |
+
sample = self.collection[index]
|
411 |
+
if sample.offset is None:
|
412 |
+
sample.offset = 0
|
413 |
+
clus_label_index, targets, scale_mapping = self.parse_rttm_for_ms_targets(sample)
|
414 |
+
features = self.featurizer.process(sample.audio_file, offset=sample.offset, duration=sample.duration)
|
415 |
+
feature_length = torch.tensor(features.shape[0]).long()
|
416 |
+
ms_seg_timestamps, ms_seg_counts = self.get_ms_seg_timestamps(sample)
|
417 |
+
if self.random_flip:
|
418 |
+
torch.manual_seed(index)
|
419 |
+
flip = torch.cat([torch.randperm(self.max_spks), torch.tensor(-1).unsqueeze(0)])
|
420 |
+
clus_label_index, targets = flip[clus_label_index], targets[:, flip[: self.max_spks]]
|
421 |
+
return features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets
|
422 |
+
|
423 |
+
|
424 |
+
class _AudioMSDDInferDataset(Dataset):
|
425 |
+
"""
|
426 |
+
Dataset class that loads a json file containing paths to audio files,
|
427 |
+
RTTM files and number of speakers. This Dataset class is built for diarization inference and
|
428 |
+
evaluation. Speaker embedding sequences, segment timestamps, cluster-average speaker embeddings
|
429 |
+
are loaded from memory and fed into the dataloader.
|
430 |
+
|
431 |
+
Example:
|
432 |
+
{"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
|
433 |
+
"rttm_filepath": "/path/to/diar_label_0.rttm}
|
434 |
+
...
|
435 |
+
{"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
|
436 |
+
"rttm_filepath": "/path/to/diar_label_n.rttm}
|
437 |
+
|
438 |
+
Args:
|
439 |
+
manifest_filepath (str):
|
440 |
+
Path to input manifest json files.
|
441 |
+
emb_dict (dict):
|
442 |
+
Dictionary containing cluster-average embeddings and speaker mapping information.
|
443 |
+
emb_seq (dict):
|
444 |
+
Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps.
|
445 |
+
clus_label_dict (dict):
|
446 |
+
Subsegment-level (from base-scale) speaker labels from clustering results.
|
447 |
+
soft_label_thres (float):
|
448 |
+
A threshold that determines the label of each segment based on RTTM file information.
|
449 |
+
featurizer:
|
450 |
+
Featurizer instance for generating features from raw waveform.
|
451 |
+
seq_eval_mode (bool):
|
452 |
+
If True, F1 score will be calculated for each speaker pair during inference mode.
|
453 |
+
window_stride (float):
|
454 |
+
Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
|
455 |
+
use_single_scale_clus (bool):
|
456 |
+
Use only one scale for clustering instead of using multiple scales of embeddings for clustering.
|
457 |
+
pairwise_infer (bool):
|
458 |
+
This variable should be True if dataloader is created for an inference task.
|
459 |
+
"""
|
460 |
+
|
461 |
+
@property
|
462 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
463 |
+
"""Returns definitions of module output ports."""
|
464 |
+
output_types = OrderedDict(
|
465 |
+
{
|
466 |
+
"ms_emb_seq": NeuralType(('B', 'T', 'C', 'D'), SpectrogramType()),
|
467 |
+
"length": NeuralType(tuple('B'), LengthsType()),
|
468 |
+
"ms_avg_embs": NeuralType(('B', 'C', 'D', 'C'), EncodedRepresentation()),
|
469 |
+
"targets": NeuralType(('B', 'T', 'C'), ProbsType()),
|
470 |
+
}
|
471 |
+
)
|
472 |
+
return output_types
|
473 |
+
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
*,
|
477 |
+
manifest_filepath: str,
|
478 |
+
emb_dict: Dict,
|
479 |
+
emb_seq: Dict,
|
480 |
+
clus_label_dict: Dict,
|
481 |
+
soft_label_thres: float,
|
482 |
+
seq_eval_mode: bool,
|
483 |
+
window_stride: float,
|
484 |
+
use_single_scale_clus: bool,
|
485 |
+
pairwise_infer: bool,
|
486 |
+
):
|
487 |
+
super().__init__()
|
488 |
+
self.collection = DiarizationSpeechLabel(
|
489 |
+
manifests_files=manifest_filepath.split(','),
|
490 |
+
emb_dict=emb_dict,
|
491 |
+
clus_label_dict=clus_label_dict,
|
492 |
+
seq_eval_mode=seq_eval_mode,
|
493 |
+
pairwise_infer=pairwise_infer,
|
494 |
+
)
|
495 |
+
self.emb_dict = emb_dict
|
496 |
+
self.emb_seq = emb_seq
|
497 |
+
self.clus_label_dict = clus_label_dict
|
498 |
+
self.round_digits = 2
|
499 |
+
self.decim = 10 ** self.round_digits
|
500 |
+
self.frame_per_sec = int(1 / window_stride)
|
501 |
+
self.soft_label_thres = soft_label_thres
|
502 |
+
self.pairwise_infer = pairwise_infer
|
503 |
+
self.max_spks = 2
|
504 |
+
self.use_single_scale_clus = use_single_scale_clus
|
505 |
+
self.seq_eval_mode = seq_eval_mode
|
506 |
+
|
507 |
+
def __len__(self):
|
508 |
+
return len(self.collection)
|
509 |
+
|
510 |
+
def parse_rttm_multiscale(self, sample):
|
511 |
+
"""
|
512 |
+
Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file.
|
513 |
+
This function is only used when ``self.seq_eval_mode=True`` and RTTM files are provided. This function converts
|
514 |
+
(start, end, speaker_id) format into base-scale (the finest scale) segment level diarization label in a matrix
|
515 |
+
form to create target matrix.
|
516 |
+
|
517 |
+
Args:
|
518 |
+
sample:
|
519 |
+
DiarizationSpeechLabel instance containing sample information such as audio filepath and RTTM filepath.
|
520 |
+
target_spks (tuple):
|
521 |
+
Two Indices of targeted speakers for evaluation.
|
522 |
+
Example of target_spks: (2, 3)
|
523 |
+
Returns:
|
524 |
+
seg_target (torch.tensor):
|
525 |
+
Tensor variable containing hard-labels of speaker activity in each base-scale segment.
|
526 |
+
"""
|
527 |
+
if sample.rttm_file is None:
|
528 |
+
raise ValueError(f"RTTM file is not provided for this sample {sample}")
|
529 |
+
rttm_lines = open(sample.rttm_file).readlines()
|
530 |
+
uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0]
|
531 |
+
mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping']
|
532 |
+
rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks)
|
533 |
+
fr_level_target = assign_frame_level_spk_vector(
|
534 |
+
rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks
|
535 |
+
)
|
536 |
+
seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target)
|
537 |
+
return seg_target
|
538 |
+
|
539 |
+
def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target):
|
540 |
+
"""
|
541 |
+
Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level
|
542 |
+
speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate
|
543 |
+
ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres`
|
544 |
+
to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has
|
545 |
+
dimension of (number of base-scale segments x 2) dimension.
|
546 |
+
|
547 |
+
Example of seg_target:
|
548 |
+
[[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]]
|
549 |
+
|
550 |
+
Args:
|
551 |
+
uniq_id (str):
|
552 |
+
Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file.
|
553 |
+
fr_level_target (torch.tensor):
|
554 |
+
frame-level binary speaker annotation (1: exist 0: non-exist) generated from RTTM file.
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
seg_target (torch.tensor):
|
558 |
+
Tensor variable containing binary hard-labels of speaker activity in each base-scale segment.
|
559 |
+
|
560 |
+
"""
|
561 |
+
if fr_level_target is None:
|
562 |
+
return None
|
563 |
+
else:
|
564 |
+
seg_target_list = []
|
565 |
+
for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]:
|
566 |
+
seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec)
|
567 |
+
soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / (
|
568 |
+
seg_end_fr - seg_stt_fr
|
569 |
+
)
|
570 |
+
label_vec = (soft_label_vec > self.soft_label_thres).int()
|
571 |
+
seg_target_list.append(label_vec)
|
572 |
+
seg_target = torch.stack(seg_target_list)
|
573 |
+
return seg_target
|
574 |
+
|
575 |
+
def __getitem__(self, index):
|
576 |
+
sample = self.collection[index]
|
577 |
+
if sample.offset is None:
|
578 |
+
sample.offset = 0
|
579 |
+
|
580 |
+
uniq_id = os.path.splitext(os.path.basename(sample.audio_file))[0]
|
581 |
+
scale_n = len(self.emb_dict.keys())
|
582 |
+
_avg_embs = torch.stack([self.emb_dict[scale_index][uniq_id]['avg_embs'] for scale_index in range(scale_n)])
|
583 |
+
|
584 |
+
if self.pairwise_infer:
|
585 |
+
avg_embs = _avg_embs[:, :, self.collection[index].target_spks]
|
586 |
+
else:
|
587 |
+
avg_embs = _avg_embs
|
588 |
+
|
589 |
+
if avg_embs.shape[2] > self.max_spks:
|
590 |
+
raise ValueError(
|
591 |
+
f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}"
|
592 |
+
)
|
593 |
+
|
594 |
+
feats = []
|
595 |
+
for scale_index in range(scale_n):
|
596 |
+
repeat_mat = self.emb_seq["session_scale_mapping"][uniq_id][scale_index]
|
597 |
+
feats.append(self.emb_seq[scale_index][uniq_id][repeat_mat, :])
|
598 |
+
feats_out = torch.stack(feats).permute(1, 0, 2)
|
599 |
+
feats_len = feats_out.shape[0]
|
600 |
+
|
601 |
+
if self.seq_eval_mode:
|
602 |
+
targets = self.parse_rttm_multiscale(sample)
|
603 |
+
else:
|
604 |
+
targets = torch.zeros(feats_len, 2).float()
|
605 |
+
|
606 |
+
return feats_out, feats_len, targets, avg_embs
|
607 |
+
|
608 |
+
|
609 |
+
def _msdd_train_collate_fn(self, batch):
|
610 |
+
"""
|
611 |
+
Collate batch of variables that are needed for raw waveform to diarization label training.
|
612 |
+
The following variables are included in training/validation batch:
|
613 |
+
|
614 |
+
Args:
|
615 |
+
batch (tuple):
|
616 |
+
Batch tuple containing the variables for the diarization training.
|
617 |
+
Returns:
|
618 |
+
features (torch.tensor):
|
619 |
+
Raw waveform samples (time series) loaded from the audio_filepath in the input manifest file.
|
620 |
+
feature lengths (time series sample length):
|
621 |
+
A list of lengths of the raw waveform samples.
|
622 |
+
ms_seg_timestamps (torch.tensor):
|
623 |
+
Matrix containing the start time and end time (timestamps) for each segment and each scale.
|
624 |
+
ms_seg_timestamps is needed for extracting acoustic features from raw waveforms.
|
625 |
+
ms_seg_counts (torch.tensor):
|
626 |
+
Matrix containing The number of segments for each scale. ms_seg_counts is necessary for reshaping
|
627 |
+
the input matrix for the MSDD model.
|
628 |
+
clus_label_index (torch.tensor):
|
629 |
+
Groundtruth Clustering label (cluster index for each segment) from RTTM files for training purpose.
|
630 |
+
clus_label_index is necessary for calculating cluster-average embedding.
|
631 |
+
scale_mapping (torch.tensor):
|
632 |
+
Matrix containing the segment indices of each scale. scale_mapping is necessary for reshaping the
|
633 |
+
multiscale embeddings to form an input matrix for the MSDD model.
|
634 |
+
targets (torch.tensor):
|
635 |
+
Groundtruth Speaker label for the given input embedding sequence.
|
636 |
+
"""
|
637 |
+
packed_batch = list(zip(*batch))
|
638 |
+
features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = packed_batch
|
639 |
+
features_list, feature_length_list = [], []
|
640 |
+
ms_seg_timestamps_list, ms_seg_counts_list, scale_clus_label_list, scale_mapping_list, targets_list = (
|
641 |
+
[],
|
642 |
+
[],
|
643 |
+
[],
|
644 |
+
[],
|
645 |
+
[],
|
646 |
+
)
|
647 |
+
|
648 |
+
max_raw_feat_len = max([x.shape[0] for x in features])
|
649 |
+
max_target_len = max([x.shape[0] for x in targets])
|
650 |
+
max_total_seg_len = max([x.shape[0] for x in clus_label_index])
|
651 |
+
|
652 |
+
for feat, feat_len, ms_seg_ts, ms_seg_ct, scale_clus, scl_map, tgt in batch:
|
653 |
+
seq_len = tgt.shape[0]
|
654 |
+
pad_feat = (0, max_raw_feat_len - feat_len)
|
655 |
+
pad_tgt = (0, 0, 0, max_target_len - seq_len)
|
656 |
+
pad_sm = (0, max_target_len - seq_len)
|
657 |
+
pad_ts = (0, 0, 0, max_target_len - seq_len)
|
658 |
+
pad_sc = (0, max_total_seg_len - scale_clus.shape[0])
|
659 |
+
padded_feat = torch.nn.functional.pad(feat, pad_feat)
|
660 |
+
padded_tgt = torch.nn.functional.pad(tgt, pad_tgt)
|
661 |
+
padded_sm = torch.nn.functional.pad(scl_map, pad_sm)
|
662 |
+
padded_ms_seg_ts = torch.nn.functional.pad(ms_seg_ts, pad_ts)
|
663 |
+
padded_scale_clus = torch.nn.functional.pad(scale_clus, pad_sc)
|
664 |
+
|
665 |
+
features_list.append(padded_feat)
|
666 |
+
feature_length_list.append(feat_len.clone().detach())
|
667 |
+
ms_seg_timestamps_list.append(padded_ms_seg_ts)
|
668 |
+
ms_seg_counts_list.append(ms_seg_ct.clone().detach())
|
669 |
+
scale_clus_label_list.append(padded_scale_clus)
|
670 |
+
scale_mapping_list.append(padded_sm)
|
671 |
+
targets_list.append(padded_tgt)
|
672 |
+
|
673 |
+
features = torch.stack(features_list)
|
674 |
+
feature_length = torch.stack(feature_length_list)
|
675 |
+
ms_seg_timestamps = torch.stack(ms_seg_timestamps_list)
|
676 |
+
clus_label_index = torch.stack(scale_clus_label_list)
|
677 |
+
ms_seg_counts = torch.stack(ms_seg_counts_list)
|
678 |
+
scale_mapping = torch.stack(scale_mapping_list)
|
679 |
+
targets = torch.stack(targets_list)
|
680 |
+
return features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets
|
681 |
+
|
682 |
+
|
683 |
+
def _msdd_infer_collate_fn(self, batch):
|
684 |
+
"""
|
685 |
+
Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings.
|
686 |
+
|
687 |
+
Args:
|
688 |
+
batch (tuple):
|
689 |
+
Batch tuple containing feats, feats_len, targets and ms_avg_embs.
|
690 |
+
Returns:
|
691 |
+
feats (torch.tensor):
|
692 |
+
Collated speaker embedding with unified length.
|
693 |
+
feats_len (torch.tensor):
|
694 |
+
The actual length of each embedding sequence without zero padding.
|
695 |
+
targets (torch.tensor):
|
696 |
+
Groundtruth Speaker label for the given input embedding sequence.
|
697 |
+
ms_avg_embs (torch.tensor):
|
698 |
+
Cluster-average speaker embedding vectors.
|
699 |
+
"""
|
700 |
+
|
701 |
+
packed_batch = list(zip(*batch))
|
702 |
+
feats, feats_len, targets, ms_avg_embs = packed_batch
|
703 |
+
feats_list, flen_list, targets_list, ms_avg_embs_list = [], [], [], []
|
704 |
+
max_audio_len = max(feats_len)
|
705 |
+
max_target_len = max([x.shape[0] for x in targets])
|
706 |
+
|
707 |
+
for feature, feat_len, target, ivector in batch:
|
708 |
+
flen_list.append(feat_len)
|
709 |
+
ms_avg_embs_list.append(ivector)
|
710 |
+
if feat_len < max_audio_len:
|
711 |
+
pad_a = (0, 0, 0, 0, 0, max_audio_len - feat_len)
|
712 |
+
pad_t = (0, 0, 0, max_target_len - target.shape[0])
|
713 |
+
padded_feature = torch.nn.functional.pad(feature, pad_a)
|
714 |
+
padded_target = torch.nn.functional.pad(target, pad_t)
|
715 |
+
feats_list.append(padded_feature)
|
716 |
+
targets_list.append(padded_target)
|
717 |
+
else:
|
718 |
+
targets_list.append(target.clone().detach())
|
719 |
+
feats_list.append(feature.clone().detach())
|
720 |
+
|
721 |
+
feats = torch.stack(feats_list)
|
722 |
+
feats_len = torch.tensor(flen_list)
|
723 |
+
targets = torch.stack(targets_list)
|
724 |
+
ms_avg_embs = torch.stack(ms_avg_embs_list)
|
725 |
+
return feats, feats_len, targets, ms_avg_embs
|
726 |
+
|
727 |
+
|
728 |
+
class AudioToSpeechMSDDTrainDataset(_AudioMSDDTrainDataset):
|
729 |
+
"""
|
730 |
+
Dataset class that loads a json file containing paths to audio files,
|
731 |
+
rttm files and number of speakers. This Dataset class is designed for
|
732 |
+
training or fine-tuning speaker embedding extractor and diarization decoder
|
733 |
+
at the same time.
|
734 |
+
|
735 |
+
Example:
|
736 |
+
{"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
|
737 |
+
"rttm_filepath": "/path/to/diar_label_0.rttm}
|
738 |
+
...
|
739 |
+
{"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
|
740 |
+
"rttm_filepath": "/path/to/diar_label_n.rttm}
|
741 |
+
|
742 |
+
Args:
|
743 |
+
manifest_filepath (str):
|
744 |
+
Path to input manifest json files.
|
745 |
+
multiscale_args_dict (dict):
|
746 |
+
Dictionary containing the parameters for multiscale segmentation and clustering.
|
747 |
+
emb_dir (str):
|
748 |
+
Path to a temporary folder where segmentation information for embedding extraction is saved.
|
749 |
+
soft_label_thres (float):
|
750 |
+
A threshold that determines the label of each segment based on RTTM file information.
|
751 |
+
featurizer:
|
752 |
+
Featurizer instance for generating features from the raw waveform.
|
753 |
+
window_stride (float):
|
754 |
+
Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
|
755 |
+
emb_batch_size (int):
|
756 |
+
Number of embedding vectors that are trained with attached computational graphs.
|
757 |
+
pairwise_infer (bool):
|
758 |
+
This variable should be True if dataloader is created for an inference task.
|
759 |
+
"""
|
760 |
+
|
761 |
+
def __init__(
|
762 |
+
self,
|
763 |
+
*,
|
764 |
+
manifest_filepath: str,
|
765 |
+
multiscale_args_dict: Dict,
|
766 |
+
emb_dir: str,
|
767 |
+
soft_label_thres: float,
|
768 |
+
featurizer,
|
769 |
+
window_stride,
|
770 |
+
emb_batch_size,
|
771 |
+
pairwise_infer: bool,
|
772 |
+
global_rank: int,
|
773 |
+
):
|
774 |
+
super().__init__(
|
775 |
+
manifest_filepath=manifest_filepath,
|
776 |
+
multiscale_args_dict=multiscale_args_dict,
|
777 |
+
emb_dir=emb_dir,
|
778 |
+
soft_label_thres=soft_label_thres,
|
779 |
+
featurizer=featurizer,
|
780 |
+
window_stride=window_stride,
|
781 |
+
emb_batch_size=emb_batch_size,
|
782 |
+
pairwise_infer=pairwise_infer,
|
783 |
+
global_rank=global_rank,
|
784 |
+
)
|
785 |
+
|
786 |
+
def msdd_train_collate_fn(self, batch):
|
787 |
+
return _msdd_train_collate_fn(self, batch)
|
788 |
+
|
789 |
+
|
790 |
+
class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset):
|
791 |
+
"""
|
792 |
+
Dataset class that loads a json file containing paths to audio files,
|
793 |
+
rttm files and number of speakers. The created labels are used for diarization inference.
|
794 |
+
|
795 |
+
Example:
|
796 |
+
{"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
|
797 |
+
"rttm_filepath": "/path/to/diar_label_0.rttm}
|
798 |
+
...
|
799 |
+
{"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
|
800 |
+
"rttm_filepath": "/path/to/diar_label_n.rttm}
|
801 |
+
|
802 |
+
Args:
|
803 |
+
manifest_filepath (str):
|
804 |
+
Path to input manifest json files.
|
805 |
+
emb_dict (dict):
|
806 |
+
Dictionary containing cluster-average embeddings and speaker mapping information.
|
807 |
+
emb_seq (dict):
|
808 |
+
Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps.
|
809 |
+
clus_label_dict (dict):
|
810 |
+
Subsegment-level (from base-scale) speaker labels from clustering results.
|
811 |
+
soft_label_thres (float):
|
812 |
+
Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps.
|
813 |
+
featurizer:
|
814 |
+
Featurizer instance for generating features from raw waveform.
|
815 |
+
use_single_scale_clus (bool):
|
816 |
+
Use only one scale for clustering instead of using multiple scales of embeddings for clustering.
|
817 |
+
seq_eval_mode (bool):
|
818 |
+
If True, F1 score will be calculated for each speaker pair during inference mode.
|
819 |
+
window_stride (float):
|
820 |
+
Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
|
821 |
+
pairwise_infer (bool):
|
822 |
+
If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio
|
823 |
+
is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then
|
824 |
+
fed into the MSDD to merge the individual results.
|
825 |
+
"""
|
826 |
+
|
827 |
+
def __init__(
|
828 |
+
self,
|
829 |
+
*,
|
830 |
+
manifest_filepath: str,
|
831 |
+
emb_dict: Dict,
|
832 |
+
emb_seq: Dict,
|
833 |
+
clus_label_dict: Dict,
|
834 |
+
soft_label_thres: float,
|
835 |
+
use_single_scale_clus: bool,
|
836 |
+
seq_eval_mode: bool,
|
837 |
+
window_stride: float,
|
838 |
+
pairwise_infer: bool,
|
839 |
+
):
|
840 |
+
super().__init__(
|
841 |
+
manifest_filepath=manifest_filepath,
|
842 |
+
emb_dict=emb_dict,
|
843 |
+
emb_seq=emb_seq,
|
844 |
+
clus_label_dict=clus_label_dict,
|
845 |
+
soft_label_thres=soft_label_thres,
|
846 |
+
use_single_scale_clus=use_single_scale_clus,
|
847 |
+
window_stride=window_stride,
|
848 |
+
seq_eval_mode=seq_eval_mode,
|
849 |
+
pairwise_infer=pairwise_infer,
|
850 |
+
)
|
851 |
+
|
852 |
+
def msdd_infer_collate_fn(self, batch):
|
853 |
+
return _msdd_infer_collate_fn(self, batch)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label.py
ADDED
@@ -0,0 +1,1294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import io
|
15 |
+
import os
|
16 |
+
from typing import Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import webdataset as wd
|
20 |
+
|
21 |
+
from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_sharded_filepaths
|
22 |
+
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
23 |
+
from nemo.collections.asr.parts.preprocessing.segment import available_formats as valid_sf_formats
|
24 |
+
from nemo.collections.common.parts.preprocessing import collections
|
25 |
+
from nemo.core.classes import Dataset, IterableDataset
|
26 |
+
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType, RegressionValuesType
|
27 |
+
from nemo.utils import logging
|
28 |
+
|
29 |
+
# List of valid file formats (prioritized by order of importance)
|
30 |
+
VALID_FILE_FORMATS = ';'.join(['wav', 'mp3', 'flac'] + [fmt.lower() for fmt in valid_sf_formats.keys()])
|
31 |
+
|
32 |
+
|
33 |
+
def repeat_signal(signal: torch.Tensor, sig_len: int, required_length: int) -> torch.Tensor:
|
34 |
+
"""repeat signal to make short signal to have required_length
|
35 |
+
Args:
|
36 |
+
signal (Tensor): input signal
|
37 |
+
sig_len (int): length of input signal
|
38 |
+
required_length (int): length of generated signal
|
39 |
+
Returns:
|
40 |
+
signal (Tensor): generated signal of required_length by repeating itself.
|
41 |
+
"""
|
42 |
+
sub: torch.Tensor = torch.tensor([])
|
43 |
+
repeat = int(required_length // sig_len)
|
44 |
+
rem = int(required_length % sig_len)
|
45 |
+
sub: torch.Tensor = torch.tensor([])
|
46 |
+
rep_sig: torch.Tensor = torch.cat(repeat * [signal])
|
47 |
+
if rem > 0:
|
48 |
+
sub = signal[-rem:]
|
49 |
+
signal = torch.cat((rep_sig, sub))
|
50 |
+
else:
|
51 |
+
signal = rep_sig
|
52 |
+
return signal
|
53 |
+
|
54 |
+
|
55 |
+
def normalize(signal):
|
56 |
+
"""normalize signal
|
57 |
+
Args:
|
58 |
+
signal(FloatTensor): signal to be normalized.
|
59 |
+
"""
|
60 |
+
signal_minusmean = signal - signal.mean()
|
61 |
+
return signal_minusmean / signal_minusmean.abs().max()
|
62 |
+
|
63 |
+
|
64 |
+
def count_occurence(manifest_file_id):
|
65 |
+
"""Count number of wav files in Dict manifest_file_id. Use for _TarredAudioToLabelDataset.
|
66 |
+
Args:
|
67 |
+
manifest_file_id (Dict): Dict of files and their corresponding id. {'A-sub0' : 1, ..., 'S-sub10':100}
|
68 |
+
Returns:
|
69 |
+
count (Dict): Dict of wav files {'A' : 2, ..., 'S':10}
|
70 |
+
"""
|
71 |
+
count = dict()
|
72 |
+
for i in manifest_file_id:
|
73 |
+
audio_filename = i.split("-sub")[0]
|
74 |
+
count[audio_filename] = count.get(audio_filename, 0) + 1
|
75 |
+
return count
|
76 |
+
|
77 |
+
|
78 |
+
def _speech_collate_fn(batch, pad_id):
|
79 |
+
"""collate batch of audio sig, audio len, tokens, tokens len
|
80 |
+
Args:
|
81 |
+
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
|
82 |
+
LongTensor): A tuple of tuples of signal, signal lengths,
|
83 |
+
encoded tokens, and encoded tokens length. This collate func
|
84 |
+
assumes the signals are 1d torch tensors (i.e. mono audio).
|
85 |
+
"""
|
86 |
+
_, audio_lengths, _, tokens_lengths = zip(*batch)
|
87 |
+
max_audio_len = 0
|
88 |
+
has_audio = audio_lengths[0] is not None
|
89 |
+
if has_audio:
|
90 |
+
max_audio_len = max(audio_lengths).item()
|
91 |
+
max_tokens_len = max(tokens_lengths).item()
|
92 |
+
|
93 |
+
audio_signal, tokens = [], []
|
94 |
+
for sig, sig_len, tokens_i, tokens_i_len in batch:
|
95 |
+
if has_audio:
|
96 |
+
sig_len = sig_len.item()
|
97 |
+
if sig_len < max_audio_len:
|
98 |
+
pad = (0, max_audio_len - sig_len)
|
99 |
+
sig = torch.nn.functional.pad(sig, pad)
|
100 |
+
audio_signal.append(sig)
|
101 |
+
tokens_i_len = tokens_i_len.item()
|
102 |
+
if tokens_i_len < max_tokens_len:
|
103 |
+
pad = (0, max_tokens_len - tokens_i_len)
|
104 |
+
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
|
105 |
+
tokens.append(tokens_i)
|
106 |
+
|
107 |
+
if has_audio:
|
108 |
+
audio_signal = torch.stack(audio_signal)
|
109 |
+
audio_lengths = torch.stack(audio_lengths)
|
110 |
+
else:
|
111 |
+
audio_signal, audio_lengths = None, None
|
112 |
+
tokens = torch.stack(tokens)
|
113 |
+
tokens_lengths = torch.stack(tokens_lengths)
|
114 |
+
|
115 |
+
return audio_signal, audio_lengths, tokens, tokens_lengths
|
116 |
+
|
117 |
+
|
118 |
+
def _fixed_seq_collate_fn(self, batch):
|
119 |
+
"""collate batch of audio sig, audio len, tokens, tokens len
|
120 |
+
Args:
|
121 |
+
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
|
122 |
+
LongTensor): A tuple of tuples of signal, signal lengths,
|
123 |
+
encoded tokens, and encoded tokens length. This collate func
|
124 |
+
assumes the signals are 1d torch tensors (i.e. mono audio).
|
125 |
+
"""
|
126 |
+
_, audio_lengths, _, tokens_lengths = zip(*batch)
|
127 |
+
|
128 |
+
has_audio = audio_lengths[0] is not None
|
129 |
+
fixed_length = int(max(audio_lengths))
|
130 |
+
|
131 |
+
audio_signal, tokens, new_audio_lengths = [], [], []
|
132 |
+
for sig, sig_len, tokens_i, _ in batch:
|
133 |
+
if has_audio:
|
134 |
+
sig_len = sig_len.item()
|
135 |
+
chunck_len = sig_len - fixed_length
|
136 |
+
|
137 |
+
if chunck_len < 0:
|
138 |
+
repeat = fixed_length // sig_len
|
139 |
+
rem = fixed_length % sig_len
|
140 |
+
sub = sig[-rem:] if rem > 0 else torch.tensor([])
|
141 |
+
rep_sig = torch.cat(repeat * [sig])
|
142 |
+
sig = torch.cat((rep_sig, sub))
|
143 |
+
new_audio_lengths.append(torch.tensor(fixed_length))
|
144 |
+
|
145 |
+
audio_signal.append(sig)
|
146 |
+
|
147 |
+
tokens.append(tokens_i)
|
148 |
+
|
149 |
+
if has_audio:
|
150 |
+
audio_signal = torch.stack(audio_signal)
|
151 |
+
audio_lengths = torch.stack(new_audio_lengths)
|
152 |
+
else:
|
153 |
+
audio_signal, audio_lengths = None, None
|
154 |
+
tokens = torch.stack(tokens)
|
155 |
+
tokens_lengths = torch.stack(tokens_lengths)
|
156 |
+
|
157 |
+
return audio_signal, audio_lengths, tokens, tokens_lengths
|
158 |
+
|
159 |
+
|
160 |
+
def _vad_frame_seq_collate_fn(self, batch):
|
161 |
+
"""collate batch of audio sig, audio len, tokens, tokens len
|
162 |
+
Args:
|
163 |
+
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
|
164 |
+
LongTensor): A tuple of tuples of signal, signal lengths,
|
165 |
+
encoded tokens, and encoded tokens length. This collate func
|
166 |
+
assumes the signals are 1d torch tensors (i.e. mono audio).
|
167 |
+
batch size equals to 1.
|
168 |
+
"""
|
169 |
+
slice_length = int(self.featurizer.sample_rate * self.window_length_in_sec)
|
170 |
+
_, audio_lengths, _, tokens_lengths = zip(*batch)
|
171 |
+
slice_length = int(min(slice_length, max(audio_lengths)))
|
172 |
+
shift = int(self.featurizer.sample_rate * self.shift_length_in_sec)
|
173 |
+
has_audio = audio_lengths[0] is not None
|
174 |
+
|
175 |
+
audio_signal, num_slices, tokens, audio_lengths = [], [], [], []
|
176 |
+
|
177 |
+
append_len_start = slice_length // 2
|
178 |
+
append_len_end = slice_length - slice_length // 2
|
179 |
+
for sig, sig_len, tokens_i, _ in batch:
|
180 |
+
if self.normalize_audio:
|
181 |
+
sig = normalize(sig)
|
182 |
+
start = torch.zeros(append_len_start)
|
183 |
+
end = torch.zeros(append_len_end)
|
184 |
+
sig = torch.cat((start, sig, end))
|
185 |
+
sig_len += slice_length
|
186 |
+
|
187 |
+
if has_audio:
|
188 |
+
slices = torch.div(sig_len - slice_length, shift, rounding_mode='trunc')
|
189 |
+
for slice_id in range(slices):
|
190 |
+
start_idx = slice_id * shift
|
191 |
+
end_idx = start_idx + slice_length
|
192 |
+
signal = sig[start_idx:end_idx]
|
193 |
+
audio_signal.append(signal)
|
194 |
+
|
195 |
+
num_slices.append(slices)
|
196 |
+
tokens.extend([tokens_i] * slices)
|
197 |
+
audio_lengths.extend([slice_length] * slices)
|
198 |
+
|
199 |
+
if has_audio:
|
200 |
+
audio_signal = torch.stack(audio_signal)
|
201 |
+
audio_lengths = torch.tensor(audio_lengths)
|
202 |
+
else:
|
203 |
+
audio_signal, audio_lengths = None, None
|
204 |
+
|
205 |
+
tokens = torch.stack(tokens)
|
206 |
+
tokens_lengths = torch.tensor(num_slices)
|
207 |
+
return audio_signal, audio_lengths, tokens, tokens_lengths
|
208 |
+
|
209 |
+
|
210 |
+
class _AudioLabelDataset(Dataset):
|
211 |
+
"""
|
212 |
+
Dataset that loads tensors via a json file containing paths to audio files,
|
213 |
+
labels, and durations and offsets(in seconds). Each new line is a
|
214 |
+
different sample. Example below:
|
215 |
+
and their target labels. JSON files should be of the following format::
|
216 |
+
{"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
|
217 |
+
target_label_0, "offset": offset_in_sec_0}
|
218 |
+
...
|
219 |
+
{"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
|
220 |
+
target_label_n, "offset": offset_in_sec_n}
|
221 |
+
Args:
|
222 |
+
manifest_filepath (Union[str, List[str]]): Dataset parameter. Path to JSON containing data.
|
223 |
+
labels (list): Dataset parameter. List of target classes that can be output by the speaker recognition model.
|
224 |
+
featurizer
|
225 |
+
min_duration (float): Dataset parameter. All training files which have a duration less than min_duration
|
226 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
227 |
+
Defaults to 0.1.
|
228 |
+
max_duration (float): Dataset parameter.
|
229 |
+
All training files which have a duration more than max_duration
|
230 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
231 |
+
Defaults to None.
|
232 |
+
trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim().
|
233 |
+
Defaults to False.
|
234 |
+
"""
|
235 |
+
|
236 |
+
@property
|
237 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
238 |
+
"""Returns definitions of module output ports.
|
239 |
+
"""
|
240 |
+
|
241 |
+
output_types = {
|
242 |
+
'audio_signal': NeuralType(
|
243 |
+
('B', 'T'),
|
244 |
+
AudioSignal(freq=self._sample_rate)
|
245 |
+
if self is not None and hasattr(self, '_sample_rate')
|
246 |
+
else AudioSignal(),
|
247 |
+
),
|
248 |
+
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
249 |
+
}
|
250 |
+
|
251 |
+
if self.is_regression_task:
|
252 |
+
output_types.update(
|
253 |
+
{
|
254 |
+
'targets': NeuralType(tuple('B'), RegressionValuesType()),
|
255 |
+
'targets_length': NeuralType(tuple('B'), LengthsType()),
|
256 |
+
}
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
|
260 |
+
output_types.update(
|
261 |
+
{'label': NeuralType(tuple('B'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),}
|
262 |
+
)
|
263 |
+
|
264 |
+
return output_types
|
265 |
+
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
*,
|
269 |
+
manifest_filepath: Union[str, List[str]],
|
270 |
+
labels: List[str],
|
271 |
+
featurizer,
|
272 |
+
min_duration: Optional[float] = 0.1,
|
273 |
+
max_duration: Optional[float] = None,
|
274 |
+
trim: bool = False,
|
275 |
+
is_regression_task: bool = False,
|
276 |
+
cal_labels_occurrence: Optional[bool] = False,
|
277 |
+
):
|
278 |
+
super().__init__()
|
279 |
+
if isinstance(manifest_filepath, str):
|
280 |
+
manifest_filepath = manifest_filepath.split(',')
|
281 |
+
cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True)
|
282 |
+
self.collection = collections.ASRSpeechLabel(
|
283 |
+
manifests_files=manifest_filepath,
|
284 |
+
min_duration=min_duration,
|
285 |
+
max_duration=max_duration,
|
286 |
+
is_regression_task=is_regression_task,
|
287 |
+
cal_labels_occurrence=cal_labels_occurrence,
|
288 |
+
)
|
289 |
+
|
290 |
+
self.featurizer = featurizer
|
291 |
+
self.trim = trim
|
292 |
+
self.is_regression_task = is_regression_task
|
293 |
+
|
294 |
+
if not is_regression_task:
|
295 |
+
self.labels = labels if labels else self.collection.uniq_labels
|
296 |
+
self.num_classes = len(self.labels) if self.labels is not None else 1
|
297 |
+
self.label2id, self.id2label = {}, {}
|
298 |
+
self.id2occurrence, self.labels_occurrence = {}, []
|
299 |
+
|
300 |
+
for label_id, label in enumerate(self.labels):
|
301 |
+
self.label2id[label] = label_id
|
302 |
+
self.id2label[label_id] = label
|
303 |
+
if cal_labels_occurrence:
|
304 |
+
self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
|
305 |
+
|
306 |
+
if cal_labels_occurrence:
|
307 |
+
self.labels_occurrence = [self.id2occurrence[k] for k in sorted(self.id2occurrence)]
|
308 |
+
|
309 |
+
for idx in range(len(self.labels[:5])):
|
310 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
311 |
+
|
312 |
+
else:
|
313 |
+
self.labels = []
|
314 |
+
self.num_classes = 1
|
315 |
+
|
316 |
+
def __len__(self):
|
317 |
+
return len(self.collection)
|
318 |
+
|
319 |
+
def __getitem__(self, index):
|
320 |
+
sample = self.collection[index]
|
321 |
+
|
322 |
+
offset = sample.offset
|
323 |
+
|
324 |
+
if offset is None:
|
325 |
+
offset = 0
|
326 |
+
|
327 |
+
features = self.featurizer.process(sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim)
|
328 |
+
f, fl = features, torch.tensor(features.shape[0]).long()
|
329 |
+
|
330 |
+
if not self.is_regression_task:
|
331 |
+
t = torch.tensor(self.label2id[sample.label]).long()
|
332 |
+
else:
|
333 |
+
t = torch.tensor(sample.label).float()
|
334 |
+
|
335 |
+
tl = torch.tensor(1).long() # For compatibility with collate_fn used later
|
336 |
+
|
337 |
+
return f, fl, t, tl
|
338 |
+
|
339 |
+
|
340 |
+
# Ported from https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/data/speech2text/speech_commands.py
|
341 |
+
class AudioToClassificationLabelDataset(_AudioLabelDataset):
|
342 |
+
"""
|
343 |
+
Dataset that loads tensors via a json file containing paths to audio
|
344 |
+
files, command class, and durations (in seconds). Each new line is a
|
345 |
+
different sample. Example below:
|
346 |
+
{"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
|
347 |
+
target_label_0, "offset": offset_in_sec_0}
|
348 |
+
...
|
349 |
+
{"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
|
350 |
+
target_label_n, "offset": offset_in_sec_n}
|
351 |
+
Args:
|
352 |
+
manifest_filepath (Union[str, List[str]]): Path to manifest json as described above. Can
|
353 |
+
be comma-separated paths.
|
354 |
+
labels (Optional[list]): String containing all the possible labels to map to
|
355 |
+
if None then automatically picks from ASRSpeechLabel collection.
|
356 |
+
featurizer: Initialized featurizer class that converts paths of
|
357 |
+
audio to feature tensors
|
358 |
+
max_duration: If audio exceeds this length, do not include in dataset
|
359 |
+
min_duration: If audio is less than this length, do not include
|
360 |
+
in dataset
|
361 |
+
trim: Boolean flag whether to trim the audio
|
362 |
+
"""
|
363 |
+
|
364 |
+
def _collate_fn(self, batch):
|
365 |
+
return _speech_collate_fn(batch, pad_id=0)
|
366 |
+
|
367 |
+
|
368 |
+
class AudioToSpeechLabelDataset(_AudioLabelDataset):
|
369 |
+
"""
|
370 |
+
Dataset that loads tensors via a json file containing paths to audio
|
371 |
+
files, command class, and durations (in seconds). Each new line is a
|
372 |
+
different sample. Example below:
|
373 |
+
{"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
|
374 |
+
target_label_0, "offset": offset_in_sec_0}
|
375 |
+
...
|
376 |
+
{"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
|
377 |
+
target_label_n, "offset": offset_in_sec_n}
|
378 |
+
Args:
|
379 |
+
manifest_filepath (Union[str, List[str]]): Path to manifest json as described above. Can
|
380 |
+
be comma-separated paths.
|
381 |
+
labels (Optional[list]): String containing all the possible labels to map to
|
382 |
+
if None then automatically picks from ASRSpeechLabel collection.
|
383 |
+
min_duration (float): Dataset parameter.
|
384 |
+
All training files which have a duration less than min_duration
|
385 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
386 |
+
Defaults to 0.1.
|
387 |
+
max_duration (float): Dataset parameter.
|
388 |
+
All training files which have a duration more than max_duration
|
389 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
390 |
+
Defaults to None.
|
391 |
+
trim (bool): Whether to use trim silence from beginning and end
|
392 |
+
of audio signal using librosa.effects.trim().
|
393 |
+
Defaults to False.
|
394 |
+
window_length_in_sec (float): length of window/slice (in seconds)
|
395 |
+
Use this for speaker recognition and VAD tasks.
|
396 |
+
shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch
|
397 |
+
Use this for VAD task during inference.
|
398 |
+
normalize_audio (bool): Whether to normalize audio signal.
|
399 |
+
Defaults to False.
|
400 |
+
is_regression_task (bool): Whether the dataset is for a regression task instead of classification.
|
401 |
+
Defaults to False.
|
402 |
+
cal_labels_occurrence (bool): Whether to calculate occurrence of labels
|
403 |
+
Defaults to False.
|
404 |
+
"""
|
405 |
+
|
406 |
+
def __init__(
|
407 |
+
self,
|
408 |
+
*,
|
409 |
+
manifest_filepath: Union[str, List[str]],
|
410 |
+
labels: List[str],
|
411 |
+
featurizer,
|
412 |
+
min_duration: Optional[float] = 0.1,
|
413 |
+
max_duration: Optional[float] = None,
|
414 |
+
trim: bool = False,
|
415 |
+
window_length_in_sec: Optional[float] = 8,
|
416 |
+
shift_length_in_sec: Optional[float] = 1,
|
417 |
+
normalize_audio: bool = False,
|
418 |
+
is_regression_task: bool = False,
|
419 |
+
cal_labels_occurrence: Optional[bool] = False,
|
420 |
+
):
|
421 |
+
self.window_length_in_sec = window_length_in_sec
|
422 |
+
self.shift_length_in_sec = shift_length_in_sec
|
423 |
+
self.normalize_audio = normalize_audio
|
424 |
+
|
425 |
+
logging.debug("Window/slice length considered for collate func is {}".format(self.window_length_in_sec))
|
426 |
+
logging.debug("Shift length considered for collate func is {}".format(self.shift_length_in_sec))
|
427 |
+
|
428 |
+
super().__init__(
|
429 |
+
manifest_filepath=manifest_filepath,
|
430 |
+
labels=labels,
|
431 |
+
featurizer=featurizer,
|
432 |
+
min_duration=min_duration,
|
433 |
+
max_duration=max_duration,
|
434 |
+
trim=trim,
|
435 |
+
is_regression_task=is_regression_task,
|
436 |
+
cal_labels_occurrence=cal_labels_occurrence,
|
437 |
+
)
|
438 |
+
|
439 |
+
def fixed_seq_collate_fn(self, batch):
|
440 |
+
return _fixed_seq_collate_fn(self, batch)
|
441 |
+
|
442 |
+
def vad_frame_seq_collate_fn(self, batch):
|
443 |
+
return _vad_frame_seq_collate_fn(self, batch)
|
444 |
+
|
445 |
+
|
446 |
+
class _TarredAudioLabelDataset(IterableDataset):
|
447 |
+
"""
|
448 |
+
A similar Dataset to the AudioLabelDataSet, but which loads tarred audio files.
|
449 |
+
|
450 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToSpeechLabelDataset),
|
451 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
452 |
+
contain the information for one audio file, including at least the label and name of the audio
|
453 |
+
file within the tarball.
|
454 |
+
|
455 |
+
Valid formats for the audio_tar_filepaths argument include:
|
456 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
457 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
458 |
+
|
459 |
+
Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
|
460 |
+
This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
|
461 |
+
Supported opening braces - { <=> (, [, < and the special tag _OP_.
|
462 |
+
Supported closing braces - } <=> ), ], > and the special tag _CL_.
|
463 |
+
For SLURM based tasks, we suggest the use of the special tags for ease of use.
|
464 |
+
|
465 |
+
See the documentation for more information about accepted data and input formats.
|
466 |
+
|
467 |
+
If using multiple processes the number of shards should be divisible by the number of workers to ensure an
|
468 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
469 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
470 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
471 |
+
|
472 |
+
Notice that a few arguments are different from the AudioLabelDataSet; for example, shuffle (bool) has been
|
473 |
+
replaced by shuffle_n (int).
|
474 |
+
|
475 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
476 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
480 |
+
string (can be brace-expandable).
|
481 |
+
manifest_filepath (str): Path to the manifest.
|
482 |
+
labels (list): Dataset parameter.
|
483 |
+
List of target classes that can be output by the speaker recognition model.
|
484 |
+
featurizer
|
485 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
486 |
+
See WebDataset documentation for more details.
|
487 |
+
Defaults to 0.
|
488 |
+
min_duration (float): Dataset parameter.
|
489 |
+
All training files which have a duration less than min_duration
|
490 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
491 |
+
Defaults to 0.1.
|
492 |
+
max_duration (float): Dataset parameter.
|
493 |
+
All training files which have a duration more than max_duration
|
494 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
495 |
+
Defaults to None.
|
496 |
+
trim(bool): Whether to use trim silence from beginning and end
|
497 |
+
of audio signal using librosa.effects.trim().
|
498 |
+
Defaults to False.
|
499 |
+
window_length_in_sec (float): length of slice/window (in seconds) # Pass this only for speaker recognition and VAD task
|
500 |
+
shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task. in a batch # Pass this only for VAD task during inference.
|
501 |
+
normalize_audio (bool): Whether to normalize audio signal. Defaults to False.
|
502 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
503 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
504 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
505 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
506 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
507 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
508 |
+
dataset independently of other nodes, and reduces dependence on the value of `shuffle_n`.
|
509 |
+
|
510 |
+
.. warning::
|
511 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
512 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
513 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
514 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
515 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
516 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
517 |
+
or test datasets.
|
518 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
519 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
520 |
+
is_regression_task (bool): Whether it is a regression task. Defualts to False.
|
521 |
+
"""
|
522 |
+
|
523 |
+
def __init__(
|
524 |
+
self,
|
525 |
+
*,
|
526 |
+
audio_tar_filepaths: Union[str, List[str]],
|
527 |
+
manifest_filepath: Union[str, List[str]],
|
528 |
+
labels: List[str],
|
529 |
+
featurizer,
|
530 |
+
shuffle_n: int = 0,
|
531 |
+
min_duration: Optional[float] = 0.1,
|
532 |
+
max_duration: Optional[float] = None,
|
533 |
+
trim: bool = False,
|
534 |
+
shard_strategy: str = "scatter",
|
535 |
+
global_rank: int = 0,
|
536 |
+
world_size: int = 0,
|
537 |
+
is_regression_task: bool = False,
|
538 |
+
):
|
539 |
+
cache_datastore_manifests(manifest_filepaths=manifest_filepath)
|
540 |
+
self.collection = collections.ASRSpeechLabel(
|
541 |
+
manifests_files=manifest_filepath,
|
542 |
+
min_duration=min_duration,
|
543 |
+
max_duration=max_duration,
|
544 |
+
index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID
|
545 |
+
)
|
546 |
+
|
547 |
+
self.file_occurence = count_occurence(self.collection.mapping)
|
548 |
+
|
549 |
+
self.featurizer = featurizer
|
550 |
+
self.trim = trim
|
551 |
+
|
552 |
+
self.labels = labels if labels else self.collection.uniq_labels
|
553 |
+
self.num_classes = len(self.labels)
|
554 |
+
|
555 |
+
self.label2id, self.id2label = {}, {}
|
556 |
+
for label_id, label in enumerate(self.labels):
|
557 |
+
self.label2id[label] = label_id
|
558 |
+
self.id2label[label_id] = label
|
559 |
+
|
560 |
+
for idx in range(len(self.labels[:5])):
|
561 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
562 |
+
|
563 |
+
audio_tar_filepaths = expand_sharded_filepaths(
|
564 |
+
sharded_filepaths=audio_tar_filepaths,
|
565 |
+
shard_strategy=shard_strategy,
|
566 |
+
world_size=world_size,
|
567 |
+
global_rank=global_rank,
|
568 |
+
)
|
569 |
+
# Put together WebDataset
|
570 |
+
self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)
|
571 |
+
|
572 |
+
if shuffle_n > 0:
|
573 |
+
self._dataset = self._dataset.shuffle(shuffle_n)
|
574 |
+
else:
|
575 |
+
logging.info("WebDataset will not shuffle files within the tar files.")
|
576 |
+
|
577 |
+
self._dataset = (
|
578 |
+
self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
|
579 |
+
.to_tuple('audio', 'key')
|
580 |
+
.pipe(self._filter)
|
581 |
+
.map(f=self._build_sample)
|
582 |
+
)
|
583 |
+
|
584 |
+
def _filter(self, iterator):
|
585 |
+
"""This function is used to remove samples that have been filtered out by ASRSpeechLabel already.
|
586 |
+
Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
|
587 |
+
that was filtered out (e.g. for duration).
|
588 |
+
Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
|
589 |
+
which may make your code hang as one process will finish before the other.
|
590 |
+
"""
|
591 |
+
|
592 |
+
class TarredAudioFilter:
|
593 |
+
def __init__(self, collection, file_occurence):
|
594 |
+
self.iterator = iterator
|
595 |
+
self.collection = collection
|
596 |
+
self.file_occurence = file_occurence
|
597 |
+
self._iterable = self._internal_generator()
|
598 |
+
|
599 |
+
def __iter__(self):
|
600 |
+
self._iterable = self._internal_generator()
|
601 |
+
return self
|
602 |
+
|
603 |
+
def __next__(self):
|
604 |
+
try:
|
605 |
+
values = next(self._iterable)
|
606 |
+
except StopIteration:
|
607 |
+
# reset generator
|
608 |
+
self._iterable = self._internal_generator()
|
609 |
+
values = next(self._iterable)
|
610 |
+
|
611 |
+
return values
|
612 |
+
|
613 |
+
def _internal_generator(self):
|
614 |
+
"""
|
615 |
+
WebDataset requires an Iterator, but we require an iterable that yields 1-or-more
|
616 |
+
values per value inside self.iterator.
|
617 |
+
|
618 |
+
Therefore wrap the iterator with a generator function that will yield 1-or-more
|
619 |
+
values per sample in the iterator.
|
620 |
+
"""
|
621 |
+
for _, tup in enumerate(self.iterator):
|
622 |
+
audio_bytes, audio_filename = tup
|
623 |
+
|
624 |
+
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
|
625 |
+
if audio_filename in self.file_occurence:
|
626 |
+
for j in range(0, self.file_occurence[file_id]):
|
627 |
+
if j == 0:
|
628 |
+
audio_filename = file_id
|
629 |
+
else:
|
630 |
+
audio_filename = file_id + "-sub" + str(j)
|
631 |
+
yield audio_bytes, audio_filename
|
632 |
+
|
633 |
+
return TarredAudioFilter(self.collection, self.file_occurence)
|
634 |
+
|
635 |
+
def _build_sample(self, tup):
|
636 |
+
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
|
637 |
+
"""
|
638 |
+
audio_bytes, audio_filename = tup
|
639 |
+
# Grab manifest entry from self.collection
|
640 |
+
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
|
641 |
+
|
642 |
+
manifest_idx = self.collection.mapping[file_id]
|
643 |
+
manifest_entry = self.collection[manifest_idx]
|
644 |
+
|
645 |
+
offset = manifest_entry.offset
|
646 |
+
if offset is None:
|
647 |
+
offset = 0
|
648 |
+
|
649 |
+
# Convert audio bytes to IO stream for processing (for SoundFile to read)
|
650 |
+
audio_filestream = io.BytesIO(audio_bytes)
|
651 |
+
features = self.featurizer.process(
|
652 |
+
audio_filestream, offset=offset, duration=manifest_entry.duration, trim=self.trim,
|
653 |
+
)
|
654 |
+
|
655 |
+
audio_filestream.close()
|
656 |
+
|
657 |
+
# Audio features
|
658 |
+
f, fl = features, torch.tensor(features.shape[0]).long()
|
659 |
+
|
660 |
+
t = self.label2id[manifest_entry.label]
|
661 |
+
tl = 1 # For compatibility with collate_fn used later
|
662 |
+
|
663 |
+
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
664 |
+
|
665 |
+
def __iter__(self):
|
666 |
+
return self._dataset.__iter__()
|
667 |
+
|
668 |
+
def __len__(self):
|
669 |
+
return len(self.collection)
|
670 |
+
|
671 |
+
|
672 |
+
class TarredAudioToClassificationLabelDataset(_TarredAudioLabelDataset):
|
673 |
+
"""
|
674 |
+
A similar Dataset to the AudioToClassificationLabelDataset, but which loads tarred audio files.
|
675 |
+
|
676 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToClassificationLabelDataset),
|
677 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
678 |
+
contain the information for one audio file, including at least the transcript and name of the audio
|
679 |
+
file within the tarball.
|
680 |
+
|
681 |
+
Valid formats for the audio_tar_filepaths argument include:
|
682 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
683 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
684 |
+
|
685 |
+
See the WebDataset documentation for more information about accepted data and input formats.
|
686 |
+
|
687 |
+
If using multiple processes the number of shards should be divisible by the number of workers to ensure an
|
688 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
689 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
690 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
691 |
+
|
692 |
+
Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
|
693 |
+
replaced by shuffle_n (int).
|
694 |
+
|
695 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
696 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
697 |
+
|
698 |
+
Args:
|
699 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
700 |
+
string (can be brace-expandable).
|
701 |
+
manifest_filepath (str): Path to the manifest.
|
702 |
+
labels (list): Dataset parameter.
|
703 |
+
List of target classes that can be output by the speaker recognition model.
|
704 |
+
featurizer
|
705 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
706 |
+
See WebDataset documentation for more details.
|
707 |
+
Defaults to 0.
|
708 |
+
min_duration (float): Dataset parameter.
|
709 |
+
All training files which have a duration less than min_duration
|
710 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
711 |
+
Defaults to 0.1.
|
712 |
+
max_duration (float): Dataset parameter.
|
713 |
+
All training files which have a duration more than max_duration
|
714 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
715 |
+
Defaults to None.
|
716 |
+
trim(bool): Whether to use trim silence from beginning and end
|
717 |
+
of audio signal using librosa.effects.trim().
|
718 |
+
Defaults to False.
|
719 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
720 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
721 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
722 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
723 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
724 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
725 |
+
dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
|
726 |
+
|
727 |
+
.. warning::
|
728 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
729 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
730 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
731 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
732 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
733 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
734 |
+
or test datasets.
|
735 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
736 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
737 |
+
is_regression_task (bool): Whether it is a regression task. Defualts to False.
|
738 |
+
"""
|
739 |
+
|
740 |
+
def _collate_fn(self, batch):
|
741 |
+
return _speech_collate_fn(batch, pad_id=0)
|
742 |
+
|
743 |
+
|
744 |
+
class TarredAudioToSpeechLabelDataset(_TarredAudioLabelDataset):
|
745 |
+
"""
|
746 |
+
A similar Dataset to the AudioToSpeechLabelDataset, but which loads tarred audio files.
|
747 |
+
|
748 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToSpeechLabelDataset),
|
749 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
750 |
+
contain the information for one audio file, including at least the transcript and name of the audio
|
751 |
+
file within the tarball.
|
752 |
+
|
753 |
+
Valid formats for the audio_tar_filepaths argument include:
|
754 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
755 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
756 |
+
|
757 |
+
See the WebDataset documentation for more information about accepted data and input formats.
|
758 |
+
|
759 |
+
If using multiple processes the number of shards should be divisible by the number of workers to ensure an
|
760 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
761 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
762 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
763 |
+
|
764 |
+
Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
|
765 |
+
replaced by shuffle_n (int).
|
766 |
+
|
767 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
768 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
769 |
+
|
770 |
+
Args:
|
771 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
772 |
+
string (can be brace-expandable).
|
773 |
+
manifest_filepath (str): Path to the manifest.
|
774 |
+
labels (list): Dataset parameter.
|
775 |
+
List of target classes that can be output by the speaker recognition model.
|
776 |
+
featurizer
|
777 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
778 |
+
See WebDataset documentation for more details.
|
779 |
+
Defaults to 0.
|
780 |
+
min_duration (float): Dataset parameter.
|
781 |
+
All training files which have a duration less than min_duration
|
782 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
783 |
+
Defaults to 0.1.
|
784 |
+
max_duration (float): Dataset parameter.
|
785 |
+
All training files which have a duration more than max_duration
|
786 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
787 |
+
Defaults to None.
|
788 |
+
trim(bool): Whether to use trim silence from beginning and end
|
789 |
+
of audio signal using librosa.effects.trim().
|
790 |
+
Defaults to False.
|
791 |
+
window_length_in_sec (float): time length of window/slice (in seconds) # Pass this only for speaker recognition and VAD task
|
792 |
+
shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task. in a batch # Pass this only for VAD task during inference.
|
793 |
+
normalize_audio (bool): Whether to normalize audio signal. Defaults to False.
|
794 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
795 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
796 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
797 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
798 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
799 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
800 |
+
dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
|
801 |
+
|
802 |
+
.. warning::
|
803 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
804 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
805 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
806 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
807 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
808 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
809 |
+
or test datasets.
|
810 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
811 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
812 |
+
"""
|
813 |
+
|
814 |
+
def __init__(
|
815 |
+
self,
|
816 |
+
*,
|
817 |
+
audio_tar_filepaths: Union[str, List[str]],
|
818 |
+
manifest_filepath: Union[str, List[str]],
|
819 |
+
labels: List[str],
|
820 |
+
featurizer,
|
821 |
+
shuffle_n: int = 0,
|
822 |
+
min_duration: Optional[float] = 0.1,
|
823 |
+
max_duration: Optional[float] = None,
|
824 |
+
trim: bool = False,
|
825 |
+
window_length_in_sec: Optional[float] = 8,
|
826 |
+
shift_length_in_sec: Optional[float] = 1,
|
827 |
+
normalize_audio: bool = False,
|
828 |
+
shard_strategy: str = "scatter",
|
829 |
+
global_rank: int = 0,
|
830 |
+
world_size: int = 0,
|
831 |
+
):
|
832 |
+
logging.info("Window/slice length considered for collate func is {}".format(window_length_in_sec))
|
833 |
+
logging.info("Shift length considered for collate func is {}".format(shift_length_in_sec))
|
834 |
+
self.window_length_in_sec = window_length_in_sec
|
835 |
+
self.shift_length_in_sec = shift_length_in_sec
|
836 |
+
self.normalize_audio = normalize_audio
|
837 |
+
|
838 |
+
super().__init__(
|
839 |
+
audio_tar_filepaths=audio_tar_filepaths,
|
840 |
+
manifest_filepath=manifest_filepath,
|
841 |
+
labels=labels,
|
842 |
+
featurizer=featurizer,
|
843 |
+
shuffle_n=shuffle_n,
|
844 |
+
min_duration=min_duration,
|
845 |
+
max_duration=max_duration,
|
846 |
+
trim=trim,
|
847 |
+
shard_strategy=shard_strategy,
|
848 |
+
global_rank=global_rank,
|
849 |
+
world_size=world_size,
|
850 |
+
)
|
851 |
+
|
852 |
+
def fixed_seq_collate_fn(self, batch):
|
853 |
+
return _fixed_seq_collate_fn(self, batch)
|
854 |
+
|
855 |
+
def sliced_seq_collate_fn(self, batch):
|
856 |
+
raise NotImplementedError
|
857 |
+
|
858 |
+
def vad_frame_seq_collate_fn(self, batch):
|
859 |
+
return _vad_frame_seq_collate_fn(self, batch)
|
860 |
+
|
861 |
+
|
862 |
+
class AudioToMultiLabelDataset(Dataset):
|
863 |
+
"""
|
864 |
+
Dataset that loads a json file containing paths to audio files, durations (in seconds), and a sequence of labels.
|
865 |
+
Each new line is a different sample. Example below:
|
866 |
+
{"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
|
867 |
+
"0 1 1 0 1", "offset": offset_in_sec_0}
|
868 |
+
...
|
869 |
+
{"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
|
870 |
+
"0 1 0 0 1", "offset": offset_in_sec_n}
|
871 |
+
Args:
|
872 |
+
manifest_filepath (Union[str, List[str]]): Path to manifest json as described above. Can
|
873 |
+
be comma-separated paths.
|
874 |
+
labels (Optional[list]): String containing all the possible labels to map to
|
875 |
+
if None then automatically picks from ASRSpeechLabel collection.
|
876 |
+
min_duration (float): Dataset parameter.
|
877 |
+
All training files which have a duration less than min_duration
|
878 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
879 |
+
Defaults to 0.1.
|
880 |
+
max_duration (float): Dataset parameter.
|
881 |
+
All training files which have a duration more than max_duration
|
882 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
883 |
+
Defaults to None.
|
884 |
+
trim (bool): Whether to use trim silence from beginning and end
|
885 |
+
of audio signal using librosa.effects.trim().
|
886 |
+
Defaults to False.
|
887 |
+
window_length_in_sec (float): length of window/slice (in seconds)
|
888 |
+
Use this for speaker recognition and VAD tasks.
|
889 |
+
shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch
|
890 |
+
Use this for VAD task during inference.
|
891 |
+
normalize_audio (bool): Whether to normalize audio signal.
|
892 |
+
Defaults to False.
|
893 |
+
is_regression_task (bool): Whether the dataset is for a regression task instead of classification.
|
894 |
+
Defaults to False.
|
895 |
+
cal_labels_occurrence (bool): Whether to calculate occurrence of labels
|
896 |
+
Defaults to False.
|
897 |
+
delimiter (Optional[str]): Delimiter to use when splitting the label string, default to None.
|
898 |
+
normalize_audio_db (Optional[float]): normalize audio signal to a target db, default to None.
|
899 |
+
"""
|
900 |
+
|
901 |
+
@property
|
902 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
903 |
+
"""Returns definitions of module output ports.
|
904 |
+
"""
|
905 |
+
|
906 |
+
output_types = {
|
907 |
+
'audio_signal': NeuralType(
|
908 |
+
('B', 'T'),
|
909 |
+
AudioSignal(freq=self._sample_rate)
|
910 |
+
if self is not None and hasattr(self, '_sample_rate')
|
911 |
+
else AudioSignal(),
|
912 |
+
),
|
913 |
+
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
914 |
+
}
|
915 |
+
|
916 |
+
if self.is_regression_task:
|
917 |
+
output_types.update(
|
918 |
+
{
|
919 |
+
'targets': NeuralType(tuple('B, T'), RegressionValuesType()),
|
920 |
+
'targets_length': NeuralType(tuple('B'), LengthsType()),
|
921 |
+
}
|
922 |
+
)
|
923 |
+
else:
|
924 |
+
output_types.update(
|
925 |
+
{'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),}
|
926 |
+
)
|
927 |
+
|
928 |
+
return output_types
|
929 |
+
|
930 |
+
def __init__(
|
931 |
+
self,
|
932 |
+
*,
|
933 |
+
manifest_filepath: Union[str, List[str]],
|
934 |
+
sample_rate: int,
|
935 |
+
labels: Optional[List[str]] = None,
|
936 |
+
int_values: bool = False,
|
937 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
938 |
+
min_duration: Optional[float] = 0.1,
|
939 |
+
max_duration: Optional[float] = None,
|
940 |
+
trim_silence: bool = False,
|
941 |
+
is_regression_task: bool = False,
|
942 |
+
cal_labels_occurrence: Optional[bool] = False,
|
943 |
+
delimiter: Optional[str] = None,
|
944 |
+
normalize_audio_db: Optional[float] = None,
|
945 |
+
):
|
946 |
+
super().__init__()
|
947 |
+
if isinstance(manifest_filepath, str):
|
948 |
+
manifest_filepath = manifest_filepath.split(',')
|
949 |
+
|
950 |
+
self.delimiter = delimiter
|
951 |
+
self.normalize_audio_db = normalize_audio_db
|
952 |
+
|
953 |
+
self.collection = collections.ASRSpeechLabel(
|
954 |
+
manifests_files=manifest_filepath,
|
955 |
+
min_duration=min_duration,
|
956 |
+
max_duration=max_duration,
|
957 |
+
is_regression_task=is_regression_task,
|
958 |
+
cal_labels_occurrence=cal_labels_occurrence,
|
959 |
+
delimiter=delimiter,
|
960 |
+
)
|
961 |
+
|
962 |
+
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
|
963 |
+
self.trim = trim_silence
|
964 |
+
self.is_regression_task = is_regression_task
|
965 |
+
self.id2occurrence = {}
|
966 |
+
self.labels_occurrence = None
|
967 |
+
|
968 |
+
if not is_regression_task:
|
969 |
+
self.labels = labels if labels else self._get_label_set()
|
970 |
+
self.num_classes = len(self.labels) if self.labels is not None else 1
|
971 |
+
self.label2id, self.id2label = {}, {}
|
972 |
+
for label_id, label in enumerate(self.labels):
|
973 |
+
self.label2id[label] = label_id
|
974 |
+
self.id2label[label_id] = label
|
975 |
+
if cal_labels_occurrence:
|
976 |
+
self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
|
977 |
+
self.labels_occurrence.append(self.id2occurrence[label_id])
|
978 |
+
|
979 |
+
for idx in range(len(self.labels[:5])):
|
980 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
981 |
+
else:
|
982 |
+
self.labels = []
|
983 |
+
self.num_classes = 1
|
984 |
+
|
985 |
+
def _get_label_set(self):
|
986 |
+
labels = []
|
987 |
+
for sample in self.collection:
|
988 |
+
label_str = sample.label
|
989 |
+
if label_str:
|
990 |
+
label_str_list = label_str.split(self.delimiter) if self.delimiter else label_str.split()
|
991 |
+
labels.extend(label_str_list)
|
992 |
+
return sorted(set(labels))
|
993 |
+
|
994 |
+
def _label_str_to_tensor(self, label_str: str):
|
995 |
+
labels = label_str.split(self.delimiter) if self.delimiter else label_str.split()
|
996 |
+
|
997 |
+
if self.is_regression_task:
|
998 |
+
labels = [float(s) for s in labels]
|
999 |
+
labels = torch.tensor(labels).float()
|
1000 |
+
else:
|
1001 |
+
labels = [self.label2id[s] for s in labels]
|
1002 |
+
labels = torch.tensor(labels).long()
|
1003 |
+
return labels
|
1004 |
+
|
1005 |
+
def __len__(self):
|
1006 |
+
return len(self.collection)
|
1007 |
+
|
1008 |
+
def __getitem__(self, index):
|
1009 |
+
sample = self.collection[index]
|
1010 |
+
|
1011 |
+
offset = sample.offset
|
1012 |
+
|
1013 |
+
if offset is None:
|
1014 |
+
offset = 0
|
1015 |
+
|
1016 |
+
features = self.featurizer.process(
|
1017 |
+
sample.audio_file,
|
1018 |
+
offset=offset,
|
1019 |
+
duration=sample.duration,
|
1020 |
+
trim=self.trim,
|
1021 |
+
normalize_db=self.normalize_audio_db,
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
f, fl = features, torch.tensor(features.size(0)).long()
|
1025 |
+
|
1026 |
+
t = self._label_str_to_tensor(sample.label)
|
1027 |
+
|
1028 |
+
tl = torch.tensor(t.size(0)).long()
|
1029 |
+
|
1030 |
+
return f, fl, t, tl
|
1031 |
+
|
1032 |
+
def _collate_fn(self, batch):
|
1033 |
+
return _speech_collate_fn(batch, pad_id=0)
|
1034 |
+
|
1035 |
+
|
1036 |
+
class TarredAudioToMultiLabelDataset(IterableDataset):
|
1037 |
+
"""
|
1038 |
+
A similar Dataset to the AudioToMultiLabelDataset, but which loads tarred audio files.
|
1039 |
+
|
1040 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToSpeechLabelDataset),
|
1041 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
1042 |
+
contain the information for one audio file, including at least the transcript and name of the audio
|
1043 |
+
file within the tarball.
|
1044 |
+
|
1045 |
+
Valid formats for the audio_tar_filepaths argument include:
|
1046 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
1047 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
1048 |
+
|
1049 |
+
See the WebDataset documentation for more information about accepted data and input formats.
|
1050 |
+
|
1051 |
+
If using multiple processes the number of shards should be divisible by the number of workers to ensure an
|
1052 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
1053 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
1054 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
1055 |
+
|
1056 |
+
Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
|
1057 |
+
replaced by shuffle_n (int).
|
1058 |
+
|
1059 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
1060 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
1061 |
+
|
1062 |
+
Args:
|
1063 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
1064 |
+
string (can be brace-expandable).
|
1065 |
+
manifest_filepath (str): Path to the manifest.
|
1066 |
+
labels (list): Dataset parameter.
|
1067 |
+
List of target classes that can be output by the speaker recognition model.
|
1068 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
1069 |
+
See WebDataset documentation for more details.
|
1070 |
+
Defaults to 0.
|
1071 |
+
min_duration (float): Dataset parameter.
|
1072 |
+
All training files which have a duration less than min_duration
|
1073 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
1074 |
+
Defaults to 0.1.
|
1075 |
+
max_duration (float): Dataset parameter.
|
1076 |
+
All training files which have a duration more than max_duration
|
1077 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
1078 |
+
Defaults to None.
|
1079 |
+
trim(bool): Whether to use trim silence from beginning and end
|
1080 |
+
of audio signal using librosa.effects.trim().
|
1081 |
+
Defaults to False.
|
1082 |
+
window_length_in_sec (float): time length of window/slice (in seconds) # Pass this only for speaker recognition and VAD task
|
1083 |
+
shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task. in a batch # Pass this only for VAD task during inference.
|
1084 |
+
normalize_audio (bool): Whether to normalize audio signal. Defaults to False.
|
1085 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
1086 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
1087 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
1088 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
1089 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
1090 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
1091 |
+
dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
|
1092 |
+
|
1093 |
+
.. warning::
|
1094 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
1095 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
1096 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
1097 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
1098 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
1099 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
1100 |
+
or test datasets.
|
1101 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
1102 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
1103 |
+
delimiter (Optional[str]): Delimiter to use when splitting the label string, default to None.
|
1104 |
+
normalize_audio_db (Optional[float]): normalize audio signal to a target db, default to None.
|
1105 |
+
"""
|
1106 |
+
|
1107 |
+
def __init__(
|
1108 |
+
self,
|
1109 |
+
*,
|
1110 |
+
audio_tar_filepaths: Union[str, List[str]],
|
1111 |
+
manifest_filepath: Union[str, List[str]],
|
1112 |
+
sample_rate: int,
|
1113 |
+
labels: Optional[List[str]] = None,
|
1114 |
+
shuffle_n: int = 0,
|
1115 |
+
int_values: bool = False,
|
1116 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
1117 |
+
min_duration: Optional[float] = 0.1,
|
1118 |
+
max_duration: Optional[float] = None,
|
1119 |
+
trim_silence: bool = False,
|
1120 |
+
is_regression_task: bool = False,
|
1121 |
+
shard_strategy: str = "scatter",
|
1122 |
+
global_rank: int = 0,
|
1123 |
+
world_size: int = 0,
|
1124 |
+
delimiter: Optional[str] = None,
|
1125 |
+
normalize_audio_db: Optional[float] = None,
|
1126 |
+
):
|
1127 |
+
super().__init__()
|
1128 |
+
if isinstance(manifest_filepath, str):
|
1129 |
+
manifest_filepath = manifest_filepath.split(',')
|
1130 |
+
|
1131 |
+
self.trim = trim_silence
|
1132 |
+
self.is_regression_task = is_regression_task
|
1133 |
+
self.delimiter = delimiter
|
1134 |
+
self.normalize_audio_db = normalize_audio_db
|
1135 |
+
|
1136 |
+
self.collection = collections.ASRSpeechLabel(
|
1137 |
+
manifests_files=manifest_filepath,
|
1138 |
+
min_duration=min_duration,
|
1139 |
+
max_duration=max_duration,
|
1140 |
+
is_regression_task=is_regression_task,
|
1141 |
+
index_by_file_id=True,
|
1142 |
+
)
|
1143 |
+
self.file_occurence = count_occurence(self.collection.mapping)
|
1144 |
+
|
1145 |
+
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
|
1146 |
+
|
1147 |
+
if not is_regression_task:
|
1148 |
+
self.labels = labels if labels else self._get_label_set()
|
1149 |
+
self.num_classes = len(self.labels) if self.labels is not None else 1
|
1150 |
+
self.label2id, self.id2label = {}, {}
|
1151 |
+
for label_id, label in enumerate(self.labels):
|
1152 |
+
self.label2id[label] = label_id
|
1153 |
+
self.id2label[label_id] = label
|
1154 |
+
for idx in range(len(self.labels[:5])):
|
1155 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
1156 |
+
else:
|
1157 |
+
self.labels = []
|
1158 |
+
self.num_classes = 1
|
1159 |
+
|
1160 |
+
audio_tar_filepaths = expand_sharded_filepaths(
|
1161 |
+
sharded_filepaths=audio_tar_filepaths,
|
1162 |
+
shard_strategy=shard_strategy,
|
1163 |
+
world_size=world_size,
|
1164 |
+
global_rank=global_rank,
|
1165 |
+
)
|
1166 |
+
# Put together WebDataset
|
1167 |
+
self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)
|
1168 |
+
|
1169 |
+
if shuffle_n > 0:
|
1170 |
+
self._dataset = self._dataset.shuffle(shuffle_n)
|
1171 |
+
else:
|
1172 |
+
logging.info("WebDataset will not shuffle files within the tar files.")
|
1173 |
+
|
1174 |
+
self._dataset = (
|
1175 |
+
self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
|
1176 |
+
.to_tuple('audio', 'key')
|
1177 |
+
.pipe(self._filter)
|
1178 |
+
.map(f=self._build_sample)
|
1179 |
+
)
|
1180 |
+
|
1181 |
+
def _get_label_set(self):
|
1182 |
+
labels = []
|
1183 |
+
for sample in self.collection:
|
1184 |
+
label_str = sample.label
|
1185 |
+
if label_str:
|
1186 |
+
label_str_list = label_str.split(self.delimiter) if self.delimiter else label_str.split()
|
1187 |
+
labels.extend(label_str_list)
|
1188 |
+
return sorted(set(labels))
|
1189 |
+
|
1190 |
+
def _label_str_to_tensor(self, label_str: str):
|
1191 |
+
labels = label_str.split(self.delimiter) if self.delimiter else label_str.split()
|
1192 |
+
|
1193 |
+
if self.is_regression_task:
|
1194 |
+
labels = [float(s) for s in labels]
|
1195 |
+
labels = torch.tensor(labels).float()
|
1196 |
+
else:
|
1197 |
+
labels = [self.label2id[s] for s in labels]
|
1198 |
+
labels = torch.tensor(labels).long()
|
1199 |
+
return labels
|
1200 |
+
|
1201 |
+
def _filter(self, iterator):
|
1202 |
+
"""This function is used to remove samples that have been filtered out by ASRSpeechLabel already.
|
1203 |
+
Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
|
1204 |
+
that was filtered out (e.g. for duration).
|
1205 |
+
Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
|
1206 |
+
which may make your code hang as one process will finish before the other.
|
1207 |
+
"""
|
1208 |
+
|
1209 |
+
class TarredAudioFilter:
|
1210 |
+
def __init__(self, collection, file_occurence):
|
1211 |
+
self.iterator = iterator
|
1212 |
+
self.collection = collection
|
1213 |
+
self.file_occurence = file_occurence
|
1214 |
+
self._iterable = self._internal_generator()
|
1215 |
+
|
1216 |
+
def __iter__(self):
|
1217 |
+
self._iterable = self._internal_generator()
|
1218 |
+
return self
|
1219 |
+
|
1220 |
+
def __next__(self):
|
1221 |
+
try:
|
1222 |
+
values = next(self._iterable)
|
1223 |
+
except StopIteration:
|
1224 |
+
# reset generator
|
1225 |
+
self._iterable = self._internal_generator()
|
1226 |
+
values = next(self._iterable)
|
1227 |
+
|
1228 |
+
return values
|
1229 |
+
|
1230 |
+
def _internal_generator(self):
|
1231 |
+
"""
|
1232 |
+
WebDataset requires an Iterator, but we require an iterable that yields 1-or-more
|
1233 |
+
values per value inside self.iterator.
|
1234 |
+
|
1235 |
+
Therefore wrap the iterator with a generator function that will yield 1-or-more
|
1236 |
+
values per sample in the iterator.
|
1237 |
+
"""
|
1238 |
+
for _, tup in enumerate(self.iterator):
|
1239 |
+
audio_bytes, audio_filename = tup
|
1240 |
+
|
1241 |
+
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
|
1242 |
+
if audio_filename in self.file_occurence:
|
1243 |
+
for j in range(0, self.file_occurence[file_id]):
|
1244 |
+
if j == 0:
|
1245 |
+
audio_filename = file_id
|
1246 |
+
else:
|
1247 |
+
audio_filename = file_id + "-sub" + str(j)
|
1248 |
+
yield audio_bytes, audio_filename
|
1249 |
+
|
1250 |
+
return TarredAudioFilter(self.collection, self.file_occurence)
|
1251 |
+
|
1252 |
+
def _build_sample(self, tup):
|
1253 |
+
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
|
1254 |
+
"""
|
1255 |
+
audio_bytes, audio_filename = tup
|
1256 |
+
# Grab manifest entry from self.collection
|
1257 |
+
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
|
1258 |
+
|
1259 |
+
manifest_idx = self.collection.mapping[file_id]
|
1260 |
+
manifest_entry = self.collection[manifest_idx]
|
1261 |
+
|
1262 |
+
offset = manifest_entry.offset
|
1263 |
+
if offset is None:
|
1264 |
+
offset = 0
|
1265 |
+
|
1266 |
+
# Convert audio bytes to IO stream for processing (for SoundFile to read)
|
1267 |
+
audio_filestream = io.BytesIO(audio_bytes)
|
1268 |
+
features = self.featurizer.process(
|
1269 |
+
audio_filestream,
|
1270 |
+
offset=offset,
|
1271 |
+
duration=manifest_entry.duration,
|
1272 |
+
trim=self.trim,
|
1273 |
+
normalize_db=self.normalize_audio_db,
|
1274 |
+
)
|
1275 |
+
|
1276 |
+
audio_filestream.close()
|
1277 |
+
|
1278 |
+
# Audio features
|
1279 |
+
f, fl = features, torch.tensor(features.shape[0]).long()
|
1280 |
+
|
1281 |
+
t = self._label_str_to_tensor(manifest_entry.label)
|
1282 |
+
|
1283 |
+
tl = torch.tensor(t.size(0)).long()
|
1284 |
+
|
1285 |
+
return f, fl, t, tl
|
1286 |
+
|
1287 |
+
def __iter__(self):
|
1288 |
+
return self._dataset.__iter__()
|
1289 |
+
|
1290 |
+
def __len__(self):
|
1291 |
+
return len(self.collection)
|
1292 |
+
|
1293 |
+
def _collate_fn(self, batch):
|
1294 |
+
return _speech_collate_fn(batch, pad_id=0)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label_dataset.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import copy
|
15 |
+
|
16 |
+
from omegaconf import DictConfig
|
17 |
+
|
18 |
+
from nemo.collections.asr.data import audio_to_label
|
19 |
+
from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list, get_chain_dataset
|
20 |
+
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
21 |
+
from nemo.collections.common.data.dataset import ConcatDataset
|
22 |
+
|
23 |
+
|
24 |
+
def get_classification_label_dataset(featurizer, config: dict) -> audio_to_label.AudioToClassificationLabelDataset:
|
25 |
+
"""
|
26 |
+
Instantiates a Classification AudioLabelDataset.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
config: Config of the AudioToClassificationLabelDataset.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
An instance of AudioToClassificationLabelDataset.
|
33 |
+
"""
|
34 |
+
dataset = audio_to_label.AudioToClassificationLabelDataset(
|
35 |
+
manifest_filepath=config['manifest_filepath'],
|
36 |
+
labels=config['labels'],
|
37 |
+
featurizer=featurizer,
|
38 |
+
max_duration=config.get('max_duration', None),
|
39 |
+
min_duration=config.get('min_duration', None),
|
40 |
+
trim=config.get('trim_silence', False),
|
41 |
+
is_regression_task=config.get('is_regression_task', False),
|
42 |
+
cal_labels_occurrence=config.get('cal_labels_occurrence', False),
|
43 |
+
)
|
44 |
+
return dataset
|
45 |
+
|
46 |
+
|
47 |
+
def get_speech_label_dataset(featurizer, config: dict) -> audio_to_label.AudioToSpeechLabelDataset:
|
48 |
+
"""
|
49 |
+
Instantiates a Speech Label (e.g. VAD, speaker recognition) AudioLabelDataset.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
config: Config of the AudioToSpeechLabelDataSet.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
An instance of AudioToSpeechLabelDataset.
|
56 |
+
"""
|
57 |
+
dataset = audio_to_label.AudioToSpeechLabelDataset(
|
58 |
+
manifest_filepath=config['manifest_filepath'],
|
59 |
+
labels=config['labels'],
|
60 |
+
featurizer=featurizer,
|
61 |
+
max_duration=config.get('max_duration', None),
|
62 |
+
min_duration=config.get('min_duration', None),
|
63 |
+
trim=config.get('trim_silence', False),
|
64 |
+
window_length_in_sec=config.get('window_length_in_sec', 0.31),
|
65 |
+
shift_length_in_sec=config.get('shift_length_in_sec', 0.01),
|
66 |
+
normalize_audio=config.get('normalize_audio', False),
|
67 |
+
cal_labels_occurrence=config.get('cal_labels_occurrence', False),
|
68 |
+
)
|
69 |
+
return dataset
|
70 |
+
|
71 |
+
|
72 |
+
def get_tarred_classification_label_dataset(
|
73 |
+
featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int
|
74 |
+
) -> audio_to_label.TarredAudioToClassificationLabelDataset:
|
75 |
+
"""
|
76 |
+
Instantiates a Classification TarredAudioLabelDataset.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
config: Config of the TarredAudioToClassificationLabelDataset.
|
80 |
+
shuffle_n: How many samples to look ahead and load to be shuffled.
|
81 |
+
See WebDataset documentation for more details.
|
82 |
+
global_rank: Global rank of this device.
|
83 |
+
world_size: Global world size in the training method.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
An instance of TarredAudioToClassificationLabelDataset.
|
87 |
+
"""
|
88 |
+
tarred_audio_filepaths = config['tarred_audio_filepaths']
|
89 |
+
manifest_filepaths = config['manifest_filepath']
|
90 |
+
datasets = []
|
91 |
+
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
|
92 |
+
manifest_filepaths = convert_to_config_list(manifest_filepaths)
|
93 |
+
|
94 |
+
bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets
|
95 |
+
if bucketing_weights:
|
96 |
+
for idx, weight in enumerate(bucketing_weights):
|
97 |
+
if not isinstance(weight, int) or weight <= 0:
|
98 |
+
raise ValueError(f"bucket weights must be positive integers")
|
99 |
+
|
100 |
+
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
101 |
+
raise ValueError(
|
102 |
+
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
|
103 |
+
)
|
104 |
+
|
105 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
106 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
107 |
+
):
|
108 |
+
if len(tarred_audio_filepath) == 1:
|
109 |
+
tarred_audio_filepath = tarred_audio_filepath[0]
|
110 |
+
dataset = audio_to_label.TarredAudioToClassificationLabelDataset(
|
111 |
+
audio_tar_filepaths=tarred_audio_filepath,
|
112 |
+
manifest_filepath=manifest_filepath,
|
113 |
+
labels=config['labels'],
|
114 |
+
featurizer=featurizer,
|
115 |
+
shuffle_n=shuffle_n,
|
116 |
+
max_duration=config.get('max_duration', None),
|
117 |
+
min_duration=config.get('min_duration', None),
|
118 |
+
trim=config.get('trim_silence', False),
|
119 |
+
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
120 |
+
global_rank=global_rank,
|
121 |
+
world_size=world_size,
|
122 |
+
is_regression_task=config.get('is_regression_task', False),
|
123 |
+
)
|
124 |
+
|
125 |
+
if bucketing_weights:
|
126 |
+
[datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
|
127 |
+
else:
|
128 |
+
datasets.append(dataset)
|
129 |
+
|
130 |
+
return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
|
131 |
+
|
132 |
+
|
133 |
+
def get_concat_tarred_speech_label_dataset(
|
134 |
+
featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int,
|
135 |
+
):
|
136 |
+
tarred_audio_filepaths = config['tarred_audio_filepaths']
|
137 |
+
manifest_filepaths = config['manifest_filepath']
|
138 |
+
datasets = []
|
139 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
140 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
141 |
+
):
|
142 |
+
conf = copy.deepcopy(config)
|
143 |
+
conf['manifest_filepath'] = manifest_filepath
|
144 |
+
conf['tarred_audio_filepaths'] = tarred_audio_filepath
|
145 |
+
dataset = get_tarred_speech_label_dataset(
|
146 |
+
config=conf, featurizer=featurizer, shuffle_n=shuffle_n, global_rank=global_rank, world_size=world_size,
|
147 |
+
)
|
148 |
+
datasets.append(dataset)
|
149 |
+
|
150 |
+
dataset = ConcatDataset(
|
151 |
+
datasets,
|
152 |
+
sampling_technique=config.get('concat_sampling_technique', 'temperature'),
|
153 |
+
sampling_temperature=config.get('concat_sampling_temperature', 5),
|
154 |
+
sampling_probabilities=config.get('concat_sampling_probabilities', None),
|
155 |
+
global_rank=global_rank,
|
156 |
+
world_size=world_size,
|
157 |
+
shuffle=config['shuffle'],
|
158 |
+
)
|
159 |
+
return dataset
|
160 |
+
|
161 |
+
|
162 |
+
def get_tarred_speech_label_dataset(
|
163 |
+
featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int,
|
164 |
+
) -> audio_to_label.TarredAudioToSpeechLabelDataset:
|
165 |
+
"""
|
166 |
+
InInstantiates a Speech Label (e.g. VAD, speaker recognition) TarredAudioLabelDataset.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
config: Config of the TarredAudioToSpeechLabelDataset.
|
170 |
+
shuffle_n: How many samples to look ahead and load to be shuffled.
|
171 |
+
See WebDataset documentation for more details.
|
172 |
+
global_rank: Global rank of this device.
|
173 |
+
world_size: Global world size in the training method.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
An instance of TarredAudioToSpeechLabelDataset.
|
177 |
+
"""
|
178 |
+
tarred_audio_filepaths = config['tarred_audio_filepaths']
|
179 |
+
manifest_filepaths = config['manifest_filepath']
|
180 |
+
datasets = []
|
181 |
+
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
|
182 |
+
manifest_filepaths = convert_to_config_list(manifest_filepaths)
|
183 |
+
|
184 |
+
bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets
|
185 |
+
if bucketing_weights:
|
186 |
+
for idx, weight in enumerate(bucketing_weights):
|
187 |
+
if not isinstance(weight, int) or weight <= 0:
|
188 |
+
raise ValueError(f"bucket weights must be positive integers")
|
189 |
+
|
190 |
+
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
191 |
+
raise ValueError(
|
192 |
+
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
|
193 |
+
)
|
194 |
+
|
195 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
196 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
197 |
+
):
|
198 |
+
if len(tarred_audio_filepath) == 1:
|
199 |
+
tarred_audio_filepath = tarred_audio_filepath[0]
|
200 |
+
dataset = audio_to_label.TarredAudioToSpeechLabelDataset(
|
201 |
+
audio_tar_filepaths=tarred_audio_filepath,
|
202 |
+
manifest_filepath=manifest_filepath,
|
203 |
+
labels=config['labels'],
|
204 |
+
featurizer=featurizer,
|
205 |
+
shuffle_n=shuffle_n,
|
206 |
+
max_duration=config.get('max_duration', None),
|
207 |
+
min_duration=config.get('min_duration', None),
|
208 |
+
trim=config.get('trim_silence', False),
|
209 |
+
window_length_in_sec=config.get('window_length_in_sec', 8),
|
210 |
+
shift_length_in_sec=config.get('shift_length_in_sec', 0.075),
|
211 |
+
normalize_audio=config.get('normalize_audio', False),
|
212 |
+
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
213 |
+
global_rank=global_rank,
|
214 |
+
world_size=world_size,
|
215 |
+
)
|
216 |
+
|
217 |
+
if bucketing_weights:
|
218 |
+
[datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
|
219 |
+
else:
|
220 |
+
datasets.append(dataset)
|
221 |
+
|
222 |
+
return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
|
223 |
+
|
224 |
+
|
225 |
+
def get_audio_multi_label_dataset(cfg: DictConfig) -> audio_to_label.AudioToMultiLabelDataset:
|
226 |
+
if "augmentor" in cfg:
|
227 |
+
augmentor = process_augmentations(cfg.augmentor)
|
228 |
+
else:
|
229 |
+
augmentor = None
|
230 |
+
|
231 |
+
dataset = audio_to_label.AudioToMultiLabelDataset(
|
232 |
+
manifest_filepath=cfg.get("manifest_filepath"),
|
233 |
+
sample_rate=cfg.get("sample_rate"),
|
234 |
+
labels=cfg.get("labels", None),
|
235 |
+
int_values=cfg.get("int_values", False),
|
236 |
+
augmentor=augmentor,
|
237 |
+
min_duration=cfg.get("min_duration", None),
|
238 |
+
max_duration=cfg.get("max_duration", None),
|
239 |
+
trim_silence=cfg.get("trim_silence", False),
|
240 |
+
is_regression_task=cfg.get("is_regression_task", False),
|
241 |
+
cal_labels_occurrence=cfg.get("cal_labels_occurrence", False),
|
242 |
+
delimiter=cfg.get("delimiter", None),
|
243 |
+
normalize_audio_db=cfg.get("normalize_audio_db", None),
|
244 |
+
)
|
245 |
+
return dataset
|
246 |
+
|
247 |
+
|
248 |
+
def get_tarred_audio_multi_label_dataset(
|
249 |
+
cfg: DictConfig, shuffle_n: int, global_rank: int, world_size: int
|
250 |
+
) -> audio_to_label.TarredAudioToMultiLabelDataset:
|
251 |
+
|
252 |
+
if "augmentor" in cfg:
|
253 |
+
augmentor = process_augmentations(cfg.augmentor)
|
254 |
+
else:
|
255 |
+
augmentor = None
|
256 |
+
|
257 |
+
tarred_audio_filepaths = cfg['tarred_audio_filepaths']
|
258 |
+
manifest_filepaths = cfg['manifest_filepath']
|
259 |
+
datasets = []
|
260 |
+
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
|
261 |
+
manifest_filepaths = convert_to_config_list(manifest_filepaths)
|
262 |
+
|
263 |
+
bucketing_weights = cfg.get('bucketing_weights', None) # For upsampling buckets
|
264 |
+
if bucketing_weights:
|
265 |
+
for idx, weight in enumerate(bucketing_weights):
|
266 |
+
if not isinstance(weight, int) or weight <= 0:
|
267 |
+
raise ValueError(f"bucket weights must be positive integers")
|
268 |
+
|
269 |
+
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
270 |
+
raise ValueError(
|
271 |
+
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
|
272 |
+
)
|
273 |
+
|
274 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
275 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
276 |
+
):
|
277 |
+
if len(tarred_audio_filepath) == 1:
|
278 |
+
tarred_audio_filepath = tarred_audio_filepath[0]
|
279 |
+
|
280 |
+
dataset = audio_to_label.TarredAudioToMultiLabelDataset(
|
281 |
+
audio_tar_filepaths=tarred_audio_filepath,
|
282 |
+
manifest_filepath=manifest_filepath,
|
283 |
+
sample_rate=cfg["sample_rate"],
|
284 |
+
labels=cfg['labels'],
|
285 |
+
shuffle_n=shuffle_n,
|
286 |
+
int_values=cfg.get("int_values", False),
|
287 |
+
augmentor=augmentor,
|
288 |
+
min_duration=cfg.get('min_duration', None),
|
289 |
+
max_duration=cfg.get('max_duration', None),
|
290 |
+
trim_silence=cfg.get('trim_silence', False),
|
291 |
+
is_regression_task=cfg.get('is_regression_task', False),
|
292 |
+
delimiter=cfg.get("delimiter", None),
|
293 |
+
shard_strategy=cfg.get('tarred_shard_strategy', 'scatter'),
|
294 |
+
global_rank=global_rank,
|
295 |
+
world_size=world_size,
|
296 |
+
normalize_audio_db=cfg.get("normalize_audio_db", None),
|
297 |
+
)
|
298 |
+
|
299 |
+
if bucketing_weights:
|
300 |
+
[datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
|
301 |
+
else:
|
302 |
+
datasets.append(dataset)
|
303 |
+
|
304 |
+
return get_chain_dataset(datasets=datasets, ds_config=cfg, rank=global_rank)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text.py
ADDED
@@ -0,0 +1,1366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import io
|
15 |
+
import json
|
16 |
+
import math
|
17 |
+
import multiprocessing
|
18 |
+
import os
|
19 |
+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import braceexpand
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import webdataset as wd
|
25 |
+
from torch.utils.data import ChainDataset
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
29 |
+
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
|
30 |
+
from nemo.collections.common import tokenizers
|
31 |
+
from nemo.collections.common.parts.preprocessing import collections, parsers
|
32 |
+
from nemo.core.classes import Dataset, IterableDataset
|
33 |
+
from nemo.core.neural_types import *
|
34 |
+
from nemo.utils import logging
|
35 |
+
from nemo.utils.data_utils import (
|
36 |
+
DataStoreObject,
|
37 |
+
datastore_object_get,
|
38 |
+
datastore_path_to_webdataset_url,
|
39 |
+
is_datastore_cache_shared,
|
40 |
+
is_datastore_path,
|
41 |
+
is_tarred_path,
|
42 |
+
)
|
43 |
+
from nemo.utils.get_rank import is_global_rank_zero
|
44 |
+
|
45 |
+
__all__ = [
|
46 |
+
'AudioToCharDataset',
|
47 |
+
'AudioToBPEDataset',
|
48 |
+
'TarredAudioToCharDataset',
|
49 |
+
'TarredAudioToBPEDataset',
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
def _speech_collate_fn(batch, pad_id):
|
54 |
+
"""collate batch of audio sig, audio len, tokens, tokens len
|
55 |
+
Args:
|
56 |
+
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
|
57 |
+
LongTensor): A tuple of tuples of signal, signal lengths,
|
58 |
+
encoded tokens, and encoded tokens length. This collate func
|
59 |
+
assumes the signals are 1d torch tensors (i.e. mono audio).
|
60 |
+
"""
|
61 |
+
packed_batch = list(zip(*batch))
|
62 |
+
if len(packed_batch) == 5:
|
63 |
+
_, audio_lengths, _, tokens_lengths, sample_ids = packed_batch
|
64 |
+
elif len(packed_batch) == 4:
|
65 |
+
sample_ids = None
|
66 |
+
_, audio_lengths, _, tokens_lengths = packed_batch
|
67 |
+
else:
|
68 |
+
raise ValueError("Expects 4 or 5 tensors in the batch!")
|
69 |
+
max_audio_len = 0
|
70 |
+
has_audio = audio_lengths[0] is not None
|
71 |
+
if has_audio:
|
72 |
+
max_audio_len = max(audio_lengths).item()
|
73 |
+
max_tokens_len = max(tokens_lengths).item()
|
74 |
+
|
75 |
+
audio_signal, tokens = [], []
|
76 |
+
for b in batch:
|
77 |
+
if len(b) == 5:
|
78 |
+
sig, sig_len, tokens_i, tokens_i_len, _ = b
|
79 |
+
else:
|
80 |
+
sig, sig_len, tokens_i, tokens_i_len = b
|
81 |
+
if has_audio:
|
82 |
+
sig_len = sig_len.item()
|
83 |
+
if sig_len < max_audio_len:
|
84 |
+
pad = (0, max_audio_len - sig_len)
|
85 |
+
sig = torch.nn.functional.pad(sig, pad)
|
86 |
+
audio_signal.append(sig)
|
87 |
+
tokens_i_len = tokens_i_len.item()
|
88 |
+
if tokens_i_len < max_tokens_len:
|
89 |
+
pad = (0, max_tokens_len - tokens_i_len)
|
90 |
+
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
|
91 |
+
tokens.append(tokens_i)
|
92 |
+
|
93 |
+
if has_audio:
|
94 |
+
audio_signal = torch.stack(audio_signal)
|
95 |
+
audio_lengths = torch.stack(audio_lengths)
|
96 |
+
else:
|
97 |
+
audio_signal, audio_lengths = None, None
|
98 |
+
tokens = torch.stack(tokens)
|
99 |
+
tokens_lengths = torch.stack(tokens_lengths)
|
100 |
+
if sample_ids is None:
|
101 |
+
return audio_signal, audio_lengths, tokens, tokens_lengths
|
102 |
+
else:
|
103 |
+
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
|
104 |
+
return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids
|
105 |
+
|
106 |
+
|
107 |
+
class ASRManifestProcessor:
|
108 |
+
"""
|
109 |
+
Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds).
|
110 |
+
Each new line is a different sample. Example below:
|
111 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
|
112 |
+
...
|
113 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
|
114 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
115 |
+
Args:
|
116 |
+
manifest_filepath: Path to manifest json as described above. Can be comma-separated paths.
|
117 |
+
parser: Str for a language specific preprocessor or a callable.
|
118 |
+
max_duration: If audio exceeds this length, do not include in dataset.
|
119 |
+
min_duration: If audio is less than this length, do not include in dataset.
|
120 |
+
max_utts: Limit number of utterances.
|
121 |
+
bos_id: Id of beginning of sequence symbol to append if not None.
|
122 |
+
eos_id: Id of end of sequence symbol to append if not None.
|
123 |
+
pad_id: Id of pad symbol. Defaults to 0.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
manifest_filepath: str,
|
129 |
+
parser: Union[str, Callable],
|
130 |
+
max_duration: Optional[float] = None,
|
131 |
+
min_duration: Optional[float] = None,
|
132 |
+
max_utts: int = 0,
|
133 |
+
bos_id: Optional[int] = None,
|
134 |
+
eos_id: Optional[int] = None,
|
135 |
+
pad_id: int = 0,
|
136 |
+
index_by_file_id: bool = False,
|
137 |
+
):
|
138 |
+
self.parser = parser
|
139 |
+
|
140 |
+
self.collection = collections.ASRAudioText(
|
141 |
+
manifests_files=manifest_filepath,
|
142 |
+
parser=parser,
|
143 |
+
min_duration=min_duration,
|
144 |
+
max_duration=max_duration,
|
145 |
+
max_number=max_utts,
|
146 |
+
index_by_file_id=index_by_file_id,
|
147 |
+
)
|
148 |
+
|
149 |
+
self.eos_id = eos_id
|
150 |
+
self.bos_id = bos_id
|
151 |
+
self.pad_id = pad_id
|
152 |
+
|
153 |
+
def process_text_by_id(self, index: int) -> Tuple[List[int], int]:
|
154 |
+
sample = self.collection[index]
|
155 |
+
return self.process_text_by_sample(sample)
|
156 |
+
|
157 |
+
def process_text_by_file_id(self, file_id: str) -> Tuple[List[int], int]:
|
158 |
+
manifest_idx = self.collection.mapping[file_id][0]
|
159 |
+
sample = self.collection[manifest_idx]
|
160 |
+
return self.process_text_by_sample(sample)
|
161 |
+
|
162 |
+
def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -> Tuple[List[int], int]:
|
163 |
+
t, tl = sample.text_tokens, len(sample.text_tokens)
|
164 |
+
|
165 |
+
if self.bos_id is not None:
|
166 |
+
t = [self.bos_id] + t
|
167 |
+
tl += 1
|
168 |
+
if self.eos_id is not None:
|
169 |
+
t = t + [self.eos_id]
|
170 |
+
tl += 1
|
171 |
+
|
172 |
+
return t, tl
|
173 |
+
|
174 |
+
|
175 |
+
def expand_sharded_filepaths(sharded_filepaths, shard_strategy: str, world_size: int, global_rank: int):
|
176 |
+
valid_shard_strategies = ['scatter', 'replicate']
|
177 |
+
if shard_strategy not in valid_shard_strategies:
|
178 |
+
raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")
|
179 |
+
|
180 |
+
if isinstance(sharded_filepaths, str):
|
181 |
+
# Replace '(' and '[' with '{'
|
182 |
+
brace_keys_open = ['(', '[', '<', '_OP_']
|
183 |
+
for bkey in brace_keys_open:
|
184 |
+
if bkey in sharded_filepaths:
|
185 |
+
sharded_filepaths = sharded_filepaths.replace(bkey, "{")
|
186 |
+
|
187 |
+
# Replace ')' and ']' with '}'
|
188 |
+
brace_keys_close = [')', ']', '>', '_CL_']
|
189 |
+
for bkey in brace_keys_close:
|
190 |
+
if bkey in sharded_filepaths:
|
191 |
+
sharded_filepaths = sharded_filepaths.replace(bkey, "}")
|
192 |
+
|
193 |
+
if isinstance(sharded_filepaths, str):
|
194 |
+
# Brace expand, set escape=False for Windows compatibility
|
195 |
+
sharded_filepaths = list(braceexpand.braceexpand(sharded_filepaths, escape=False))
|
196 |
+
|
197 |
+
# Expand store paths into WebDataset URLs
|
198 |
+
sharded_filepaths = [
|
199 |
+
datastore_path_to_webdataset_url(p) if is_datastore_path(p) and is_tarred_path(p) else p
|
200 |
+
for p in sharded_filepaths
|
201 |
+
]
|
202 |
+
|
203 |
+
# Check for distributed and partition shards accordingly
|
204 |
+
if world_size > 1:
|
205 |
+
if shard_strategy == 'scatter':
|
206 |
+
logging.info("All tarred dataset shards will be scattered evenly across all nodes.")
|
207 |
+
|
208 |
+
if len(sharded_filepaths) % world_size != 0:
|
209 |
+
logging.warning(
|
210 |
+
f"Number of shards in tarred dataset ({len(sharded_filepaths)}) is not divisible "
|
211 |
+
f"by number of distributed workers ({world_size})."
|
212 |
+
)
|
213 |
+
|
214 |
+
begin_idx = (len(sharded_filepaths) // world_size) * global_rank
|
215 |
+
end_idx = begin_idx + len(sharded_filepaths) // world_size
|
216 |
+
sharded_filepaths = sharded_filepaths[begin_idx:end_idx]
|
217 |
+
logging.info(
|
218 |
+
"Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
|
219 |
+
)
|
220 |
+
|
221 |
+
elif shard_strategy == 'replicate':
|
222 |
+
logging.info("All tarred dataset shards will be replicated across all nodes.")
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")
|
225 |
+
|
226 |
+
return sharded_filepaths
|
227 |
+
|
228 |
+
|
229 |
+
def cache_datastore_manifests(
|
230 |
+
manifest_filepaths: Union[str, List[str]],
|
231 |
+
cache_audio: bool = False,
|
232 |
+
shared_cache: Optional[bool] = None,
|
233 |
+
num_workers: Optional[int] = None,
|
234 |
+
max_num_workers: int = 20,
|
235 |
+
):
|
236 |
+
"""Cache manifests and audio from an object store.
|
237 |
+
It is assumed that remote manifests are using relative paths.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
manifest_filepaths: list of paths to manifest files (list of strings or a string with `,` as separator)
|
241 |
+
cache_audio: If True, audio from manifest will also be cached
|
242 |
+
shared_cache: Optional, True if cache is shared across all nodes
|
243 |
+
num_workers: Optional, number of workers to be used for download
|
244 |
+
max_num_workers: max number of workers to be used for download, used when setting num_workers automatically
|
245 |
+
"""
|
246 |
+
if isinstance(manifest_filepaths, str):
|
247 |
+
manifest_filepaths = manifest_filepaths.split(',')
|
248 |
+
|
249 |
+
num_datastore_manifests = sum([is_datastore_path(f) for f in manifest_filepaths])
|
250 |
+
|
251 |
+
if num_datastore_manifests > 0:
|
252 |
+
# Local utility function
|
253 |
+
def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
|
254 |
+
"""Cache manifests and audio data from object store.
|
255 |
+
"""
|
256 |
+
# Determine the number of workers to use
|
257 |
+
if num_workers is None:
|
258 |
+
num_workers = os.cpu_count() - 1
|
259 |
+
num_workers = min(num_workers, max_num_workers)
|
260 |
+
|
261 |
+
# Process each manifest file
|
262 |
+
for manifest_file in manifest_filepaths:
|
263 |
+
# If manifest is on a data store, then cache it.
|
264 |
+
# Otherwise, nothing to do.
|
265 |
+
if is_datastore_path(manifest_file):
|
266 |
+
logging.info('Cache manifest file: %s', manifest_file)
|
267 |
+
cached_manifest_file = DataStoreObject(manifest_file).get()
|
268 |
+
logging.info('Cached at: %s', str(cached_manifest_file))
|
269 |
+
|
270 |
+
if cache_audio:
|
271 |
+
# Each audio file from manifest will be cached.
|
272 |
+
logging.info('Cache audio from manifest file: %s', manifest_file)
|
273 |
+
# Assumes that manifest is using relative paths
|
274 |
+
manifest_dir = os.path.dirname(manifest_file)
|
275 |
+
# Prepare all store objects
|
276 |
+
audio_objects = []
|
277 |
+
with open(cached_manifest_file, 'r') as f:
|
278 |
+
for line in f:
|
279 |
+
item = json.loads(line)
|
280 |
+
store_path = os.path.join(manifest_dir, item['audio_filepath'])
|
281 |
+
audio_objects.append(DataStoreObject(store_path=store_path))
|
282 |
+
|
283 |
+
if num_workers is not None and num_workers > 1:
|
284 |
+
logging.debug('Using multiprocessing with num_workers: %d.', num_workers)
|
285 |
+
with multiprocessing.Pool(processes=num_workers) as p:
|
286 |
+
result = list(
|
287 |
+
tqdm(p.imap(datastore_object_get, audio_objects), total=len(audio_objects))
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
logging.debug('Using a single process.')
|
291 |
+
result = []
|
292 |
+
for audio_object in tqdm(audio_objects):
|
293 |
+
result.append(audio_object.get() is not None)
|
294 |
+
|
295 |
+
if not all(result):
|
296 |
+
raise RuntimeError('Some files not downloaded successfully')
|
297 |
+
logging.info('Caching complete')
|
298 |
+
|
299 |
+
else:
|
300 |
+
# Nothing to do here
|
301 |
+
logging.debug('Manifest is not on a data store: %s', manifest_file)
|
302 |
+
|
303 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
304 |
+
logging.debug('Distributed environment is available and initialized.')
|
305 |
+
|
306 |
+
# Handle distributed environment
|
307 |
+
if shared_cache is None:
|
308 |
+
shared_cache = is_datastore_cache_shared()
|
309 |
+
|
310 |
+
if shared_cache:
|
311 |
+
logging.debug('Cache is shared among nodes, cache data on global rank zero.')
|
312 |
+
is_rank_zero = is_global_rank_zero()
|
313 |
+
else:
|
314 |
+
logging.debug('Cache is not shared among nodes, cache data on local rank zero.')
|
315 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
316 |
+
is_rank_zero = local_rank == 0
|
317 |
+
|
318 |
+
if is_rank_zero:
|
319 |
+
logging.info('Cache data from %s rank 0', 'global' if shared_cache else 'local')
|
320 |
+
cache_data(
|
321 |
+
manifest_filepaths=manifest_filepaths,
|
322 |
+
cache_audio=cache_audio,
|
323 |
+
num_workers=num_workers,
|
324 |
+
max_num_workers=max_num_workers,
|
325 |
+
)
|
326 |
+
logging.debug('Reached barrier')
|
327 |
+
torch.distributed.barrier()
|
328 |
+
|
329 |
+
elif is_global_rank_zero():
|
330 |
+
# Handle non-distributed environment, e.g., if running on a single GPU
|
331 |
+
logging.warning(
|
332 |
+
'Torch distributed is not initialized and caching may be prone to data race conditions. '
|
333 |
+
'Now caching data from global rank 0. If there are other ranks and they pass this '
|
334 |
+
'before rank 0, errors might result.'
|
335 |
+
)
|
336 |
+
cache_data(
|
337 |
+
manifest_filepaths=manifest_filepaths,
|
338 |
+
cache_audio=cache_audio,
|
339 |
+
num_workers=num_workers,
|
340 |
+
max_num_workers=max_num_workers,
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
raise RuntimeError(
|
344 |
+
'Torch distributed is not initialized and caching on nodes other than global rank zero is disabled '
|
345 |
+
'to avoid race condition between different ranks. To ensure distributed environment is '
|
346 |
+
'initialized, please update data config to use `defer_setup = True`.'
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
"""Optionally expand / shard the list of manifests
|
351 |
+
This is made to use the same notation as the sharded audio files
|
352 |
+
|
353 |
+
Args:
|
354 |
+
manifest_filepaths: list of manifest files (the sharded notation)
|
355 |
+
shard_strategy: scatter or replicate (scatter by default)
|
356 |
+
shard_manifests: bool, if False, no sharding / manifest filepath expansion will be attempted
|
357 |
+
global_rank: int, the rank of this worker
|
358 |
+
world_size: int, total number of workers
|
359 |
+
"""
|
360 |
+
|
361 |
+
|
362 |
+
def shard_manifests_if_needed(
|
363 |
+
manifest_filepaths: Union[str, List[str]],
|
364 |
+
shard_strategy: str,
|
365 |
+
shard_manifests: bool,
|
366 |
+
global_rank: int,
|
367 |
+
world_size: int,
|
368 |
+
):
|
369 |
+
if shard_manifests:
|
370 |
+
if not torch.distributed.is_available():
|
371 |
+
logging.warning("Not running in torch.distributed mode. Manifest sharding not available")
|
372 |
+
return manifest_filepaths
|
373 |
+
|
374 |
+
if not torch.distributed.is_initialized():
|
375 |
+
logging.warning(
|
376 |
+
'Manifest sharding was requested but torch.distributed is not initialized '
|
377 |
+
'Did you intend to set the defer_setup flag?'
|
378 |
+
)
|
379 |
+
return manifest_filepaths
|
380 |
+
|
381 |
+
manifest_filepaths = expand_sharded_filepaths(
|
382 |
+
sharded_filepaths=manifest_filepaths,
|
383 |
+
shard_strategy=shard_strategy,
|
384 |
+
world_size=world_size,
|
385 |
+
global_rank=global_rank,
|
386 |
+
)
|
387 |
+
|
388 |
+
return manifest_filepaths
|
389 |
+
|
390 |
+
|
391 |
+
class _AudioTextDataset(Dataset):
|
392 |
+
"""
|
393 |
+
Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds).
|
394 |
+
Each new line is a different sample. Example below:
|
395 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
|
396 |
+
...
|
397 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
|
398 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
399 |
+
Args:
|
400 |
+
manifest_filepath: Path to manifest json as described above. Can be comma-separated paths.
|
401 |
+
parser: Str for a language specific preprocessor or a callable.
|
402 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
403 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
404 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded
|
405 |
+
audio
|
406 |
+
max_duration: If audio exceeds this length, do not include in dataset
|
407 |
+
min_duration: If audio is less than this length, do not include in dataset
|
408 |
+
max_utts: Limit number of utterances
|
409 |
+
trim: whether or not to trim silence. Defaults to False
|
410 |
+
bos_id: Id of beginning of sequence symbol to append if not None
|
411 |
+
eos_id: Id of end of sequence symbol to append if not None
|
412 |
+
pad_id: Id of pad symbol. Defaults to 0
|
413 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
414 |
+
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
|
415 |
+
"""
|
416 |
+
|
417 |
+
@property
|
418 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
419 |
+
"""Returns definitions of module output ports.
|
420 |
+
"""
|
421 |
+
return {
|
422 |
+
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
|
423 |
+
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
424 |
+
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
425 |
+
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
426 |
+
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
427 |
+
}
|
428 |
+
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
manifest_filepath: str,
|
432 |
+
parser: Union[str, Callable],
|
433 |
+
sample_rate: int,
|
434 |
+
int_values: bool = False,
|
435 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
436 |
+
max_duration: Optional[int] = None,
|
437 |
+
min_duration: Optional[int] = None,
|
438 |
+
max_utts: int = 0,
|
439 |
+
trim: bool = False,
|
440 |
+
bos_id: Optional[int] = None,
|
441 |
+
eos_id: Optional[int] = None,
|
442 |
+
pad_id: int = 0,
|
443 |
+
return_sample_id: bool = False,
|
444 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
445 |
+
):
|
446 |
+
if type(manifest_filepath) == str:
|
447 |
+
manifest_filepath = manifest_filepath.split(",")
|
448 |
+
|
449 |
+
# If necessary, cache manifests and audio from object store
|
450 |
+
cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True)
|
451 |
+
|
452 |
+
self.manifest_processor = ASRManifestProcessor(
|
453 |
+
manifest_filepath=manifest_filepath,
|
454 |
+
parser=parser,
|
455 |
+
max_duration=max_duration,
|
456 |
+
min_duration=min_duration,
|
457 |
+
max_utts=max_utts,
|
458 |
+
bos_id=bos_id,
|
459 |
+
eos_id=eos_id,
|
460 |
+
pad_id=pad_id,
|
461 |
+
)
|
462 |
+
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
|
463 |
+
self.trim = trim
|
464 |
+
self.return_sample_id = return_sample_id
|
465 |
+
self.channel_selector = channel_selector
|
466 |
+
|
467 |
+
def get_manifest_sample(self, sample_id):
|
468 |
+
return self.manifest_processor.collection[sample_id]
|
469 |
+
|
470 |
+
def __getitem__(self, index):
|
471 |
+
sample = self.manifest_processor.collection[index]
|
472 |
+
offset = sample.offset
|
473 |
+
|
474 |
+
if offset is None:
|
475 |
+
offset = 0
|
476 |
+
|
477 |
+
features = self.featurizer.process(
|
478 |
+
sample.audio_file,
|
479 |
+
offset=offset,
|
480 |
+
duration=sample.duration,
|
481 |
+
trim=self.trim,
|
482 |
+
orig_sr=sample.orig_sr,
|
483 |
+
channel_selector=self.channel_selector,
|
484 |
+
)
|
485 |
+
f, fl = features, torch.tensor(features.shape[0]).long()
|
486 |
+
|
487 |
+
t, tl = self.manifest_processor.process_text_by_sample(sample=sample)
|
488 |
+
|
489 |
+
if self.return_sample_id:
|
490 |
+
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
|
491 |
+
else:
|
492 |
+
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
493 |
+
|
494 |
+
return output
|
495 |
+
|
496 |
+
def __len__(self):
|
497 |
+
return len(self.manifest_processor.collection)
|
498 |
+
|
499 |
+
def _collate_fn(self, batch):
|
500 |
+
return _speech_collate_fn(batch, pad_id=self.manifest_processor.pad_id)
|
501 |
+
|
502 |
+
|
503 |
+
class AudioToCharDataset(_AudioTextDataset):
|
504 |
+
"""
|
505 |
+
Dataset that loads tensors via a json file containing paths to audio
|
506 |
+
files, transcripts, and durations (in seconds). Each new line is a
|
507 |
+
different sample. Example below:
|
508 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
509 |
+
"/path/to/audio.txt", "duration": 23.147}
|
510 |
+
...
|
511 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
512 |
+
transcription", "offset": 301.75, "duration": 0.82, "utt":
|
513 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
514 |
+
|
515 |
+
Args:
|
516 |
+
manifest_filepath: Path to manifest json as described above. Can
|
517 |
+
be comma-separated paths.
|
518 |
+
labels: String containing all the possible characters to map to
|
519 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
520 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
521 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
522 |
+
object used to augment loaded audio
|
523 |
+
max_duration: If audio exceeds this length, do not include in dataset
|
524 |
+
min_duration: If audio is less than this length, do not include
|
525 |
+
in dataset
|
526 |
+
max_utts: Limit number of utterances
|
527 |
+
blank_index: blank character index, default = -1
|
528 |
+
unk_index: unk_character index, default = -1
|
529 |
+
normalize: whether to normalize transcript text (default): True
|
530 |
+
bos_id: Id of beginning of sequence symbol to append if not None
|
531 |
+
eos_id: Id of end of sequence symbol to append if not None
|
532 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
533 |
+
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
|
534 |
+
"""
|
535 |
+
|
536 |
+
@property
|
537 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
538 |
+
"""Returns definitions of module output ports.
|
539 |
+
"""
|
540 |
+
return {
|
541 |
+
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
|
542 |
+
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
543 |
+
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
544 |
+
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
545 |
+
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
546 |
+
}
|
547 |
+
|
548 |
+
def __init__(
|
549 |
+
self,
|
550 |
+
manifest_filepath: str,
|
551 |
+
labels: Union[str, List[str]],
|
552 |
+
sample_rate: int,
|
553 |
+
int_values: bool = False,
|
554 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
555 |
+
max_duration: Optional[float] = None,
|
556 |
+
min_duration: Optional[float] = None,
|
557 |
+
max_utts: int = 0,
|
558 |
+
blank_index: int = -1,
|
559 |
+
unk_index: int = -1,
|
560 |
+
normalize: bool = True,
|
561 |
+
trim: bool = False,
|
562 |
+
bos_id: Optional[int] = None,
|
563 |
+
eos_id: Optional[int] = None,
|
564 |
+
pad_id: int = 0,
|
565 |
+
parser: Union[str, Callable] = 'en',
|
566 |
+
return_sample_id: bool = False,
|
567 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
568 |
+
):
|
569 |
+
self.labels = labels
|
570 |
+
|
571 |
+
parser = parsers.make_parser(
|
572 |
+
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
|
573 |
+
)
|
574 |
+
|
575 |
+
super().__init__(
|
576 |
+
manifest_filepath=manifest_filepath,
|
577 |
+
parser=parser,
|
578 |
+
sample_rate=sample_rate,
|
579 |
+
int_values=int_values,
|
580 |
+
augmentor=augmentor,
|
581 |
+
max_duration=max_duration,
|
582 |
+
min_duration=min_duration,
|
583 |
+
max_utts=max_utts,
|
584 |
+
trim=trim,
|
585 |
+
bos_id=bos_id,
|
586 |
+
eos_id=eos_id,
|
587 |
+
pad_id=pad_id,
|
588 |
+
return_sample_id=return_sample_id,
|
589 |
+
channel_selector=channel_selector,
|
590 |
+
)
|
591 |
+
|
592 |
+
|
593 |
+
class AudioToBPEDataset(_AudioTextDataset):
|
594 |
+
"""
|
595 |
+
Dataset that loads tensors via a json file containing paths to audio
|
596 |
+
files, transcripts, and durations (in seconds). Each new line is a
|
597 |
+
different sample. Example below:
|
598 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
599 |
+
"/path/to/audio.txt", "duration": 23.147}
|
600 |
+
...
|
601 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
602 |
+
transcription", "offset": 301.75, "duration": 0.82, "utt":
|
603 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
604 |
+
|
605 |
+
In practice, the dataset and manifest used for character encoding and byte pair encoding
|
606 |
+
are exactly the same. The only difference lies in how the dataset tokenizes the text in
|
607 |
+
the manifest.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
manifest_filepath: Path to manifest json as described above. Can
|
611 |
+
be comma-separated paths.
|
612 |
+
tokenizer: A subclass of the Tokenizer wrapper found in the common collection,
|
613 |
+
nemo.collections.common.tokenizers.TokenizerSpec. ASR Models support a subset of
|
614 |
+
all available tokenizers.
|
615 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
616 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
617 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
618 |
+
object used to augment loaded audio
|
619 |
+
max_duration: If audio exceeds this length, do not include in dataset
|
620 |
+
min_duration: If audio is less than this length, do not include
|
621 |
+
in dataset
|
622 |
+
max_utts: Limit number of utterances
|
623 |
+
trim: Whether to trim silence segments
|
624 |
+
use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
|
625 |
+
tokens to beginning and ending of speech respectively.
|
626 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
627 |
+
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
|
628 |
+
"""
|
629 |
+
|
630 |
+
@property
|
631 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
632 |
+
"""Returns definitions of module output ports.
|
633 |
+
"""
|
634 |
+
return {
|
635 |
+
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
|
636 |
+
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
637 |
+
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
638 |
+
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
639 |
+
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
640 |
+
}
|
641 |
+
|
642 |
+
def __init__(
|
643 |
+
self,
|
644 |
+
manifest_filepath: str,
|
645 |
+
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
|
646 |
+
sample_rate: int,
|
647 |
+
int_values: bool = False,
|
648 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
649 |
+
max_duration: Optional[int] = None,
|
650 |
+
min_duration: Optional[int] = None,
|
651 |
+
max_utts: int = 0,
|
652 |
+
trim: bool = False,
|
653 |
+
use_start_end_token: bool = True,
|
654 |
+
return_sample_id: bool = False,
|
655 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
656 |
+
):
|
657 |
+
if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
|
658 |
+
bos_id = tokenizer.bos_id
|
659 |
+
else:
|
660 |
+
bos_id = None
|
661 |
+
|
662 |
+
if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0:
|
663 |
+
eos_id = tokenizer.eos_id
|
664 |
+
else:
|
665 |
+
eos_id = None
|
666 |
+
|
667 |
+
if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0:
|
668 |
+
pad_id = tokenizer.pad_id
|
669 |
+
else:
|
670 |
+
pad_id = 0
|
671 |
+
|
672 |
+
class TokenizerWrapper:
|
673 |
+
def __init__(self, tokenizer):
|
674 |
+
if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
|
675 |
+
self.is_aggregate = True
|
676 |
+
else:
|
677 |
+
self.is_aggregate = False
|
678 |
+
self._tokenizer = tokenizer
|
679 |
+
|
680 |
+
def __call__(self, *args):
|
681 |
+
if isinstance(args[0], List) and self.is_aggregate:
|
682 |
+
t = []
|
683 |
+
for span in args[0]:
|
684 |
+
t.extend(self._tokenizer.text_to_ids(span['str'], span['lang']))
|
685 |
+
return t
|
686 |
+
|
687 |
+
t = self._tokenizer.text_to_ids(*args)
|
688 |
+
return t
|
689 |
+
|
690 |
+
super().__init__(
|
691 |
+
manifest_filepath=manifest_filepath,
|
692 |
+
parser=TokenizerWrapper(tokenizer),
|
693 |
+
sample_rate=sample_rate,
|
694 |
+
int_values=int_values,
|
695 |
+
augmentor=augmentor,
|
696 |
+
max_duration=max_duration,
|
697 |
+
min_duration=min_duration,
|
698 |
+
max_utts=max_utts,
|
699 |
+
bos_id=bos_id,
|
700 |
+
eos_id=eos_id,
|
701 |
+
pad_id=pad_id,
|
702 |
+
trim=trim,
|
703 |
+
return_sample_id=return_sample_id,
|
704 |
+
channel_selector=channel_selector,
|
705 |
+
)
|
706 |
+
|
707 |
+
|
708 |
+
class _TarredAudioToTextDataset(IterableDataset):
|
709 |
+
"""
|
710 |
+
A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files.
|
711 |
+
|
712 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset),
|
713 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
714 |
+
contain the information for one audio file, including at least the transcript and name of the audio
|
715 |
+
file within the tarball.
|
716 |
+
|
717 |
+
Valid formats for the audio_tar_filepaths argument include:
|
718 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
719 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
720 |
+
|
721 |
+
Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
|
722 |
+
This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
|
723 |
+
Supported opening braces - { <=> (, [, < and the special tag _OP_.
|
724 |
+
Supported closing braces - } <=> ), ], > and the special tag _CL_.
|
725 |
+
For SLURM based tasks, we suggest the use of the special tags for ease of use.
|
726 |
+
|
727 |
+
See the WebDataset documentation for more information about accepted data and input formats.
|
728 |
+
|
729 |
+
If using multiple workers the number of shards should be divisible by world_size to ensure an
|
730 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
731 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
732 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
733 |
+
|
734 |
+
Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been
|
735 |
+
replaced by shuffle_n (int).
|
736 |
+
|
737 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
738 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
739 |
+
|
740 |
+
Args:
|
741 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
742 |
+
string (can be brace-expandable).
|
743 |
+
manifest_filepath (str): Path to the manifest.
|
744 |
+
parser (callable): A callable which is used to pre-process the text output.
|
745 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
746 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
747 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
748 |
+
object used to augment loaded audio
|
749 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
750 |
+
See WebDataset documentation for more details.
|
751 |
+
Defaults to 0.
|
752 |
+
min_duration (float): Dataset parameter.
|
753 |
+
All training files which have a duration less than min_duration
|
754 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
755 |
+
Defaults to 0.1.
|
756 |
+
max_duration (float): Dataset parameter.
|
757 |
+
All training files which have a duration more than max_duration
|
758 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
759 |
+
Defaults to None.
|
760 |
+
blank_index (int): Blank character index, defaults to -1.
|
761 |
+
unk_index (int): Unknown character index, defaults to -1.
|
762 |
+
normalize (bool): Dataset parameter.
|
763 |
+
Whether to use automatic text cleaning.
|
764 |
+
It is highly recommended to manually clean text for best results.
|
765 |
+
Defaults to True.
|
766 |
+
trim (bool): Whether to use trim silence from beginning and end
|
767 |
+
of audio signal using librosa.effects.trim().
|
768 |
+
Defaults to False.
|
769 |
+
bos_id (id): Dataset parameter.
|
770 |
+
Beginning of string symbol id used for seq2seq models.
|
771 |
+
Defaults to None.
|
772 |
+
eos_id (id): Dataset parameter.
|
773 |
+
End of string symbol id used for seq2seq models.
|
774 |
+
Defaults to None.
|
775 |
+
pad_id (id): Token used to pad when collating samples in batches.
|
776 |
+
If this is None, pads using 0s.
|
777 |
+
Defaults to None.
|
778 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
779 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
780 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
781 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
782 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
783 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
784 |
+
dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
|
785 |
+
|
786 |
+
.. warning::
|
787 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
788 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
789 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
790 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
791 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
792 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
793 |
+
or test datasets.
|
794 |
+
shard_manifests (bool): Whether or not to try / shard manifests. Defaults to False.
|
795 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
796 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
797 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
798 |
+
"""
|
799 |
+
|
800 |
+
def __init__(
|
801 |
+
self,
|
802 |
+
audio_tar_filepaths: Union[str, List[str]],
|
803 |
+
manifest_filepath: str,
|
804 |
+
parser: Callable,
|
805 |
+
sample_rate: int,
|
806 |
+
int_values: bool = False,
|
807 |
+
augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
|
808 |
+
shuffle_n: int = 0,
|
809 |
+
min_duration: Optional[float] = None,
|
810 |
+
max_duration: Optional[float] = None,
|
811 |
+
trim: bool = False,
|
812 |
+
bos_id: Optional[int] = None,
|
813 |
+
eos_id: Optional[int] = None,
|
814 |
+
pad_id: int = 0,
|
815 |
+
shard_strategy: str = "scatter",
|
816 |
+
shard_manifests: bool = False,
|
817 |
+
global_rank: int = 0,
|
818 |
+
world_size: int = 0,
|
819 |
+
return_sample_id: bool = False,
|
820 |
+
):
|
821 |
+
self.shard_manifests = shard_manifests
|
822 |
+
|
823 |
+
# Shard manifests if necessary and possible and then expand the paths
|
824 |
+
manifest_filepath = shard_manifests_if_needed(
|
825 |
+
shard_manifests=shard_manifests,
|
826 |
+
shard_strategy=shard_strategy,
|
827 |
+
manifest_filepaths=manifest_filepath,
|
828 |
+
world_size=world_size,
|
829 |
+
global_rank=global_rank,
|
830 |
+
)
|
831 |
+
|
832 |
+
# If necessary, cache manifests from object store
|
833 |
+
cache_datastore_manifests(manifest_filepaths=manifest_filepath)
|
834 |
+
|
835 |
+
self.manifest_processor = ASRManifestProcessor(
|
836 |
+
manifest_filepath=manifest_filepath,
|
837 |
+
parser=parser,
|
838 |
+
max_duration=max_duration,
|
839 |
+
min_duration=min_duration,
|
840 |
+
max_utts=0,
|
841 |
+
bos_id=bos_id,
|
842 |
+
eos_id=eos_id,
|
843 |
+
pad_id=pad_id,
|
844 |
+
index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID
|
845 |
+
)
|
846 |
+
|
847 |
+
self.len = self._compute_len()
|
848 |
+
|
849 |
+
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
|
850 |
+
self.trim = trim
|
851 |
+
self.eos_id = eos_id
|
852 |
+
self.bos_id = bos_id
|
853 |
+
self.pad_id = pad_id
|
854 |
+
self.return_sample_id = return_sample_id
|
855 |
+
|
856 |
+
audio_tar_filepaths = expand_sharded_filepaths(
|
857 |
+
sharded_filepaths=audio_tar_filepaths,
|
858 |
+
shard_strategy=shard_strategy,
|
859 |
+
world_size=world_size,
|
860 |
+
global_rank=global_rank,
|
861 |
+
)
|
862 |
+
|
863 |
+
# Put together WebDataset
|
864 |
+
self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)
|
865 |
+
|
866 |
+
if shuffle_n > 0:
|
867 |
+
self._dataset = self._dataset.shuffle(shuffle_n)
|
868 |
+
else:
|
869 |
+
logging.info("WebDataset will not shuffle files within the tar files.")
|
870 |
+
|
871 |
+
self._dataset = (
|
872 |
+
self._dataset.rename(audio='wav;ogg;flac', key='__key__')
|
873 |
+
.to_tuple('audio', 'key')
|
874 |
+
.pipe(self._filter)
|
875 |
+
.pipe(self._loop_offsets)
|
876 |
+
.map(f=self._build_sample)
|
877 |
+
)
|
878 |
+
|
879 |
+
def _filter(self, iterator):
|
880 |
+
"""This function is used to remove samples that have been filtered out by ASRAudioText already.
|
881 |
+
Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
|
882 |
+
that was filtered out (e.g. for duration).
|
883 |
+
Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
|
884 |
+
which may make your code hang as one process will finish before the other.
|
885 |
+
"""
|
886 |
+
|
887 |
+
class TarredAudioFilter:
|
888 |
+
def __init__(self, collection):
|
889 |
+
self.iterator = iterator
|
890 |
+
self.collection = collection
|
891 |
+
|
892 |
+
def __iter__(self):
|
893 |
+
return self
|
894 |
+
|
895 |
+
def __next__(self):
|
896 |
+
while True:
|
897 |
+
audio_bytes, audio_filename = next(self.iterator)
|
898 |
+
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
|
899 |
+
if file_id in self.collection.mapping:
|
900 |
+
return audio_bytes, audio_filename
|
901 |
+
|
902 |
+
return TarredAudioFilter(self.manifest_processor.collection)
|
903 |
+
|
904 |
+
def _loop_offsets(self, iterator):
|
905 |
+
"""This function is used to iterate through utterances with different offsets for each file.
|
906 |
+
"""
|
907 |
+
|
908 |
+
class TarredAudioLoopOffsets:
|
909 |
+
def __init__(self, collection):
|
910 |
+
self.iterator = iterator
|
911 |
+
self.collection = collection
|
912 |
+
self.current_fn = None
|
913 |
+
self.current_bytes = None
|
914 |
+
self.offset_id = 0
|
915 |
+
|
916 |
+
def __iter__(self):
|
917 |
+
return self
|
918 |
+
|
919 |
+
def __next__(self):
|
920 |
+
if self.current_fn is None:
|
921 |
+
self.current_bytes, self.current_fn = next(self.iterator)
|
922 |
+
self.offset_id = 0
|
923 |
+
else:
|
924 |
+
offset_list = self.collection.mapping[self.current_fn]
|
925 |
+
if len(offset_list) == self.offset_id + 1:
|
926 |
+
self.current_bytes, self.current_fn = next(self.iterator)
|
927 |
+
self.offset_id = 0
|
928 |
+
else:
|
929 |
+
self.offset_id += 1
|
930 |
+
|
931 |
+
return self.current_bytes, self.current_fn, self.offset_id
|
932 |
+
|
933 |
+
return TarredAudioLoopOffsets(self.manifest_processor.collection)
|
934 |
+
|
935 |
+
def _collate_fn(self, batch):
|
936 |
+
return _speech_collate_fn(batch, self.pad_id)
|
937 |
+
|
938 |
+
def _build_sample(self, tup):
|
939 |
+
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
|
940 |
+
"""
|
941 |
+
audio_bytes, audio_filename, offset_id = tup
|
942 |
+
|
943 |
+
# Grab manifest entry from self.manifest_preprocessor.collection
|
944 |
+
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
|
945 |
+
manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id]
|
946 |
+
manifest_entry = self.manifest_processor.collection[manifest_idx]
|
947 |
+
|
948 |
+
offset = manifest_entry.offset
|
949 |
+
if offset is None:
|
950 |
+
offset = 0
|
951 |
+
|
952 |
+
# Convert audio bytes to IO stream for processing (for SoundFile to read)
|
953 |
+
audio_filestream = io.BytesIO(audio_bytes)
|
954 |
+
features = self.featurizer.process(
|
955 |
+
audio_filestream,
|
956 |
+
offset=offset,
|
957 |
+
duration=manifest_entry.duration,
|
958 |
+
trim=self.trim,
|
959 |
+
orig_sr=manifest_entry.orig_sr,
|
960 |
+
)
|
961 |
+
audio_filestream.close()
|
962 |
+
|
963 |
+
# Audio features
|
964 |
+
f, fl = features, torch.tensor(features.shape[0]).long()
|
965 |
+
|
966 |
+
# Text features
|
967 |
+
t, tl = manifest_entry.text_tokens, len(manifest_entry.text_tokens)
|
968 |
+
|
969 |
+
self.manifest_processor.process_text_by_sample(sample=manifest_entry)
|
970 |
+
|
971 |
+
if self.bos_id is not None:
|
972 |
+
t = [self.bos_id] + t
|
973 |
+
tl += 1
|
974 |
+
if self.eos_id is not None:
|
975 |
+
t = t + [self.eos_id]
|
976 |
+
tl += 1
|
977 |
+
|
978 |
+
if self.return_sample_id:
|
979 |
+
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), manifest_idx
|
980 |
+
else:
|
981 |
+
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
982 |
+
|
983 |
+
def get_manifest_sample(self, sample_id):
|
984 |
+
return self.manifest_processor.collection[sample_id]
|
985 |
+
|
986 |
+
def __iter__(self):
|
987 |
+
return self._dataset.__iter__()
|
988 |
+
|
989 |
+
def _compute_len(self):
|
990 |
+
if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized():
|
991 |
+
my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda()
|
992 |
+
torch.distributed.all_reduce(my_len)
|
993 |
+
my_len = my_len.int()
|
994 |
+
logging.info(f'Sharded manifests: Total length: {my_len}')
|
995 |
+
else:
|
996 |
+
my_len = len(self.manifest_processor.collection)
|
997 |
+
|
998 |
+
return my_len
|
999 |
+
|
1000 |
+
def __len__(self):
|
1001 |
+
return self.len
|
1002 |
+
|
1003 |
+
|
1004 |
+
class TarredAudioToCharDataset(_TarredAudioToTextDataset):
|
1005 |
+
"""
|
1006 |
+
A similar Dataset to the AudioToCharDataset, but which loads tarred audio files.
|
1007 |
+
|
1008 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset),
|
1009 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
1010 |
+
contain the information for one audio file, including at least the transcript and name of the audio
|
1011 |
+
file within the tarball.
|
1012 |
+
|
1013 |
+
Valid formats for the audio_tar_filepaths argument include:
|
1014 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
1015 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
1016 |
+
|
1017 |
+
See the WebDataset documentation for more information about accepted data and input formats.
|
1018 |
+
|
1019 |
+
If using multiple workers the number of shards should be divisible by world_size to ensure an
|
1020 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
1021 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
1022 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
1023 |
+
|
1024 |
+
Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been
|
1025 |
+
replaced by shuffle_n (int).
|
1026 |
+
|
1027 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
1028 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
1029 |
+
|
1030 |
+
Args:
|
1031 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
1032 |
+
string (can be brace-expandable).
|
1033 |
+
manifest_filepath (str): Path to the manifest.
|
1034 |
+
labels (list): List of characters that can be output by the ASR model.
|
1035 |
+
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
1036 |
+
symbol is automatically added later for models using ctc.
|
1037 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
1038 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
1039 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
1040 |
+
object used to augment loaded audio
|
1041 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
1042 |
+
See WebDataset documentation for more details.
|
1043 |
+
Defaults to 0.
|
1044 |
+
min_duration (float): Dataset parameter.
|
1045 |
+
All training files which have a duration less than min_duration
|
1046 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
1047 |
+
Defaults to 0.1.
|
1048 |
+
max_duration (float): Dataset parameter.
|
1049 |
+
All training files which have a duration more than max_duration
|
1050 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
1051 |
+
Defaults to None.
|
1052 |
+
blank_index (int): Blank character index, defaults to -1.
|
1053 |
+
unk_index (int): Unknown character index, defaults to -1.
|
1054 |
+
normalize (bool): Dataset parameter.
|
1055 |
+
Whether to use automatic text cleaning.
|
1056 |
+
It is highly recommended to manually clean text for best results.
|
1057 |
+
Defaults to True.
|
1058 |
+
trim (bool): Whether to use trim silence from beginning and end
|
1059 |
+
of audio signal using librosa.effects.trim().
|
1060 |
+
Defaults to False.
|
1061 |
+
bos_id (id): Dataset parameter.
|
1062 |
+
Beginning of string symbol id used for seq2seq models.
|
1063 |
+
Defaults to None.
|
1064 |
+
eos_id (id): Dataset parameter.
|
1065 |
+
End of string symbol id used for seq2seq models.
|
1066 |
+
Defaults to None.
|
1067 |
+
pad_id (id): Token used to pad when collating samples in batches.
|
1068 |
+
If this is None, pads using 0s.
|
1069 |
+
Defaults to None.
|
1070 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
1071 |
+
|
1072 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
1073 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
1074 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
1075 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
1076 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
1077 |
+
dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
|
1078 |
+
|
1079 |
+
.. warning::
|
1080 |
+
|
1081 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
1082 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
1083 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
1084 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
1085 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
1086 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
1087 |
+
or test datasets.
|
1088 |
+
|
1089 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
1090 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
1091 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
1092 |
+
"""
|
1093 |
+
|
1094 |
+
def __init__(
|
1095 |
+
self,
|
1096 |
+
audio_tar_filepaths: Union[str, List[str]],
|
1097 |
+
manifest_filepath: str,
|
1098 |
+
labels: List[str],
|
1099 |
+
sample_rate: int,
|
1100 |
+
int_values: bool = False,
|
1101 |
+
augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
|
1102 |
+
shuffle_n: int = 0,
|
1103 |
+
min_duration: Optional[float] = None,
|
1104 |
+
max_duration: Optional[float] = None,
|
1105 |
+
blank_index: int = -1,
|
1106 |
+
unk_index: int = -1,
|
1107 |
+
normalize: bool = True,
|
1108 |
+
trim: bool = False,
|
1109 |
+
bos_id: Optional[int] = None,
|
1110 |
+
eos_id: Optional[int] = None,
|
1111 |
+
parser: Optional[str] = 'en',
|
1112 |
+
pad_id: int = 0,
|
1113 |
+
shard_strategy: str = "scatter",
|
1114 |
+
shard_manifests: bool = False,
|
1115 |
+
global_rank: int = 0,
|
1116 |
+
world_size: int = 0,
|
1117 |
+
return_sample_id: bool = False,
|
1118 |
+
):
|
1119 |
+
self.labels = labels
|
1120 |
+
|
1121 |
+
parser = parsers.make_parser(
|
1122 |
+
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
|
1123 |
+
)
|
1124 |
+
|
1125 |
+
super().__init__(
|
1126 |
+
audio_tar_filepaths=audio_tar_filepaths,
|
1127 |
+
manifest_filepath=manifest_filepath,
|
1128 |
+
parser=parser,
|
1129 |
+
sample_rate=sample_rate,
|
1130 |
+
int_values=int_values,
|
1131 |
+
augmentor=augmentor,
|
1132 |
+
shuffle_n=shuffle_n,
|
1133 |
+
min_duration=min_duration,
|
1134 |
+
max_duration=max_duration,
|
1135 |
+
trim=trim,
|
1136 |
+
bos_id=bos_id,
|
1137 |
+
eos_id=eos_id,
|
1138 |
+
pad_id=pad_id,
|
1139 |
+
shard_strategy=shard_strategy,
|
1140 |
+
shard_manifests=shard_manifests,
|
1141 |
+
global_rank=global_rank,
|
1142 |
+
world_size=world_size,
|
1143 |
+
return_sample_id=return_sample_id,
|
1144 |
+
)
|
1145 |
+
|
1146 |
+
|
1147 |
+
class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
|
1148 |
+
"""
|
1149 |
+
A similar Dataset to the AudioToBPEDataset, but which loads tarred audio files.
|
1150 |
+
|
1151 |
+
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToBPEDataset),
|
1152 |
+
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
|
1153 |
+
contain the information for one audio file, including at least the transcript and name of the audio
|
1154 |
+
file within the tarball.
|
1155 |
+
|
1156 |
+
Valid formats for the audio_tar_filepaths argument include:
|
1157 |
+
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
|
1158 |
+
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
|
1159 |
+
|
1160 |
+
See the WebDataset documentation for more information about accepted data and input formats.
|
1161 |
+
|
1162 |
+
If using multiple workers the number of shards should be divisible by world_size to ensure an
|
1163 |
+
even split among workers. If it is not divisible, logging will give a warning but training will proceed.
|
1164 |
+
In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
|
1165 |
+
is applied. We currently do not check for this, but your program may hang if the shards are uneven!
|
1166 |
+
|
1167 |
+
Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
|
1168 |
+
replaced by shuffle_n (int).
|
1169 |
+
|
1170 |
+
Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
|
1171 |
+
after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
|
1172 |
+
|
1173 |
+
Args:
|
1174 |
+
audio_tar_filepaths: Either a list of audio tarball filepaths, or a
|
1175 |
+
string (can be brace-expandable).
|
1176 |
+
manifest_filepath (str): Path to the manifest.
|
1177 |
+
tokenizer (TokenizerSpec): Either a Word Piece Encoding tokenizer (BERT),
|
1178 |
+
or a Sentence Piece Encoding tokenizer (BPE). The CTC blank
|
1179 |
+
symbol is automatically added later for models using ctc.
|
1180 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
1181 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
1182 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
1183 |
+
object used to augment loaded audio
|
1184 |
+
shuffle_n (int): How many samples to look ahead and load to be shuffled.
|
1185 |
+
See WebDataset documentation for more details.
|
1186 |
+
Defaults to 0.
|
1187 |
+
min_duration (float): Dataset parameter.
|
1188 |
+
All training files which have a duration less than min_duration
|
1189 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
1190 |
+
Defaults to 0.1.
|
1191 |
+
max_duration (float): Dataset parameter.
|
1192 |
+
All training files which have a duration more than max_duration
|
1193 |
+
are dropped. Note: Duration is read from the manifest JSON.
|
1194 |
+
Defaults to None.
|
1195 |
+
trim (bool): Whether to use trim silence from beginning and end
|
1196 |
+
of audio signal using librosa.effects.trim().
|
1197 |
+
Defaults to False.
|
1198 |
+
use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
|
1199 |
+
tokens to beginning and ending of speech respectively.
|
1200 |
+
pad_id (id): Token used to pad when collating samples in batches.
|
1201 |
+
If this is None, pads using 0s.
|
1202 |
+
Defaults to None.
|
1203 |
+
shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
|
1204 |
+
|
1205 |
+
- `scatter`: The default shard strategy applied by WebDataset, where each node gets
|
1206 |
+
a unique set of shards, which are permanently pre-allocated and never changed at runtime.
|
1207 |
+
- `replicate`: Optional shard strategy, where each node gets all of the set of shards
|
1208 |
+
available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
|
1209 |
+
The benefit of replication is that it allows each node to sample data points from the entire
|
1210 |
+
dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
|
1211 |
+
|
1212 |
+
.. warning::
|
1213 |
+
|
1214 |
+
Replicated strategy allows every node to sample the entire set of available tarfiles,
|
1215 |
+
and therefore more than one node may sample the same tarfile, and even sample the same
|
1216 |
+
data points! As such, there is no assured guarantee that all samples in the dataset will be
|
1217 |
+
sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
|
1218 |
+
occasions (when the number of shards is not divisible with ``world_size``), will not sample
|
1219 |
+
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
|
1220 |
+
or test datasets.
|
1221 |
+
|
1222 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
1223 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
1224 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
1225 |
+
"""
|
1226 |
+
|
1227 |
+
def __init__(
|
1228 |
+
self,
|
1229 |
+
audio_tar_filepaths: Union[str, List[str]],
|
1230 |
+
manifest_filepath: str,
|
1231 |
+
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
|
1232 |
+
sample_rate: int,
|
1233 |
+
int_values: bool = False,
|
1234 |
+
augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
|
1235 |
+
shuffle_n: int = 0,
|
1236 |
+
min_duration: Optional[float] = None,
|
1237 |
+
max_duration: Optional[float] = None,
|
1238 |
+
trim: bool = False,
|
1239 |
+
use_start_end_token: bool = True,
|
1240 |
+
shard_strategy: str = "scatter",
|
1241 |
+
shard_manifests: bool = False,
|
1242 |
+
global_rank: int = 0,
|
1243 |
+
world_size: int = 0,
|
1244 |
+
return_sample_id: bool = False,
|
1245 |
+
):
|
1246 |
+
if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
|
1247 |
+
bos_id = tokenizer.bos_id
|
1248 |
+
else:
|
1249 |
+
bos_id = None
|
1250 |
+
|
1251 |
+
if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0:
|
1252 |
+
eos_id = tokenizer.eos_id
|
1253 |
+
else:
|
1254 |
+
eos_id = None
|
1255 |
+
|
1256 |
+
if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0:
|
1257 |
+
pad_id = tokenizer.pad_id
|
1258 |
+
else:
|
1259 |
+
pad_id = 0
|
1260 |
+
|
1261 |
+
class TokenizerWrapper:
|
1262 |
+
def __init__(self, tokenizer):
|
1263 |
+
if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
|
1264 |
+
self.is_aggregate = True
|
1265 |
+
else:
|
1266 |
+
self.is_aggregate = False
|
1267 |
+
self._tokenizer = tokenizer
|
1268 |
+
|
1269 |
+
def __call__(self, *args):
|
1270 |
+
if isinstance(args[0], List) and self.is_aggregate:
|
1271 |
+
t = []
|
1272 |
+
for span in args[0]:
|
1273 |
+
t.extend(self._tokenizer.text_to_ids(span['str'], span['lang']))
|
1274 |
+
return t
|
1275 |
+
|
1276 |
+
t = self._tokenizer.text_to_ids(*args)
|
1277 |
+
return t
|
1278 |
+
|
1279 |
+
super().__init__(
|
1280 |
+
audio_tar_filepaths=audio_tar_filepaths,
|
1281 |
+
manifest_filepath=manifest_filepath,
|
1282 |
+
parser=TokenizerWrapper(tokenizer),
|
1283 |
+
sample_rate=sample_rate,
|
1284 |
+
int_values=int_values,
|
1285 |
+
augmentor=augmentor,
|
1286 |
+
shuffle_n=shuffle_n,
|
1287 |
+
min_duration=min_duration,
|
1288 |
+
max_duration=max_duration,
|
1289 |
+
trim=trim,
|
1290 |
+
bos_id=bos_id,
|
1291 |
+
eos_id=eos_id,
|
1292 |
+
pad_id=pad_id,
|
1293 |
+
shard_strategy=shard_strategy,
|
1294 |
+
shard_manifests=shard_manifests,
|
1295 |
+
global_rank=global_rank,
|
1296 |
+
world_size=world_size,
|
1297 |
+
return_sample_id=return_sample_id,
|
1298 |
+
)
|
1299 |
+
|
1300 |
+
|
1301 |
+
class BucketingDataset(IterableDataset):
|
1302 |
+
"""
|
1303 |
+
A Dataset which wraps another IterableDataset and adopts it for bucketing
|
1304 |
+
Args:
|
1305 |
+
dataset (IterableDataset): The IterableDataset to get wrapped
|
1306 |
+
bucketing_batch_size (int): Number of samples to build a batch
|
1307 |
+
"""
|
1308 |
+
|
1309 |
+
def __init__(
|
1310 |
+
self, dataset: IterableDataset, bucketing_batch_size: int,
|
1311 |
+
):
|
1312 |
+
self.wrapped_dataset = dataset
|
1313 |
+
self.bucketing_batch_size = bucketing_batch_size
|
1314 |
+
super().__init__()
|
1315 |
+
|
1316 |
+
def _collate_fn(self, batch):
|
1317 |
+
return _speech_collate_fn(batch[0], self.wrapped_dataset.pad_id)
|
1318 |
+
|
1319 |
+
def __iter__(self):
|
1320 |
+
return BucketingIterator(
|
1321 |
+
wrapped_ds=self.wrapped_dataset._dataset, bucketing_batch_size=self.bucketing_batch_size
|
1322 |
+
).__iter__()
|
1323 |
+
|
1324 |
+
def __len__(self):
|
1325 |
+
return int(math.ceil(len(self.wrapped_dataset) / float(self.bucketing_batch_size)))
|
1326 |
+
|
1327 |
+
|
1328 |
+
class BucketingIterator:
|
1329 |
+
def __init__(self, wrapped_ds, bucketing_batch_size):
|
1330 |
+
self.wrapped_ds = wrapped_ds
|
1331 |
+
self.wrapped_iter = None
|
1332 |
+
self.bucketing_batch_size = bucketing_batch_size
|
1333 |
+
|
1334 |
+
def __iter__(self):
|
1335 |
+
self.wrapped_iter = iter(self.wrapped_ds)
|
1336 |
+
return self
|
1337 |
+
|
1338 |
+
def __next__(self):
|
1339 |
+
batches = []
|
1340 |
+
for idx in range(self.bucketing_batch_size):
|
1341 |
+
try:
|
1342 |
+
sample = next(self.wrapped_iter)
|
1343 |
+
except StopIteration:
|
1344 |
+
break
|
1345 |
+
batches.append(sample)
|
1346 |
+
if len(batches) == 0:
|
1347 |
+
raise StopIteration
|
1348 |
+
return batches
|
1349 |
+
|
1350 |
+
|
1351 |
+
class RandomizedChainDataset(ChainDataset):
|
1352 |
+
def __init__(self, datasets: Iterable[Dataset], rnd_seed=0) -> None:
|
1353 |
+
super(RandomizedChainDataset, self).__init__(list(datasets))
|
1354 |
+
self.rnd_gen = np.random.RandomState(rnd_seed)
|
1355 |
+
|
1356 |
+
def __iter__(self):
|
1357 |
+
shuffled_order = self.rnd_gen.permutation(len(self.datasets))
|
1358 |
+
for dataset_idx in shuffled_order:
|
1359 |
+
d = self.datasets[dataset_idx]
|
1360 |
+
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
1361 |
+
for idx, x in enumerate(d):
|
1362 |
+
yield x
|
1363 |
+
# in case d is an infinite dataset, we want to break the loop
|
1364 |
+
# so that the other datasets get a chance to yield too
|
1365 |
+
if idx >= len(d) - 1:
|
1366 |
+
break
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dali.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import operator
|
17 |
+
import os.path
|
18 |
+
import time
|
19 |
+
from collections.abc import Iterator
|
20 |
+
from typing import Callable, List, Optional, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from omegaconf import DictConfig
|
24 |
+
|
25 |
+
from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor, expand_sharded_filepaths
|
26 |
+
from nemo.collections.common.parts.preprocessing import parsers
|
27 |
+
from nemo.utils import logging, model_utils
|
28 |
+
|
29 |
+
try:
|
30 |
+
import nvidia.dali as dali
|
31 |
+
from nvidia.dali.pipeline import Pipeline
|
32 |
+
from nvidia.dali.plugin.pytorch import DALIGenericIterator as DALIPytorchIterator
|
33 |
+
from nvidia.dali.plugin.pytorch import LastBatchPolicy as LastBatchPolicy
|
34 |
+
|
35 |
+
HAVE_DALI = True
|
36 |
+
except (ImportError, ModuleNotFoundError):
|
37 |
+
HAVE_DALI = False
|
38 |
+
|
39 |
+
__all__ = [
|
40 |
+
'AudioToCharDALIDataset',
|
41 |
+
'AudioToBPEDALIDataset',
|
42 |
+
]
|
43 |
+
|
44 |
+
"""
|
45 |
+
Below minimum version is required to access the "read_idxs" argument in
|
46 |
+
dali.fn.readers.nemo_asr
|
47 |
+
"""
|
48 |
+
__DALI_MINIMUM_VERSION__ = "1.11"
|
49 |
+
|
50 |
+
DALI_INSTALLATION_MESSAGE = (
|
51 |
+
"Could not import `nvidia.dali`.\n"
|
52 |
+
"Please install DALI by following the steps provided here - \n"
|
53 |
+
"https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def is_dali_supported(min_version: str, verbose: bool = False) -> bool:
|
58 |
+
"""
|
59 |
+
Checks if DALI in installed, and version is >= min_verion.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
min_version: A semver str that is the minimum requirement.
|
63 |
+
verbose: Whether to log the installation instructions if DALI is not found.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
bool - whether DALI could be imported or not.
|
67 |
+
"""
|
68 |
+
module_available, _ = model_utils.check_lib_version(
|
69 |
+
'nvidia.dali', checked_version=min_version, operator=operator.ge
|
70 |
+
)
|
71 |
+
|
72 |
+
# If DALI is not installed
|
73 |
+
if module_available is None:
|
74 |
+
if verbose:
|
75 |
+
logging.info(DALI_INSTALLATION_MESSAGE)
|
76 |
+
|
77 |
+
return False
|
78 |
+
|
79 |
+
return module_available
|
80 |
+
|
81 |
+
|
82 |
+
class DALIOutputs(object):
|
83 |
+
def __init__(self, out_dict):
|
84 |
+
self._has_processed_signal = 'processed_signal' in out_dict and 'processed_signal_len' in out_dict
|
85 |
+
if not self._has_processed_signal:
|
86 |
+
assert 'audio' in out_dict and 'audio_len' in out_dict
|
87 |
+
assert 'transcript' in out_dict and 'transcript_len' in out_dict
|
88 |
+
if self._has_processed_signal:
|
89 |
+
self._outs = (
|
90 |
+
out_dict['processed_signal'],
|
91 |
+
out_dict['processed_signal_len'].reshape(-1),
|
92 |
+
out_dict['transcript'],
|
93 |
+
out_dict['transcript_len'].reshape(-1),
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
self._outs = (
|
97 |
+
out_dict['audio'],
|
98 |
+
out_dict['audio_len'].reshape(-1),
|
99 |
+
out_dict['transcript'],
|
100 |
+
out_dict['transcript_len'].reshape(-1),
|
101 |
+
)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def has_processed_signal(self):
|
105 |
+
return self._has_processed_signal
|
106 |
+
|
107 |
+
def __getitem__(self, key):
|
108 |
+
return self._outs[key]
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return len(self._outs)
|
112 |
+
|
113 |
+
|
114 |
+
class _AudioTextDALIDataset(Iterator):
|
115 |
+
"""
|
116 |
+
NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a sample descriptor in JSON,
|
117 |
+
including audio files, transcripts, and durations (in seconds).
|
118 |
+
Here's an example:
|
119 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
|
120 |
+
...
|
121 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
|
122 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
123 |
+
|
124 |
+
Args:
|
125 |
+
manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths.
|
126 |
+
device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'.
|
127 |
+
batch_size (int): Number of samples in a batch.
|
128 |
+
parser (str, callable): A str for an inbuilt parser, or a callable with signature f(str) -> List[int].
|
129 |
+
sample_rate (int): Sample rate to resample loaded audio to.
|
130 |
+
num_threads (int): Number of CPU processing threads to be created by the DALI pipeline.
|
131 |
+
max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files.
|
132 |
+
min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files.
|
133 |
+
bos_id (int): Id of beginning of sequence symbol to append if not None
|
134 |
+
eos_id (int): Id of end of sequence symbol to append if not None
|
135 |
+
pad_id (int): Id used to pad the input. Defaults to 0 if not provided.
|
136 |
+
trim (bool): If True, it will extract the nonsilent region of the loaded audio signal.
|
137 |
+
shuffle (bool): If set to True, the dataset will shuffled after loading.
|
138 |
+
drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size.
|
139 |
+
If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
|
140 |
+
device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
|
141 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
142 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
|
143 |
+
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
144 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
manifest_filepath: str,
|
150 |
+
device: str,
|
151 |
+
batch_size: int,
|
152 |
+
parser: Union[str, Callable],
|
153 |
+
audio_tar_filepaths: Optional[Union[str, List[str]]] = None,
|
154 |
+
audio_tar_index_filepaths: Optional[Union[str, List[str]]] = None,
|
155 |
+
sample_rate: int = 16000,
|
156 |
+
num_threads: int = 4,
|
157 |
+
max_duration: float = 0.0,
|
158 |
+
min_duration: float = 0.0,
|
159 |
+
bos_id: Optional[int] = None,
|
160 |
+
eos_id: Optional[int] = None,
|
161 |
+
pad_id: int = 0,
|
162 |
+
trim: bool = False,
|
163 |
+
shuffle: bool = False,
|
164 |
+
drop_last: bool = False,
|
165 |
+
shard_strategy: str = "scatter",
|
166 |
+
device_id: int = 0,
|
167 |
+
global_rank: int = 0,
|
168 |
+
world_size: int = 1,
|
169 |
+
preprocessor_cfg: DictConfig = None,
|
170 |
+
return_sample_id: bool = False,
|
171 |
+
):
|
172 |
+
self.drop_last = drop_last # used by lr_scheduler
|
173 |
+
if return_sample_id:
|
174 |
+
raise ValueError(
|
175 |
+
"Currently DALI data layers don't support returning the sample_id and return_sample_id can not be enabled."
|
176 |
+
)
|
177 |
+
self.return_sample_id = return_sample_id
|
178 |
+
|
179 |
+
if not HAVE_DALI:
|
180 |
+
raise ModuleNotFoundError(
|
181 |
+
f"{self} requires NVIDIA DALI to be installed. "
|
182 |
+
f"See: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html#id1"
|
183 |
+
)
|
184 |
+
|
185 |
+
if device not in ('cpu', 'gpu'):
|
186 |
+
raise ValueError(
|
187 |
+
f"{self} received an unexpected device argument {device}. Supported values are: 'cpu', 'gpu'"
|
188 |
+
)
|
189 |
+
|
190 |
+
device_id = device_id if device == 'gpu' else None
|
191 |
+
|
192 |
+
self.batch_size = batch_size # Used by NeMo
|
193 |
+
|
194 |
+
self.device = device
|
195 |
+
self.device_id = device_id
|
196 |
+
|
197 |
+
if world_size > 1:
|
198 |
+
self.shard_id = global_rank
|
199 |
+
self.num_shards = world_size
|
200 |
+
else:
|
201 |
+
self.shard_id = None
|
202 |
+
self.num_shards = None
|
203 |
+
|
204 |
+
self.eos_id = eos_id
|
205 |
+
self.bos_id = bos_id
|
206 |
+
self.sample_rate = sample_rate
|
207 |
+
|
208 |
+
self.pipe = Pipeline(
|
209 |
+
batch_size=batch_size,
|
210 |
+
num_threads=num_threads,
|
211 |
+
device_id=self.device_id,
|
212 |
+
exec_async=True,
|
213 |
+
exec_pipelined=True,
|
214 |
+
)
|
215 |
+
|
216 |
+
has_preprocessor = preprocessor_cfg is not None
|
217 |
+
if has_preprocessor:
|
218 |
+
if preprocessor_cfg._target_ == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor":
|
219 |
+
feature_type = "mel_spectrogram"
|
220 |
+
elif preprocessor_cfg._target_ == "nemo.collections.asr.modules.AudioToMFCCPreprocessor":
|
221 |
+
feature_type = "mfcc"
|
222 |
+
else:
|
223 |
+
raise ValueError(
|
224 |
+
f"{self} received an unexpected preprocessor configuration: {preprocessor_cfg._target_}."
|
225 |
+
f" Supported preprocessors are: AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor"
|
226 |
+
)
|
227 |
+
|
228 |
+
# Default values taken from AudioToMelSpectrogramPreprocessor
|
229 |
+
params = preprocessor_cfg
|
230 |
+
self.dither = params['dither'] if 'dither' in params else 0.0
|
231 |
+
self.preemph = params['preemph'] if 'preemph' in params else 0.97
|
232 |
+
self.window_size_sec = params['window_size'] if 'window_size' in params else 0.02
|
233 |
+
self.window_stride_sec = params['window_stride'] if 'window_stride' in params else 0.01
|
234 |
+
self.sample_rate = params['sample_rate'] if 'sample_rate' in params else sample_rate
|
235 |
+
self.window_size = int(self.window_size_sec * self.sample_rate)
|
236 |
+
self.window_stride = int(self.window_stride_sec * self.sample_rate)
|
237 |
+
|
238 |
+
normalize = params['normalize'] if 'normalize' in params else 'per_feature'
|
239 |
+
if normalize == 'per_feature': # Each freq channel independently
|
240 |
+
self.normalization_axes = (1,)
|
241 |
+
elif normalize == 'all_features':
|
242 |
+
self.normalization_axes = (0, 1)
|
243 |
+
else:
|
244 |
+
raise ValueError(
|
245 |
+
f"{self} received {normalize} for the normalize parameter."
|
246 |
+
f" It must be either 'per_feature' or 'all_features'."
|
247 |
+
)
|
248 |
+
|
249 |
+
self.window = None
|
250 |
+
window_name = params['window'] if 'window' in params else 'hann'
|
251 |
+
torch_windows = {
|
252 |
+
'hann': torch.hann_window,
|
253 |
+
'hamming': torch.hamming_window,
|
254 |
+
'blackman': torch.blackman_window,
|
255 |
+
'bartlett': torch.bartlett_window,
|
256 |
+
'none': None,
|
257 |
+
}
|
258 |
+
|
259 |
+
if window_name == 'ones':
|
260 |
+
window_tensor = torch.ones(self.window_size)
|
261 |
+
else:
|
262 |
+
try:
|
263 |
+
window_fn = torch_windows.get(window_name, None)
|
264 |
+
except:
|
265 |
+
raise ValueError(
|
266 |
+
f"{self} received '{window_name}' for the window parameter."
|
267 |
+
f" It must be one of: ('hann', 'ones', 'hamming', 'blackman', 'bartlett', None)."
|
268 |
+
f" None is equivalent to 'hann'."
|
269 |
+
)
|
270 |
+
window_tensor = window_fn(self.window_size, periodic=False) if window_fn else None
|
271 |
+
self.window = window_tensor.numpy().tolist() if window_tensor is not None else None
|
272 |
+
|
273 |
+
self.n_fft = params['n_fft'] if 'n_fft' in params else 2 ** math.ceil(math.log2(self.window_size))
|
274 |
+
self.n_mels = params['n_mels'] if 'n_mels' in params else 64
|
275 |
+
self.n_mfcc = params['n_mfcc'] if 'n_mfcc' in params else 64
|
276 |
+
|
277 |
+
features = params['features'] if 'features' in params else 0
|
278 |
+
if features > 0:
|
279 |
+
if feature_type == 'mel_spectrogram':
|
280 |
+
self.n_mels = features
|
281 |
+
elif feature_type == 'mfcc':
|
282 |
+
self.n_mfcc = features
|
283 |
+
|
284 |
+
# TODO Implement frame splicing
|
285 |
+
if 'frame_splicing' in params:
|
286 |
+
assert params['frame_splicing'] == 1, "Frame splicing is not implemented"
|
287 |
+
|
288 |
+
self.freq_low = params['lowfreq'] if 'lowfreq' in params else 0.0
|
289 |
+
self.freq_high = params['highfreq'] if 'highfreq' in params else self.sample_rate / 2.0
|
290 |
+
self.log_features = params['log'] if 'log' in params else True
|
291 |
+
|
292 |
+
# We want to avoid taking the log of zero
|
293 |
+
# There are two options: either adding or clamping to a small value
|
294 |
+
|
295 |
+
self.log_zero_guard_type = params['log_zero_guard_type'] if 'log_zero_guard_type' in params else 'add'
|
296 |
+
if self.log_zero_guard_type not in ["add", "clamp"]:
|
297 |
+
raise ValueError(
|
298 |
+
f"{self} received {self.log_zero_guard_type} for the "
|
299 |
+
f"log_zero_guard_type parameter. It must be either 'add' or "
|
300 |
+
f"'clamp'."
|
301 |
+
)
|
302 |
+
|
303 |
+
self.log_zero_guard_value = (
|
304 |
+
params['log_zero_guard_value'] if 'log_zero_guard_value' in params else 2 ** -24
|
305 |
+
)
|
306 |
+
if isinstance(self.log_zero_guard_value, str):
|
307 |
+
if self.log_zero_guard_value == "tiny":
|
308 |
+
self.log_zero_guard_value = torch.finfo(torch.float32).tiny
|
309 |
+
elif self.log_zero_guard_value == "eps":
|
310 |
+
self.log_zero_guard_value = torch.finfo(torch.float32).eps
|
311 |
+
else:
|
312 |
+
raise ValueError(
|
313 |
+
f"{self} received {self.log_zero_guard_value} for the log_zero_guard_type parameter."
|
314 |
+
f"It must be either a number, 'tiny', or 'eps'"
|
315 |
+
)
|
316 |
+
|
317 |
+
self.mag_power = params['mag_power'] if 'mag_power' in params else 2
|
318 |
+
if self.mag_power != 1.0 and self.mag_power != 2.0:
|
319 |
+
raise ValueError(
|
320 |
+
f"{self} received {self.mag_power} for the mag_power parameter." f" It must be either 1.0 or 2.0."
|
321 |
+
)
|
322 |
+
|
323 |
+
self.pad_to = max(params['pad_to'], 1) if 'pad_to' in params else 16
|
324 |
+
self.pad_value = params['pad_value'] if 'pad_value' in params else 0.0
|
325 |
+
|
326 |
+
with self.pipe:
|
327 |
+
if audio_tar_filepaths is None and audio_tar_index_filepaths is None:
|
328 |
+
audio, indices = dali.fn.readers.nemo_asr(
|
329 |
+
name="Reader",
|
330 |
+
manifest_filepaths=manifest_filepath.split(','),
|
331 |
+
dtype=dali.types.FLOAT,
|
332 |
+
downmix=True,
|
333 |
+
sample_rate=float(self.sample_rate),
|
334 |
+
min_duration=min_duration,
|
335 |
+
max_duration=max_duration,
|
336 |
+
read_sample_rate=False,
|
337 |
+
read_text=False,
|
338 |
+
read_idxs=True,
|
339 |
+
random_shuffle=shuffle,
|
340 |
+
shard_id=self.shard_id,
|
341 |
+
num_shards=self.num_shards,
|
342 |
+
pad_last_batch=True,
|
343 |
+
)
|
344 |
+
|
345 |
+
self.is_tarred_dataset = False
|
346 |
+
|
347 |
+
elif audio_tar_filepaths is not None and audio_tar_index_filepaths is not None:
|
348 |
+
audio_tar_filepaths = expand_sharded_filepaths(
|
349 |
+
audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank
|
350 |
+
)
|
351 |
+
audio_tar_index_filepaths = expand_sharded_filepaths(
|
352 |
+
audio_tar_index_filepaths,
|
353 |
+
shard_strategy=shard_strategy,
|
354 |
+
world_size=world_size,
|
355 |
+
global_rank=global_rank,
|
356 |
+
)
|
357 |
+
|
358 |
+
if len(audio_tar_filepaths) != len(audio_tar_index_filepaths) and len(audio_tar_index_filepaths) != 0:
|
359 |
+
raise ValueError(
|
360 |
+
f"Number of filepaths provided for `audio_tar_filepaths` must match "
|
361 |
+
f"`audio_tar_index_filepaths`. Got {len(audio_tar_filepaths)} audio_tar_filepaths and "
|
362 |
+
f"{len(audio_tar_index_filepaths)} audio_tar_index_filepaths."
|
363 |
+
)
|
364 |
+
|
365 |
+
tar_file = dali.fn.readers.webdataset(
|
366 |
+
paths=audio_tar_filepaths,
|
367 |
+
index_paths=audio_tar_index_filepaths,
|
368 |
+
name="Reader",
|
369 |
+
ext=["wav"],
|
370 |
+
missing_component_behavior="error",
|
371 |
+
random_shuffle=shuffle,
|
372 |
+
shard_id=self.shard_id,
|
373 |
+
num_shards=self.num_shards,
|
374 |
+
pad_last_batch=True,
|
375 |
+
)
|
376 |
+
audio, _ = dali.fn.decoders.audio(
|
377 |
+
tar_file, dtype=dali.types.FLOAT, downmix=True, sample_rate=float(self.sample_rate),
|
378 |
+
)
|
379 |
+
indices = dali.fn.get_property(tar_file, key="source_info")
|
380 |
+
indices = dali.fn.pad(indices)
|
381 |
+
|
382 |
+
self.is_tarred_dataset = True
|
383 |
+
|
384 |
+
else:
|
385 |
+
raise RuntimeError(
|
386 |
+
"When using DALI datasets, either `audio_tar_filepaths` "
|
387 |
+
"and `audio_tar_index_filepaths` should either both be None (sequential dataset)"
|
388 |
+
"or provided (tarred dataset)."
|
389 |
+
)
|
390 |
+
|
391 |
+
# Extract nonsilent region, if necessary
|
392 |
+
if trim:
|
393 |
+
# Need to extract non-silent region before moving to the GPU
|
394 |
+
roi_start, roi_len = dali.fn.nonsilent_region(audio, cutoff_db=-60)
|
395 |
+
audio = audio.gpu() if self.device == 'gpu' else audio
|
396 |
+
audio = dali.fn.slice(
|
397 |
+
audio, roi_start, roi_len, normalized_anchor=False, normalized_shape=False, axes=[0]
|
398 |
+
)
|
399 |
+
else:
|
400 |
+
audio = audio.gpu() if self.device == 'gpu' else audio
|
401 |
+
|
402 |
+
if not has_preprocessor:
|
403 |
+
# No preprocessing, the output is the audio signal
|
404 |
+
audio_len = dali.fn.shapes(dali.fn.reshape(audio, shape=[-1]))
|
405 |
+
audio = dali.fn.pad(audio)
|
406 |
+
self.pipe.set_outputs(audio, audio_len, indices)
|
407 |
+
else:
|
408 |
+
# Additive gaussian noise (dither)
|
409 |
+
if self.dither > 0.0:
|
410 |
+
gaussian_noise = dali.fn.random.normal(audio)
|
411 |
+
audio = audio + self.dither * gaussian_noise
|
412 |
+
|
413 |
+
# Preemphasis filter
|
414 |
+
if self.preemph > 0.0:
|
415 |
+
audio = dali.fn.preemphasis_filter(audio, preemph_coeff=self.preemph, border='zero')
|
416 |
+
|
417 |
+
# Power spectrogram
|
418 |
+
spec = dali.fn.spectrogram(
|
419 |
+
audio,
|
420 |
+
nfft=self.n_fft,
|
421 |
+
window_length=self.window_size,
|
422 |
+
window_step=self.window_stride,
|
423 |
+
window_fn=self.window,
|
424 |
+
)
|
425 |
+
|
426 |
+
if feature_type == 'mel_spectrogram' or feature_type == 'mfcc':
|
427 |
+
# Spectrogram to Mel Spectrogram
|
428 |
+
spec = dali.fn.mel_filter_bank(
|
429 |
+
spec,
|
430 |
+
sample_rate=self.sample_rate,
|
431 |
+
nfilter=self.n_mels,
|
432 |
+
normalize=True,
|
433 |
+
freq_low=self.freq_low,
|
434 |
+
freq_high=self.freq_high,
|
435 |
+
)
|
436 |
+
# Mel Spectrogram to MFCC
|
437 |
+
if feature_type == 'mfcc':
|
438 |
+
spec = dali.fn.mfcc(spec, n_mfcc=self.n_mfcc)
|
439 |
+
|
440 |
+
# Logarithm
|
441 |
+
if self.log_zero_guard_type == 'add':
|
442 |
+
spec = spec + self.log_zero_guard_value
|
443 |
+
|
444 |
+
spec = dali.fn.to_decibels(
|
445 |
+
spec, multiplier=math.log(10), reference=1.0, cutoff_db=math.log(self.log_zero_guard_value)
|
446 |
+
)
|
447 |
+
|
448 |
+
# Normalization
|
449 |
+
spec = dali.fn.normalize(spec, axes=self.normalization_axes, epsilon=1e-5 ** 2, ddof=1)
|
450 |
+
|
451 |
+
# Extracting the length of the spectrogram
|
452 |
+
spec_len = dali.fn.slice(dali.fn.shapes(spec), 1, 1, axes=(0,))
|
453 |
+
|
454 |
+
# Pads feature dimension to be a multiple of `pad_to` and the temporal dimension to be as big as the largest sample (shape -1)
|
455 |
+
spec = dali.fn.pad(spec, fill_value=self.pad_value, axes=(0, 1), align=(self.pad_to, 1), shape=(1, -1))
|
456 |
+
self.pipe.set_outputs(spec, spec_len, indices)
|
457 |
+
|
458 |
+
x = time.time()
|
459 |
+
# Building DALI pipeline
|
460 |
+
self.pipe.build()
|
461 |
+
y = time.time()
|
462 |
+
|
463 |
+
logging.info(f"Time for pipe.build() : {(y - x)} seconds")
|
464 |
+
|
465 |
+
if has_preprocessor:
|
466 |
+
output_names = ['processed_signal', 'processed_signal_len', 'manifest_indices']
|
467 |
+
else:
|
468 |
+
output_names = ['audio', 'audio_len', 'manifest_indices']
|
469 |
+
|
470 |
+
x = time.time()
|
471 |
+
last_batch_policy = LastBatchPolicy.DROP if drop_last else LastBatchPolicy.PARTIAL
|
472 |
+
self._iter = DALIPytorchIterator(
|
473 |
+
[self.pipe],
|
474 |
+
output_map=output_names,
|
475 |
+
reader_name="Reader",
|
476 |
+
last_batch_policy=last_batch_policy,
|
477 |
+
dynamic_shape=True,
|
478 |
+
auto_reset=True,
|
479 |
+
)
|
480 |
+
y = time.time()
|
481 |
+
logging.info(f"Time for DALIPytorchIterator to initialize : {(y - x)} seconds")
|
482 |
+
|
483 |
+
# TODO come up with a better solution
|
484 |
+
class DummyDataset:
|
485 |
+
def __init__(self, parent):
|
486 |
+
self.parent = parent
|
487 |
+
|
488 |
+
def __len__(self):
|
489 |
+
return self.parent.size
|
490 |
+
|
491 |
+
self.dataset = DummyDataset(self) # Used by NeMo
|
492 |
+
|
493 |
+
x = time.time()
|
494 |
+
self.manifest_processor = ASRManifestProcessor(
|
495 |
+
manifest_filepath=manifest_filepath,
|
496 |
+
parser=parser,
|
497 |
+
max_duration=max_duration,
|
498 |
+
min_duration=min_duration,
|
499 |
+
max_utts=0,
|
500 |
+
bos_id=bos_id,
|
501 |
+
eos_id=eos_id,
|
502 |
+
pad_id=pad_id,
|
503 |
+
index_by_file_id=self.is_tarred_dataset,
|
504 |
+
)
|
505 |
+
y = time.time()
|
506 |
+
logging.info(f"Time to build nemo manifest processor - {(y - x)} seconds")
|
507 |
+
|
508 |
+
def reset(self):
|
509 |
+
self._iter.reset()
|
510 |
+
|
511 |
+
def __iter__(self):
|
512 |
+
return self
|
513 |
+
|
514 |
+
def next(self):
|
515 |
+
return self.__next__()
|
516 |
+
|
517 |
+
@property
|
518 |
+
def size(self):
|
519 |
+
return self._iter.size
|
520 |
+
|
521 |
+
def __len__(self):
|
522 |
+
return len(self._iter)
|
523 |
+
|
524 |
+
def __next__(self):
|
525 |
+
outputs = self._iter.next()
|
526 |
+
assert len(outputs) == 1
|
527 |
+
dali_out = outputs[0]
|
528 |
+
manifest_indices = dali_out['manifest_indices'].numpy()
|
529 |
+
|
530 |
+
out = {}
|
531 |
+
out_names = ['processed_signal', 'processed_signal_len', 'audio', 'audio_len']
|
532 |
+
for out_name in out_names:
|
533 |
+
if out_name in dali_out:
|
534 |
+
out[out_name] = dali_out[out_name].detach().clone()
|
535 |
+
|
536 |
+
text_tokens = []
|
537 |
+
text_tokens_len = []
|
538 |
+
max_len = 0
|
539 |
+
batch_size = manifest_indices.shape[0]
|
540 |
+
for i, manifest_index in enumerate(manifest_indices):
|
541 |
+
|
542 |
+
if not self.is_tarred_dataset:
|
543 |
+
# Loose-file dataset. Index is integer based.
|
544 |
+
manifest_index = manifest_index[0]
|
545 |
+
text, text_length = self.manifest_processor.process_text_by_id(manifest_index)
|
546 |
+
else:
|
547 |
+
# Tarred-file dataset. Index is filename based.
|
548 |
+
resolved_manifest_indices = manifest_index.tobytes().decode().split(":")
|
549 |
+
resolved_manifest_index = resolved_manifest_indices[2] # we require just the filename segment
|
550 |
+
resolved_manifest_index = os.path.splitext(resolved_manifest_index)[0] # we dont need file extension
|
551 |
+
text, text_length = self.manifest_processor.process_text_by_file_id(resolved_manifest_index)
|
552 |
+
|
553 |
+
text_tokens_len.append(text_length)
|
554 |
+
text_tokens.append(text)
|
555 |
+
if text_length > max_len:
|
556 |
+
max_len = text_length
|
557 |
+
|
558 |
+
transcript_out = torch.full([batch_size, max_len], fill_value=self.manifest_processor.pad_id, dtype=torch.long)
|
559 |
+
for i, n in enumerate(text_tokens_len):
|
560 |
+
transcript_out[i, :n] = torch.tensor(text_tokens[i], dtype=torch.long)
|
561 |
+
transcript_len_out = torch.tensor(text_tokens_len, dtype=torch.long)
|
562 |
+
|
563 |
+
out['transcript'] = transcript_out
|
564 |
+
out['transcript_len'] = transcript_len_out
|
565 |
+
return DALIOutputs(out)
|
566 |
+
|
567 |
+
|
568 |
+
class AudioToCharDALIDataset(_AudioTextDALIDataset):
|
569 |
+
"""
|
570 |
+
Character based NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a
|
571 |
+
sample descriptor in JSON, including audio files, transcripts, and durations (in seconds).
|
572 |
+
Here's an example:
|
573 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
|
574 |
+
...
|
575 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
|
576 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
577 |
+
|
578 |
+
Args:
|
579 |
+
manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths.
|
580 |
+
device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'.
|
581 |
+
batch_size (int): Number of samples in a batch.
|
582 |
+
labels (List[str]): String containing all the possible characters to map to.
|
583 |
+
sample_rate (int): Sample rate to resample loaded audio to.
|
584 |
+
num_threads (int): Number of CPU processing threads to be created by the DALI pipeline.
|
585 |
+
max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files.
|
586 |
+
min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files.
|
587 |
+
blank_index (int): blank character index, default = -1
|
588 |
+
unk_index (int): unk_character index, default = -1
|
589 |
+
normalize (bool): whether to normalize transcript text (default): True
|
590 |
+
bos_id (int): Id of beginning of sequence symbol to append if not None
|
591 |
+
eos_id (int): Id of end of sequence symbol to append if not None
|
592 |
+
pad_id (int): Id used to pad the input. Defaults to 0 if not provided.
|
593 |
+
trim (bool): If True, it will extract the nonsilent region of the loaded audio signal.
|
594 |
+
shuffle (bool): If set to True, the dataset will shuffled after loading.
|
595 |
+
drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size.
|
596 |
+
If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
|
597 |
+
parser (str, callable): A str for an inbuilt parser, or a callable with signature f(str) -> List[int].
|
598 |
+
device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
|
599 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
600 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
|
601 |
+
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
602 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
|
603 |
+
"""
|
604 |
+
|
605 |
+
def __init__(
|
606 |
+
self,
|
607 |
+
manifest_filepath: str,
|
608 |
+
device: str,
|
609 |
+
batch_size: int,
|
610 |
+
labels: Union[str, List[str]],
|
611 |
+
sample_rate: int = 16000,
|
612 |
+
audio_tar_filepaths: Optional[Union[str, List[str]]] = None,
|
613 |
+
audio_tar_index_filepaths: Optional[Union[str, List[str]]] = None,
|
614 |
+
num_threads: int = 4,
|
615 |
+
max_duration: float = 0.0,
|
616 |
+
min_duration: float = 0.0,
|
617 |
+
blank_index: int = -1,
|
618 |
+
unk_index: int = -1,
|
619 |
+
normalize: bool = True,
|
620 |
+
bos_id: Optional[int] = None,
|
621 |
+
eos_id: Optional[int] = None,
|
622 |
+
pad_id: int = 0,
|
623 |
+
trim: bool = False,
|
624 |
+
shuffle: bool = False,
|
625 |
+
drop_last: bool = False,
|
626 |
+
parser: Union[str, Callable] = 'en',
|
627 |
+
shard_strategy: str = "scatter",
|
628 |
+
device_id: int = 0,
|
629 |
+
global_rank: int = 0,
|
630 |
+
world_size: int = 1,
|
631 |
+
preprocessor_cfg: DictConfig = None,
|
632 |
+
return_sample_id: bool = False,
|
633 |
+
):
|
634 |
+
self.labels = labels
|
635 |
+
|
636 |
+
parser = parsers.make_parser(
|
637 |
+
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
|
638 |
+
)
|
639 |
+
|
640 |
+
super().__init__(
|
641 |
+
manifest_filepath=manifest_filepath,
|
642 |
+
device=device,
|
643 |
+
batch_size=batch_size,
|
644 |
+
audio_tar_filepaths=audio_tar_filepaths,
|
645 |
+
audio_tar_index_filepaths=audio_tar_index_filepaths,
|
646 |
+
sample_rate=sample_rate,
|
647 |
+
num_threads=num_threads,
|
648 |
+
max_duration=max_duration,
|
649 |
+
min_duration=min_duration,
|
650 |
+
bos_id=bos_id,
|
651 |
+
eos_id=eos_id,
|
652 |
+
pad_id=pad_id,
|
653 |
+
trim=trim,
|
654 |
+
shuffle=shuffle,
|
655 |
+
drop_last=drop_last,
|
656 |
+
parser=parser,
|
657 |
+
shard_strategy=shard_strategy,
|
658 |
+
device_id=device_id,
|
659 |
+
global_rank=global_rank,
|
660 |
+
world_size=world_size,
|
661 |
+
preprocessor_cfg=preprocessor_cfg,
|
662 |
+
return_sample_id=return_sample_id,
|
663 |
+
)
|
664 |
+
|
665 |
+
|
666 |
+
class AudioToBPEDALIDataset(_AudioTextDALIDataset):
|
667 |
+
"""
|
668 |
+
Subword based NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a
|
669 |
+
sample descriptor in JSON, including audio files, transcripts, and durations (in seconds).
|
670 |
+
Here's an example:
|
671 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
|
672 |
+
...
|
673 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
|
674 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
675 |
+
|
676 |
+
Args:
|
677 |
+
manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths.
|
678 |
+
tokenizer (TokenizerSpec): A TokenizerSpec implementation that wraps a tokenization implementation.
|
679 |
+
device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'.
|
680 |
+
batch_size (int): Number of samples in a batch.
|
681 |
+
sample_rate (int): Sample rate to resample loaded audio to.
|
682 |
+
num_threads (int): Number of CPU processing threads to be created by the DALI pipeline.
|
683 |
+
max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files.
|
684 |
+
min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files.
|
685 |
+
bos_id (int): Id of beginning of sequence symbol to append if not None. Injected from the tokenizer.
|
686 |
+
eos_id (int): Id of end of sequence symbol to append if not None. Injected from the tokenizer.
|
687 |
+
pad_id (int): Id used to pad the input. Defaults to 0 if not provided. Injected from the tokenizer.
|
688 |
+
trim (bool): If True, it will extract the nonsilent region of the loaded audio signal.
|
689 |
+
shuffle (bool): If set to True, the dataset will shuffled after loading.
|
690 |
+
drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size.
|
691 |
+
If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
|
692 |
+
|
693 |
+
device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
|
694 |
+
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
695 |
+
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
|
696 |
+
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
697 |
+
use_start_end_token (bool): Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and
|
698 |
+
ending of speech respectively.
|
699 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
|
700 |
+
"""
|
701 |
+
|
702 |
+
def __init__(
|
703 |
+
self,
|
704 |
+
manifest_filepath: str,
|
705 |
+
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
|
706 |
+
device: str,
|
707 |
+
batch_size: int,
|
708 |
+
sample_rate: int = 16000,
|
709 |
+
audio_tar_filepaths: Optional[Union[str, List[str]]] = None,
|
710 |
+
audio_tar_index_filepaths: Optional[Union[str, List[str]]] = None,
|
711 |
+
num_threads: int = 4,
|
712 |
+
max_duration: float = 0.0,
|
713 |
+
min_duration: float = 0.0,
|
714 |
+
trim: bool = False,
|
715 |
+
shuffle: bool = False,
|
716 |
+
drop_last: bool = False,
|
717 |
+
shard_strategy: str = "scatter",
|
718 |
+
device_id: int = 0,
|
719 |
+
global_rank: int = 0,
|
720 |
+
world_size: int = 1,
|
721 |
+
preprocessor_cfg: DictConfig = None,
|
722 |
+
use_start_end_token: bool = True,
|
723 |
+
return_sample_id: bool = False,
|
724 |
+
):
|
725 |
+
|
726 |
+
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
|
727 |
+
bos_id = tokenizer.bos_id
|
728 |
+
else:
|
729 |
+
bos_id = None
|
730 |
+
|
731 |
+
if use_start_end_token and hasattr(tokenizer, 'eos_token'):
|
732 |
+
eos_id = tokenizer.eos_id
|
733 |
+
else:
|
734 |
+
eos_id = None
|
735 |
+
|
736 |
+
if hasattr(tokenizer, 'pad_token'):
|
737 |
+
pad_id = tokenizer.pad_id
|
738 |
+
else:
|
739 |
+
pad_id = 0
|
740 |
+
|
741 |
+
class TokenizerWrapper:
|
742 |
+
def __init__(self, tokenizer):
|
743 |
+
self._tokenizer = tokenizer
|
744 |
+
|
745 |
+
def __call__(self, text):
|
746 |
+
t = self._tokenizer.text_to_ids(text)
|
747 |
+
return t
|
748 |
+
|
749 |
+
super().__init__(
|
750 |
+
manifest_filepath=manifest_filepath,
|
751 |
+
device=device,
|
752 |
+
batch_size=batch_size,
|
753 |
+
sample_rate=sample_rate,
|
754 |
+
audio_tar_filepaths=audio_tar_filepaths,
|
755 |
+
audio_tar_index_filepaths=audio_tar_index_filepaths,
|
756 |
+
num_threads=num_threads,
|
757 |
+
max_duration=max_duration,
|
758 |
+
min_duration=min_duration,
|
759 |
+
bos_id=bos_id,
|
760 |
+
eos_id=eos_id,
|
761 |
+
pad_id=pad_id,
|
762 |
+
trim=trim,
|
763 |
+
shuffle=shuffle,
|
764 |
+
drop_last=drop_last,
|
765 |
+
parser=TokenizerWrapper(tokenizer),
|
766 |
+
shard_strategy=shard_strategy,
|
767 |
+
device_id=device_id,
|
768 |
+
global_rank=global_rank,
|
769 |
+
world_size=world_size,
|
770 |
+
preprocessor_cfg=preprocessor_cfg,
|
771 |
+
return_sample_id=return_sample_id,
|
772 |
+
)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dataset.py
ADDED
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import json
|
17 |
+
import random
|
18 |
+
from math import isclose
|
19 |
+
from typing import Any, List, Optional, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
23 |
+
from omegaconf.listconfig import ListConfig
|
24 |
+
from pytorch_lightning.callbacks import BasePredictionWriter
|
25 |
+
from torch.utils.data import ChainDataset
|
26 |
+
|
27 |
+
from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
|
28 |
+
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
29 |
+
from nemo.collections.common.data.dataset import CodeSwitchedDataset, ConcatDataset
|
30 |
+
from nemo.utils import logging
|
31 |
+
|
32 |
+
|
33 |
+
def inject_dataloader_value_from_model_config(model_cfg: dict, dataloader_cfg: DictConfig, key: str):
|
34 |
+
"""
|
35 |
+
Extracts the label set provided at the top level of the model, and propagates it to the dataloader
|
36 |
+
config.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
model_cfg: A DictConfig representing the model's config.
|
40 |
+
dataloader_cfg: A DictConfig representing the individual data loader
|
41 |
+
key: A str value representing a key in the model_cfg whose value will be propagated to the
|
42 |
+
dataloader config.
|
43 |
+
"""
|
44 |
+
if key not in model_cfg:
|
45 |
+
logging.info(
|
46 |
+
f"Model level config does not contain `{key}`, please explicitly provide `{key}` to the dataloaders."
|
47 |
+
)
|
48 |
+
return
|
49 |
+
|
50 |
+
if not isinstance(dataloader_cfg, DictConfig):
|
51 |
+
dataloader_cfg = DictConfig(dataloader_cfg)
|
52 |
+
|
53 |
+
# If key exists in the data loader config (either set explicitly or as a placeholder (via None))
|
54 |
+
if key in dataloader_cfg:
|
55 |
+
# Dataloader `labels` is provided and is non-null
|
56 |
+
if dataloader_cfg[key] is not None and model_cfg[key] != dataloader_cfg[key]:
|
57 |
+
# Model level `labels` dont match Dataloader level `labels`
|
58 |
+
logging.warning(
|
59 |
+
f'`{key}` is explicitly provided to the data loader, and is different from '
|
60 |
+
f'the `{key}` provided at the model level config.\n'
|
61 |
+
f'If this is incorrect, please set the dataloader\'s `{key}` to None.'
|
62 |
+
)
|
63 |
+
|
64 |
+
else:
|
65 |
+
# Dataloader `key` is None or values match
|
66 |
+
# Propagate from model level `key` (even if they match)
|
67 |
+
with open_dict(dataloader_cfg):
|
68 |
+
dataloader_cfg[key] = model_cfg[key]
|
69 |
+
|
70 |
+
else:
|
71 |
+
# If key key doesnt even exist in dataloader_cfg, inject it explicitly
|
72 |
+
with open_dict(dataloader_cfg):
|
73 |
+
dataloader_cfg[key] = model_cfg[key]
|
74 |
+
|
75 |
+
|
76 |
+
def get_concat_char_dataset(
|
77 |
+
config: dict, global_rank: int, world_size: int, augmentor: Optional['AudioAugmentor'] = None
|
78 |
+
) -> ConcatDataset:
|
79 |
+
"""
|
80 |
+
Instantiates an instance of ConcatDataset containing one or more intances of
|
81 |
+
Character Encoding based AudioToCharDataset.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
config: Config of the AudioToCharDataset.
|
85 |
+
global_rank: Global rank of this device.
|
86 |
+
world_size: Global world size in the training method.
|
87 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
An instance of ConcatDataset containing one or more instances of AudioToCharDataset.
|
91 |
+
"""
|
92 |
+
if 'labels' not in config:
|
93 |
+
logging.warning(f"dataset does not have explicitly defined labels")
|
94 |
+
|
95 |
+
manifest_filepaths = config['manifest_filepath']
|
96 |
+
datasets = []
|
97 |
+
|
98 |
+
# needed to support validation Concat Datasets that arrive here as
|
99 |
+
# [[dataset1,dataset2]] otherwise ModelPT would interfere
|
100 |
+
if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str):
|
101 |
+
logging.info(f"removing an extra nesting level from {manifest_filepaths}")
|
102 |
+
manifest_filepaths = config['manifest_filepath'][0]
|
103 |
+
|
104 |
+
for manifest_filepath in manifest_filepaths:
|
105 |
+
conf = copy.deepcopy(config)
|
106 |
+
conf['manifest_filepath'] = manifest_filepath
|
107 |
+
|
108 |
+
dataset = get_char_dataset(config=conf, augmentor=augmentor)
|
109 |
+
datasets.append(dataset)
|
110 |
+
|
111 |
+
dataset = ConcatDataset(
|
112 |
+
datasets,
|
113 |
+
sampling_technique=config.get('concat_sampling_technique', 'temperature'),
|
114 |
+
sampling_temperature=config.get('concat_sampling_temperature', 5),
|
115 |
+
sampling_scale=config.get('concat_sampling_scale', 1),
|
116 |
+
sampling_probabilities=config.get('concat_sampling_probabilities', None),
|
117 |
+
shuffle=config.get('concat_shuffle', True),
|
118 |
+
seed=config.get('concat_sampling_seed', None),
|
119 |
+
global_rank=global_rank,
|
120 |
+
world_size=world_size,
|
121 |
+
)
|
122 |
+
return dataset
|
123 |
+
|
124 |
+
|
125 |
+
def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) -> audio_to_text.AudioToCharDataset:
|
126 |
+
"""
|
127 |
+
Instantiates a Character Encoding based AudioToCharDataset.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
config: Config of the AudioToCharDataset.
|
131 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
An instance of AudioToCharDataset.
|
135 |
+
"""
|
136 |
+
if 'labels' not in config:
|
137 |
+
logging.warning(f"dataset does not have explicitly defined labels")
|
138 |
+
|
139 |
+
dataset = audio_to_text.AudioToCharDataset(
|
140 |
+
manifest_filepath=config['manifest_filepath'],
|
141 |
+
labels=config.get('labels', None),
|
142 |
+
sample_rate=config['sample_rate'],
|
143 |
+
int_values=config.get('int_values', False),
|
144 |
+
augmentor=augmentor,
|
145 |
+
max_duration=config.get('max_duration', None),
|
146 |
+
min_duration=config.get('min_duration', None),
|
147 |
+
max_utts=config.get('max_utts', 0),
|
148 |
+
blank_index=config.get('blank_index', -1),
|
149 |
+
unk_index=config.get('unk_index', -1),
|
150 |
+
normalize=config.get('normalize_transcripts', False),
|
151 |
+
trim=config.get('trim_silence', False),
|
152 |
+
parser=config.get('parser', 'en'),
|
153 |
+
return_sample_id=config.get('return_sample_id', False),
|
154 |
+
channel_selector=config.get('channel_selector', None),
|
155 |
+
)
|
156 |
+
return dataset
|
157 |
+
|
158 |
+
|
159 |
+
def get_concat_bpe_dataset(
|
160 |
+
config: dict,
|
161 |
+
tokenizer: 'TokenizerSpec',
|
162 |
+
global_rank: int,
|
163 |
+
world_size: int,
|
164 |
+
augmentor: Optional['AudioAugmentor'] = None,
|
165 |
+
) -> ConcatDataset:
|
166 |
+
"""
|
167 |
+
Instantiates a ContactDataset based on several Byte Pair Encoding / Word Piece Encoding based AudioToBPEDatasets.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
config: Config of the AudioToBPEDataset.
|
171 |
+
tokenizer: An instance of a TokenizerSpec object.
|
172 |
+
global_rank: Global rank of this device.
|
173 |
+
world_size: Global world size in the training method.
|
174 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
An instance of ConcatDataset containing several instances of AudioToBPEDataset.
|
178 |
+
"""
|
179 |
+
manifest_filepaths = config['manifest_filepath']
|
180 |
+
datasets = []
|
181 |
+
|
182 |
+
# needed to support validation Concat Datasets that arrive here as
|
183 |
+
# [[dataset1,dataset2]] otherwise ModelPT would interfere
|
184 |
+
if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str):
|
185 |
+
logging.info(f"removing an extra nesting level from {manifest_filepaths}")
|
186 |
+
manifest_filepaths = config['manifest_filepath'][0]
|
187 |
+
|
188 |
+
for manifest_filepath in manifest_filepaths:
|
189 |
+
conf = copy.deepcopy(config)
|
190 |
+
conf['manifest_filepath'] = manifest_filepath
|
191 |
+
dataset = get_bpe_dataset(config=conf, tokenizer=tokenizer, augmentor=augmentor)
|
192 |
+
datasets.append(dataset)
|
193 |
+
|
194 |
+
dataset = ConcatDataset(
|
195 |
+
datasets,
|
196 |
+
sampling_technique=config.get('concat_sampling_technique', 'temperature'),
|
197 |
+
sampling_temperature=config.get('concat_sampling_temperature', 5),
|
198 |
+
sampling_scale=config.get('concat_sampling_scale', 1),
|
199 |
+
sampling_probabilities=config.get('concat_sampling_probabilities', None),
|
200 |
+
shuffle=config.get('concat_shuffle', True),
|
201 |
+
seed=config.get('concat_sampling_seed', None),
|
202 |
+
global_rank=global_rank,
|
203 |
+
world_size=world_size,
|
204 |
+
)
|
205 |
+
return dataset
|
206 |
+
|
207 |
+
|
208 |
+
def get_bpe_dataset(
|
209 |
+
config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['AudioAugmentor'] = None
|
210 |
+
) -> audio_to_text.AudioToBPEDataset:
|
211 |
+
"""
|
212 |
+
Instantiates a Byte Pair Encoding / Word Piece Encoding based AudioToBPEDataset.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
config: Config of the AudioToBPEDataset.
|
216 |
+
tokenizer: An instance of a TokenizerSpec object.
|
217 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
An instance of AudioToBPEDataset.
|
221 |
+
"""
|
222 |
+
dataset = audio_to_text.AudioToBPEDataset(
|
223 |
+
manifest_filepath=config['manifest_filepath'],
|
224 |
+
tokenizer=tokenizer,
|
225 |
+
sample_rate=config['sample_rate'],
|
226 |
+
int_values=config.get('int_values', False),
|
227 |
+
augmentor=augmentor,
|
228 |
+
max_duration=config.get('max_duration', None),
|
229 |
+
min_duration=config.get('min_duration', None),
|
230 |
+
max_utts=config.get('max_utts', 0),
|
231 |
+
trim=config.get('trim_silence', False),
|
232 |
+
use_start_end_token=config.get('use_start_end_token', True),
|
233 |
+
return_sample_id=config.get('return_sample_id', False),
|
234 |
+
channel_selector=config.get('channel_selector', None),
|
235 |
+
)
|
236 |
+
return dataset
|
237 |
+
|
238 |
+
|
239 |
+
def get_concat_tarred_dataset(
|
240 |
+
config: dict,
|
241 |
+
shuffle_n: int,
|
242 |
+
global_rank: int,
|
243 |
+
world_size: int,
|
244 |
+
tokenizer: Optional['TokenizerSpec'] = None,
|
245 |
+
augmentor: Optional['AudioAugmentor'] = None,
|
246 |
+
) -> ConcatDataset:
|
247 |
+
"""
|
248 |
+
Instantiates a ConcatDataset containing multiple Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
|
252 |
+
shuffle_n: How many samples to look ahead and load to be shuffled.
|
253 |
+
See WebDataset documentation for more details.
|
254 |
+
tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
|
255 |
+
global_rank: Global rank of this device.
|
256 |
+
world_size: Global world size in the training method.
|
257 |
+
Passsing None would return a char-based dataset.
|
258 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
An instance of ConcatDataset containing one or more TarredAudioToBPEDatasets or TarredAudioToCharDatasets.
|
262 |
+
"""
|
263 |
+
|
264 |
+
tarred_audio_filepaths = config['tarred_audio_filepaths']
|
265 |
+
manifest_filepaths = config['manifest_filepath']
|
266 |
+
datasets = []
|
267 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
268 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
269 |
+
):
|
270 |
+
conf = copy.deepcopy(config)
|
271 |
+
conf['manifest_filepath'] = manifest_filepath
|
272 |
+
conf['tarred_audio_filepaths'] = tarred_audio_filepath
|
273 |
+
dataset = get_tarred_dataset(
|
274 |
+
config=conf,
|
275 |
+
tokenizer=tokenizer,
|
276 |
+
shuffle_n=shuffle_n,
|
277 |
+
global_rank=global_rank,
|
278 |
+
world_size=world_size,
|
279 |
+
augmentor=augmentor,
|
280 |
+
)
|
281 |
+
datasets.append(dataset)
|
282 |
+
|
283 |
+
dataset = ConcatDataset(
|
284 |
+
datasets,
|
285 |
+
sampling_technique=config.get('concat_sampling_technique', 'temperature'),
|
286 |
+
sampling_temperature=config.get('concat_sampling_temperature', 5),
|
287 |
+
sampling_scale=config.get('concat_sampling_scale', 1),
|
288 |
+
sampling_probabilities=config.get('concat_sampling_probabilities', None),
|
289 |
+
shuffle=config.get('concat_shuffle', True),
|
290 |
+
seed=config.get('concat_sampling_seed', None),
|
291 |
+
global_rank=global_rank,
|
292 |
+
world_size=world_size,
|
293 |
+
)
|
294 |
+
return dataset
|
295 |
+
|
296 |
+
|
297 |
+
def get_tarred_dataset(
|
298 |
+
config: dict,
|
299 |
+
shuffle_n: int,
|
300 |
+
global_rank: int,
|
301 |
+
world_size: int,
|
302 |
+
tokenizer: Optional['TokenizerSpec'] = None,
|
303 |
+
augmentor: Optional['AudioAugmentor'] = None,
|
304 |
+
) -> Union[audio_to_text.TarredAudioToBPEDataset, audio_to_text.TarredAudioToCharDataset]:
|
305 |
+
"""
|
306 |
+
Instantiates a Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
|
310 |
+
shuffle_n: How many samples to look ahead and load to be shuffled.
|
311 |
+
See WebDataset documentation for more details.
|
312 |
+
tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
|
313 |
+
global_rank: Global rank of this device.
|
314 |
+
world_size: Global world size in the training method.
|
315 |
+
Passsing None would return a char-based dataset.
|
316 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
An instance of TarredAudioToBPEDataset or TarredAudioToCharDataset.
|
320 |
+
"""
|
321 |
+
tarred_audio_filepaths = config['tarred_audio_filepaths']
|
322 |
+
manifest_filepaths = config['manifest_filepath']
|
323 |
+
datasets = []
|
324 |
+
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
|
325 |
+
manifest_filepaths = convert_to_config_list(manifest_filepaths)
|
326 |
+
|
327 |
+
bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets
|
328 |
+
if bucketing_weights:
|
329 |
+
for idx, weight in enumerate(bucketing_weights):
|
330 |
+
if not isinstance(weight, int) or weight <= 0:
|
331 |
+
raise ValueError(f"bucket weights must be positive integers")
|
332 |
+
|
333 |
+
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
334 |
+
raise ValueError(
|
335 |
+
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
|
336 |
+
)
|
337 |
+
|
338 |
+
if 'labels' not in config:
|
339 |
+
logging.warning(f"dataset does not have explicitly defined labels")
|
340 |
+
|
341 |
+
if 'max_utts' in config:
|
342 |
+
raise ValueError('"max_utts" parameter is not supported for tarred datasets')
|
343 |
+
|
344 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
345 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
346 |
+
):
|
347 |
+
if len(tarred_audio_filepath) == 1:
|
348 |
+
tarred_audio_filepath = tarred_audio_filepath[0]
|
349 |
+
if len(manifest_filepath) == 1:
|
350 |
+
manifest_filepath = manifest_filepath[0]
|
351 |
+
|
352 |
+
if tokenizer is None:
|
353 |
+
dataset = audio_to_text.TarredAudioToCharDataset(
|
354 |
+
audio_tar_filepaths=tarred_audio_filepath,
|
355 |
+
manifest_filepath=manifest_filepath,
|
356 |
+
labels=config.get('labels', None),
|
357 |
+
sample_rate=config['sample_rate'],
|
358 |
+
int_values=config.get('int_values', False),
|
359 |
+
augmentor=augmentor,
|
360 |
+
shuffle_n=shuffle_n,
|
361 |
+
max_duration=config.get('max_duration', None),
|
362 |
+
min_duration=config.get('min_duration', None),
|
363 |
+
blank_index=config.get('blank_index', -1),
|
364 |
+
unk_index=config.get('unk_index', -1),
|
365 |
+
normalize=config.get('normalize_transcripts', False),
|
366 |
+
trim=config.get('trim_silence', False),
|
367 |
+
parser=config.get('parser', 'en'),
|
368 |
+
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
369 |
+
shard_manifests=config.get('shard_manifests', False),
|
370 |
+
global_rank=global_rank,
|
371 |
+
world_size=world_size,
|
372 |
+
return_sample_id=config.get('return_sample_id', False),
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
dataset = audio_to_text.TarredAudioToBPEDataset(
|
376 |
+
audio_tar_filepaths=tarred_audio_filepath,
|
377 |
+
manifest_filepath=manifest_filepath,
|
378 |
+
tokenizer=tokenizer,
|
379 |
+
sample_rate=config['sample_rate'],
|
380 |
+
int_values=config.get('int_values', False),
|
381 |
+
augmentor=augmentor,
|
382 |
+
shuffle_n=shuffle_n,
|
383 |
+
max_duration=config.get('max_duration', None),
|
384 |
+
min_duration=config.get('min_duration', None),
|
385 |
+
trim=config.get('trim_silence', False),
|
386 |
+
use_start_end_token=config.get('use_start_end_token', True),
|
387 |
+
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
388 |
+
shard_manifests=config.get('shard_manifests', False),
|
389 |
+
global_rank=global_rank,
|
390 |
+
world_size=world_size,
|
391 |
+
return_sample_id=config.get('return_sample_id', False),
|
392 |
+
)
|
393 |
+
if bucketing_weights:
|
394 |
+
[datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
|
395 |
+
else:
|
396 |
+
datasets.append(dataset)
|
397 |
+
|
398 |
+
return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
|
399 |
+
|
400 |
+
|
401 |
+
def get_code_switched_dataset(
|
402 |
+
config: dict,
|
403 |
+
shuffle_n: int,
|
404 |
+
global_rank: int,
|
405 |
+
world_size: int,
|
406 |
+
tokenizer: Optional['TokenizerSpec'] = None,
|
407 |
+
augmentor: Optional['AudioAugmentor'] = None,
|
408 |
+
) -> CodeSwitchedDataset:
|
409 |
+
|
410 |
+
if 'manifest_filepath' not in config:
|
411 |
+
raise ValueError("`manifest_filepath` must be provided in the dataset config if `is_code_switched=True`")
|
412 |
+
if 'code_switched' not in config:
|
413 |
+
raise ValueError("`code_switched` param group must be in the dataset config if `is_code_switched=True`")
|
414 |
+
|
415 |
+
manifest_filepaths = config['manifest_filepath']
|
416 |
+
tarred_audio_filepaths = config.get('tarred_audio_filepaths', None)
|
417 |
+
|
418 |
+
cs_config = OmegaConf.to_container(config['code_switched'])
|
419 |
+
|
420 |
+
# needed to support validation Datasets that arrive here as
|
421 |
+
# [[dataset1,dataset2]] otherwise ModelPT would interfere
|
422 |
+
if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str):
|
423 |
+
manifest_filepaths = config['manifest_filepath'][0]
|
424 |
+
if tarred_audio_filepaths is None:
|
425 |
+
tarred_audio_filepaths = [None] * len(manifest_filepaths)
|
426 |
+
|
427 |
+
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
428 |
+
raise ValueError(
|
429 |
+
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of items."
|
430 |
+
)
|
431 |
+
|
432 |
+
datasets = []
|
433 |
+
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
434 |
+
zip(tarred_audio_filepaths, manifest_filepaths)
|
435 |
+
):
|
436 |
+
conf = copy.deepcopy(config)
|
437 |
+
conf['manifest_filepath'] = manifest_filepath
|
438 |
+
with open_dict(conf):
|
439 |
+
conf['tarred_audio_filepaths'] = tarred_audio_filepath
|
440 |
+
if tarred_audio_filepath is None or len(tarred_audio_filepath) == 0:
|
441 |
+
if tokenizer is None:
|
442 |
+
dataset = get_char_dataset(config=conf, augmentor=None)
|
443 |
+
else:
|
444 |
+
dataset = get_bpe_dataset(config=conf, tokenizer=tokenizer, augmentor=None)
|
445 |
+
else:
|
446 |
+
dataset = get_tarred_dataset(
|
447 |
+
config=conf,
|
448 |
+
tokenizer=tokenizer,
|
449 |
+
shuffle_n=shuffle_n,
|
450 |
+
global_rank=global_rank,
|
451 |
+
world_size=world_size,
|
452 |
+
augmentor=None,
|
453 |
+
)
|
454 |
+
datasets.append(dataset)
|
455 |
+
|
456 |
+
config = OmegaConf.to_container(config)
|
457 |
+
|
458 |
+
dataset = CodeSwitchedDataset(
|
459 |
+
datasets,
|
460 |
+
shuffle=cs_config.get('shuffle', True),
|
461 |
+
min_duration=cs_config.get('min_duration', 4),
|
462 |
+
max_duration=cs_config.get('max_duration', 20),
|
463 |
+
min_monolingual=cs_config.get('min_monolingual', 0.3),
|
464 |
+
lang_probs=cs_config.get('probs', None),
|
465 |
+
db_norm=cs_config.get('db_norm', -25.0),
|
466 |
+
pause_start=cs_config.get('pause_start', 0),
|
467 |
+
pause_join=cs_config.get('pause_join', 0),
|
468 |
+
pause_end=cs_config.get('pause_end', 0),
|
469 |
+
sampling_scales=cs_config.get('sampling_scales', None),
|
470 |
+
seed=cs_config.get('seed', None),
|
471 |
+
global_rank=global_rank,
|
472 |
+
world_size=world_size,
|
473 |
+
pure_random=cs_config.get('pure_random', False),
|
474 |
+
force_monochannel=cs_config.get('force_monochannel', True),
|
475 |
+
infinity_mode=cs_config.get('infinity_mode', False),
|
476 |
+
sample_rate=config['sample_rate'],
|
477 |
+
augmentor=augmentor,
|
478 |
+
)
|
479 |
+
|
480 |
+
return dataset
|
481 |
+
|
482 |
+
|
483 |
+
def get_dali_char_dataset(
|
484 |
+
config: dict,
|
485 |
+
shuffle: bool,
|
486 |
+
device_id: int,
|
487 |
+
global_rank: int,
|
488 |
+
world_size: int,
|
489 |
+
preprocessor_cfg: Optional[DictConfig] = None,
|
490 |
+
) -> audio_to_text_dali.AudioToCharDALIDataset:
|
491 |
+
"""
|
492 |
+
Instantiates a Character Encoding based AudioToCharDALIDataset.
|
493 |
+
|
494 |
+
Args:
|
495 |
+
config: Config of the AudioToCharDALIDataset.
|
496 |
+
shuffle: Bool flag whether to shuffle the dataset.
|
497 |
+
device_id: Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
|
498 |
+
global_rank: Global rank of this device.
|
499 |
+
world_size: Global world size in the training method.
|
500 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
501 |
+
preprocessor_cfg: Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
An instance of AudioToCharDALIDataset.
|
505 |
+
"""
|
506 |
+
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
507 |
+
dataset = audio_to_text_dali.AudioToCharDALIDataset(
|
508 |
+
manifest_filepath=config['manifest_filepath'],
|
509 |
+
device=device,
|
510 |
+
batch_size=config['batch_size'],
|
511 |
+
labels=config['labels'],
|
512 |
+
sample_rate=config['sample_rate'],
|
513 |
+
audio_tar_filepaths=config.get('tarred_audio_filepaths', None),
|
514 |
+
audio_tar_index_filepaths=config.get('tarred_audio_index_filepaths', None),
|
515 |
+
max_duration=config.get('max_duration', None),
|
516 |
+
min_duration=config.get('min_duration', None),
|
517 |
+
blank_index=config.get('blank_index', -1),
|
518 |
+
unk_index=config.get('unk_index', -1),
|
519 |
+
normalize=config.get('normalize_transcripts', False),
|
520 |
+
trim=config.get('trim_silence', False),
|
521 |
+
parser=config.get('parser', 'en'),
|
522 |
+
shuffle=shuffle,
|
523 |
+
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
524 |
+
device_id=device_id,
|
525 |
+
global_rank=global_rank,
|
526 |
+
world_size=world_size,
|
527 |
+
preprocessor_cfg=preprocessor_cfg,
|
528 |
+
return_sample_id=config.get('return_sample_id', False),
|
529 |
+
)
|
530 |
+
return dataset
|
531 |
+
|
532 |
+
|
533 |
+
def get_dali_bpe_dataset(
|
534 |
+
config: dict,
|
535 |
+
tokenizer,
|
536 |
+
shuffle: bool,
|
537 |
+
device_id: int,
|
538 |
+
global_rank: int,
|
539 |
+
world_size: int,
|
540 |
+
preprocessor_cfg: Optional[DictConfig] = None,
|
541 |
+
) -> audio_to_text_dali.AudioToCharDALIDataset:
|
542 |
+
"""
|
543 |
+
Instantiates a Subword Encoding based AudioToBPEDALIDataset.
|
544 |
+
|
545 |
+
Args:
|
546 |
+
config: Config of the AudioToBPEDALIDataset.
|
547 |
+
tokenizer: An implementation of NeMo TokenizerSpec.
|
548 |
+
shuffle: Bool flag whether to shuffle the dataset.
|
549 |
+
device_id: Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
|
550 |
+
global_rank: Global rank of this device.
|
551 |
+
world_size: Global world size in the training method.
|
552 |
+
preprocessor_cfg: Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
553 |
+
|
554 |
+
Returns:
|
555 |
+
An instance of AudioToCharDALIDataset.
|
556 |
+
"""
|
557 |
+
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
558 |
+
dataset = audio_to_text_dali.AudioToBPEDALIDataset(
|
559 |
+
manifest_filepath=config['manifest_filepath'],
|
560 |
+
tokenizer=tokenizer,
|
561 |
+
device=device,
|
562 |
+
batch_size=config['batch_size'],
|
563 |
+
sample_rate=config['sample_rate'],
|
564 |
+
audio_tar_filepaths=config.get('tarred_audio_filepaths', None),
|
565 |
+
audio_tar_index_filepaths=config.get('tarred_audio_index_filepaths', None),
|
566 |
+
max_duration=config.get('max_duration', None),
|
567 |
+
min_duration=config.get('min_duration', None),
|
568 |
+
trim=config.get('trim_silence', False),
|
569 |
+
use_start_end_token=config.get('use_start_end_token', True),
|
570 |
+
shuffle=shuffle,
|
571 |
+
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
572 |
+
device_id=device_id,
|
573 |
+
global_rank=global_rank,
|
574 |
+
world_size=world_size,
|
575 |
+
preprocessor_cfg=preprocessor_cfg,
|
576 |
+
return_sample_id=config.get('return_sample_id', False),
|
577 |
+
)
|
578 |
+
return dataset
|
579 |
+
|
580 |
+
|
581 |
+
def get_audio_to_text_char_dataset_from_config(
|
582 |
+
config, local_rank: int, global_rank: int, world_size: int, preprocessor_cfg: Optional[DictConfig] = None
|
583 |
+
):
|
584 |
+
"""
|
585 |
+
Construct Audio-To-Text Char dataset from a config.
|
586 |
+
Args:
|
587 |
+
config: dataset config
|
588 |
+
local_rank: model local rank
|
589 |
+
global_rank: model global rand
|
590 |
+
world_size: world size
|
591 |
+
preprocessor_cfg: preprocessor config, for DALI dataset
|
592 |
+
|
593 |
+
Returns:
|
594 |
+
constructed dataset or None if dataset config is invalid or nothing to load
|
595 |
+
"""
|
596 |
+
if 'augmentor' in config:
|
597 |
+
augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size)
|
598 |
+
else:
|
599 |
+
augmentor = None
|
600 |
+
|
601 |
+
is_concat = config.get('is_concat', False)
|
602 |
+
if is_concat:
|
603 |
+
if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
|
604 |
+
logging.warning(
|
605 |
+
f"Concat dataset requires `concat_sampling_technique` but it was not provided. Config: {config}"
|
606 |
+
)
|
607 |
+
return None
|
608 |
+
if config['concat_sampling_technique'] == 'random':
|
609 |
+
if not 'concat_sampling_probabilities' in config:
|
610 |
+
logging.warning(f"Concat dataset requires `concat_sampling_probabilities` list. Config: {config}")
|
611 |
+
return None
|
612 |
+
else:
|
613 |
+
if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
|
614 |
+
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
|
615 |
+
return None
|
616 |
+
|
617 |
+
shuffle = config['shuffle']
|
618 |
+
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
619 |
+
if config.get('use_dali', False):
|
620 |
+
device_id = local_rank if device == 'gpu' else None
|
621 |
+
dataset = get_dali_char_dataset(
|
622 |
+
config=config,
|
623 |
+
shuffle=shuffle,
|
624 |
+
device_id=device_id,
|
625 |
+
global_rank=global_rank,
|
626 |
+
world_size=world_size,
|
627 |
+
preprocessor_cfg=preprocessor_cfg,
|
628 |
+
)
|
629 |
+
return dataset
|
630 |
+
|
631 |
+
# Instantiate a code-switched dataset if config is present
|
632 |
+
if config.get('is_code_switched', False):
|
633 |
+
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
|
634 |
+
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
|
635 |
+
return None
|
636 |
+
if not ('code_switched' in config and config['code_switched'] is not None):
|
637 |
+
logging.warning(
|
638 |
+
f"Code switched dataset requires `*_ds.code_switched.*` dict but it was not provided. Config: {config}"
|
639 |
+
)
|
640 |
+
return None
|
641 |
+
if (
|
642 |
+
('probs' in config['code_switched'])
|
643 |
+
and (config['code_switched']['probs'] is not None)
|
644 |
+
and (not isclose(sum(config['code_switched']['probs']), 1, abs_tol=1e-6))
|
645 |
+
):
|
646 |
+
logging.warning(f"`.code_switched.probs` need to sum to 1. Config: {config['code_switched']}")
|
647 |
+
return None
|
648 |
+
|
649 |
+
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
650 |
+
dataset = get_code_switched_dataset(
|
651 |
+
config=config,
|
652 |
+
shuffle_n=shuffle_n,
|
653 |
+
global_rank=global_rank,
|
654 |
+
world_size=world_size,
|
655 |
+
tokenizer=None,
|
656 |
+
augmentor=augmentor,
|
657 |
+
)
|
658 |
+
# Instantiate tarred dataset loader or normal dataset loader
|
659 |
+
elif config.get('is_tarred', False):
|
660 |
+
if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
|
661 |
+
'manifest_filepath' in config and config['manifest_filepath'] is None
|
662 |
+
):
|
663 |
+
logging.warning(
|
664 |
+
"Could not load dataset as `manifest_filepath` was None or "
|
665 |
+
f"`tarred_audio_filepaths` is None. Provided config : {config}"
|
666 |
+
)
|
667 |
+
return None
|
668 |
+
|
669 |
+
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
670 |
+
if is_concat:
|
671 |
+
dataset = get_concat_tarred_dataset(
|
672 |
+
config=config,
|
673 |
+
shuffle_n=shuffle_n,
|
674 |
+
global_rank=global_rank,
|
675 |
+
world_size=world_size,
|
676 |
+
augmentor=augmentor,
|
677 |
+
)
|
678 |
+
else:
|
679 |
+
dataset = get_tarred_dataset(
|
680 |
+
config=config,
|
681 |
+
shuffle_n=shuffle_n,
|
682 |
+
global_rank=global_rank,
|
683 |
+
world_size=world_size,
|
684 |
+
augmentor=augmentor,
|
685 |
+
)
|
686 |
+
else:
|
687 |
+
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
|
688 |
+
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
|
689 |
+
return None
|
690 |
+
if is_concat:
|
691 |
+
dataset = get_concat_char_dataset(
|
692 |
+
config=config, global_rank=global_rank, world_size=world_size, augmentor=augmentor
|
693 |
+
)
|
694 |
+
else:
|
695 |
+
dataset = get_char_dataset(config=config, augmentor=augmentor)
|
696 |
+
return dataset
|
697 |
+
|
698 |
+
|
699 |
+
def get_audio_to_text_bpe_dataset_from_config(
|
700 |
+
config,
|
701 |
+
local_rank: int,
|
702 |
+
global_rank: int,
|
703 |
+
world_size: int,
|
704 |
+
tokenizer,
|
705 |
+
preprocessor_cfg: Optional[DictConfig] = None,
|
706 |
+
):
|
707 |
+
"""
|
708 |
+
Construct Audio-To-Text BPE dataset from a config.
|
709 |
+
Args:
|
710 |
+
config: BPE dataset config
|
711 |
+
local_rank: model local rank
|
712 |
+
global_rank: model global rand
|
713 |
+
world_size: world size
|
714 |
+
tokenizer: BPE tokenizer
|
715 |
+
preprocessor_cfg: preprocessor config, for DALI BPE dataset
|
716 |
+
|
717 |
+
Returns:
|
718 |
+
constructed dataset or None if dataset config is invalid or nothing to load
|
719 |
+
"""
|
720 |
+
if 'augmentor' in config:
|
721 |
+
augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size)
|
722 |
+
else:
|
723 |
+
augmentor = None
|
724 |
+
|
725 |
+
is_concat = config.get('is_concat', False)
|
726 |
+
if is_concat:
|
727 |
+
if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
|
728 |
+
logging.warning(
|
729 |
+
f"Concat dataset requires `concat_sampling_technique` but it was not provided. Config: {config}"
|
730 |
+
)
|
731 |
+
return None
|
732 |
+
|
733 |
+
if config['concat_sampling_technique'] == 'random':
|
734 |
+
if not 'concat_sampling_probabilities' in config:
|
735 |
+
logging.warning(f"Concat dataset requires `concat_sampling_probabilities` list. Config: {config}")
|
736 |
+
return None
|
737 |
+
else:
|
738 |
+
if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
|
739 |
+
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
|
740 |
+
return None
|
741 |
+
|
742 |
+
shuffle = config['shuffle']
|
743 |
+
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
744 |
+
if config.get('use_dali', False):
|
745 |
+
device_id = local_rank if device == 'gpu' else None
|
746 |
+
dataset = get_dali_bpe_dataset(
|
747 |
+
config=config,
|
748 |
+
tokenizer=tokenizer,
|
749 |
+
shuffle=shuffle,
|
750 |
+
device_id=device_id,
|
751 |
+
global_rank=global_rank,
|
752 |
+
world_size=world_size,
|
753 |
+
preprocessor_cfg=preprocessor_cfg,
|
754 |
+
)
|
755 |
+
return dataset
|
756 |
+
|
757 |
+
# Instantiate a code-switched dataset if config is present
|
758 |
+
if config.get('is_code_switched', False):
|
759 |
+
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
|
760 |
+
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
|
761 |
+
return None
|
762 |
+
if not ('code_switched' in config and config['code_switched'] is not None):
|
763 |
+
logging.warning(
|
764 |
+
f"Code switched dataset requires `*_ds.code_switched.*` dict but it was not provided. Config: {config}"
|
765 |
+
)
|
766 |
+
return None
|
767 |
+
if (
|
768 |
+
('probs' in config['code_switched'])
|
769 |
+
and (config['code_switched']['probs'] is not None)
|
770 |
+
and (not isclose(sum(config['code_switched']['probs']), 1, abs_tol=1e-6))
|
771 |
+
):
|
772 |
+
logging.warning(f"`.code_switched.probs` need to sum to 1. Config: {config['code_switched']}")
|
773 |
+
return None
|
774 |
+
|
775 |
+
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
776 |
+
dataset = get_code_switched_dataset(
|
777 |
+
config=config,
|
778 |
+
shuffle_n=shuffle_n,
|
779 |
+
global_rank=global_rank,
|
780 |
+
world_size=world_size,
|
781 |
+
tokenizer=tokenizer,
|
782 |
+
augmentor=augmentor,
|
783 |
+
)
|
784 |
+
# Instantiate tarred dataset loader or normal dataset loader
|
785 |
+
elif config.get('is_tarred', False):
|
786 |
+
if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
|
787 |
+
'manifest_filepath' in config and config['manifest_filepath'] is None
|
788 |
+
):
|
789 |
+
logging.warning(
|
790 |
+
"Could not load dataset as `manifest_filepath` was None or "
|
791 |
+
f"`tarred_audio_filepaths` is None. Provided config : {config}"
|
792 |
+
)
|
793 |
+
return None
|
794 |
+
|
795 |
+
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
796 |
+
if is_concat:
|
797 |
+
dataset = get_concat_tarred_dataset(
|
798 |
+
config=config,
|
799 |
+
tokenizer=tokenizer,
|
800 |
+
shuffle_n=shuffle_n,
|
801 |
+
global_rank=global_rank,
|
802 |
+
world_size=world_size,
|
803 |
+
augmentor=augmentor,
|
804 |
+
)
|
805 |
+
else:
|
806 |
+
dataset = get_tarred_dataset(
|
807 |
+
config=config,
|
808 |
+
tokenizer=tokenizer,
|
809 |
+
shuffle_n=shuffle_n,
|
810 |
+
global_rank=global_rank,
|
811 |
+
world_size=world_size,
|
812 |
+
augmentor=augmentor,
|
813 |
+
)
|
814 |
+
else:
|
815 |
+
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
|
816 |
+
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
|
817 |
+
return None
|
818 |
+
if is_concat:
|
819 |
+
dataset = get_concat_bpe_dataset(
|
820 |
+
config=config,
|
821 |
+
global_rank=global_rank,
|
822 |
+
world_size=world_size,
|
823 |
+
tokenizer=tokenizer,
|
824 |
+
augmentor=augmentor,
|
825 |
+
)
|
826 |
+
else:
|
827 |
+
dataset = get_bpe_dataset(config=config, tokenizer=tokenizer, augmentor=augmentor)
|
828 |
+
return dataset
|
829 |
+
|
830 |
+
|
831 |
+
class ASRPredictionWriter(BasePredictionWriter):
|
832 |
+
def __init__(self, dataset, output_file: str):
|
833 |
+
super().__init__(write_interval="batch")
|
834 |
+
self.outf = open(output_file, 'w', encoding='utf-8')
|
835 |
+
self.dataset = dataset
|
836 |
+
self.samples_num = 0
|
837 |
+
|
838 |
+
def write_on_batch_end(
|
839 |
+
self,
|
840 |
+
trainer,
|
841 |
+
pl_module: 'LightningModule',
|
842 |
+
prediction: Any,
|
843 |
+
batch_indices: List[int],
|
844 |
+
batch: Any,
|
845 |
+
batch_idx: int,
|
846 |
+
dataloader_idx: int,
|
847 |
+
):
|
848 |
+
for sample_id, transcribed_text in prediction:
|
849 |
+
item = {}
|
850 |
+
sample = self.dataset.get_manifest_sample(sample_id)
|
851 |
+
item["audio_filepath"] = sample.audio_file
|
852 |
+
item["offset"] = sample.offset
|
853 |
+
item["duration"] = sample.duration
|
854 |
+
item["text"] = sample.text_raw
|
855 |
+
item["pred_text"] = transcribed_text
|
856 |
+
self.outf.write(json.dumps(item) + "\n")
|
857 |
+
self.samples_num += 1
|
858 |
+
return
|
859 |
+
|
860 |
+
def close_output_file(self):
|
861 |
+
self.outf.close()
|
862 |
+
return self.samples_num
|
863 |
+
|
864 |
+
|
865 |
+
def convert_to_config_list(initial_list):
|
866 |
+
if type(initial_list) is str:
|
867 |
+
initial_list = initial_list.split(",")
|
868 |
+
if initial_list is None or initial_list == []:
|
869 |
+
raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")
|
870 |
+
if not isinstance(initial_list, ListConfig):
|
871 |
+
initial_list = ListConfig([initial_list])
|
872 |
+
|
873 |
+
for list_idx, list_val in enumerate(initial_list):
|
874 |
+
if type(list_val) != type(initial_list[0]):
|
875 |
+
raise ValueError(
|
876 |
+
"manifest_filepaths and tarred_audio_filepaths need to be a list of lists for bucketing or just a list of strings"
|
877 |
+
)
|
878 |
+
if type(initial_list[0]) is not ListConfig:
|
879 |
+
initial_list = ListConfig([initial_list])
|
880 |
+
return initial_list
|
881 |
+
|
882 |
+
|
883 |
+
def get_chain_dataset(datasets, ds_config, rank=0):
|
884 |
+
if len(datasets) > 1:
|
885 |
+
if ds_config.get('bucketing_batch_size', None) is not None:
|
886 |
+
bucketing_batch_sizes = calc_bucketing_batch_sizes(ds_config, len(datasets))
|
887 |
+
logging.info(
|
888 |
+
f"Batch bucketing is enabled for {len(datasets)} buckets with adaptive batch sizes of {bucketing_batch_sizes}!"
|
889 |
+
)
|
890 |
+
for idx, dataset in enumerate(datasets):
|
891 |
+
datasets[idx] = audio_to_text.BucketingDataset(
|
892 |
+
dataset=dataset, bucketing_batch_size=bucketing_batch_sizes[idx]
|
893 |
+
)
|
894 |
+
else:
|
895 |
+
logging.info(
|
896 |
+
f"Batch bucketing is enabled for {len(datasets)} buckets with fixed batch size of {ds_config['batch_size']}!"
|
897 |
+
)
|
898 |
+
|
899 |
+
if len(datasets) == 1:
|
900 |
+
return datasets[0]
|
901 |
+
bucketing_strategy = ds_config.get('bucketing_strategy', 'synced_randomized')
|
902 |
+
if bucketing_strategy == 'fixed_order':
|
903 |
+
return ChainDataset(datasets)
|
904 |
+
elif bucketing_strategy == 'synced_randomized':
|
905 |
+
return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=0)
|
906 |
+
elif bucketing_strategy == 'fully_randomized':
|
907 |
+
return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=random.randint(0, 30000) + rank)
|
908 |
+
else:
|
909 |
+
raise ValueError(
|
910 |
+
f'bucketing_strategy={bucketing_strategy} is not supported! Supported strategies are [fixed_order, fully_randomized, synced_randomized].'
|
911 |
+
)
|
912 |
+
|
913 |
+
|
914 |
+
def calc_bucketing_batch_sizes(ds_config, datasets_len):
|
915 |
+
bucketing_batch_size = ds_config['bucketing_batch_size']
|
916 |
+
bucketing_weights = ds_config.get('bucketing_weights', None) # To adjust for upsampled buckets
|
917 |
+
|
918 |
+
bucketing_batch_sizes = []
|
919 |
+
|
920 |
+
if ds_config['batch_size'] != 1:
|
921 |
+
raise ValueError(
|
922 |
+
f"batch_size should be set to one when bucketing_batch_size is set and adaptive bucketing is enabled (batch_size={ds_config['batch_size']}!"
|
923 |
+
)
|
924 |
+
if type(bucketing_batch_size) == int: # linear scaling
|
925 |
+
if bucketing_weights: # Want same batchsize for the same duplicated bucket
|
926 |
+
for idx, weight in enumerate(bucketing_weights):
|
927 |
+
scale_factor = datasets_len - idx
|
928 |
+
[bucketing_batch_sizes.append(scale_factor * bucketing_batch_size) for _ in range(weight)]
|
929 |
+
else:
|
930 |
+
for idx in range(datasets_len):
|
931 |
+
scale_factor = datasets_len - idx
|
932 |
+
bucketing_batch_sizes.append(scale_factor * bucketing_batch_size)
|
933 |
+
elif isinstance(bucketing_batch_size, ListConfig) or isinstance(
|
934 |
+
bucketing_batch_size, list
|
935 |
+
): # assigned bucket sizes
|
936 |
+
if bucketing_weights: # Want same batchsize for same duplicated bucket
|
937 |
+
for idx, weight in enumerate(bucketing_weights):
|
938 |
+
[bucketing_batch_sizes.append(bucketing_batch_size[idx]) for _ in range(weight)]
|
939 |
+
else:
|
940 |
+
bucketing_batch_sizes = bucketing_batch_size
|
941 |
+
else:
|
942 |
+
raise ValueError(
|
943 |
+
f"bucketing_batch_size should be an integer or a list (bucketing_batch_size={bucketing_batch_size})!"
|
944 |
+
)
|
945 |
+
|
946 |
+
if len(bucketing_batch_sizes) != datasets_len:
|
947 |
+
raise ValueError(
|
948 |
+
f"batch_size should have the same length as the number of buckets ({len(bucketing_batch_sizes)}!={datasets_len}) "
|
949 |
+
)
|
950 |
+
return bucketing_batch_sizes
|
SoundScribe/SpeakerID/nemo/collections/asr/data/data_simulation.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, List, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader
|
19 |
+
from nemo.collections.common.parts.preprocessing import collections
|
20 |
+
from nemo.core.classes import Dataset
|
21 |
+
from nemo.core.neural_types import AcousticEncodedRepresentation, LabelsType, LengthsType, NeuralType
|
22 |
+
from nemo.utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
def _feature_collate_fn(batch):
|
26 |
+
"""collate batch of feat sig, feat len, labels, labels len, assuming all features have the same shape.
|
27 |
+
Args:
|
28 |
+
batch (FloatTensor, LongTensor, LongTensor, LongTensor): A tuple of tuples of feature, feature lengths,
|
29 |
+
encoded labels, and encoded labels length.
|
30 |
+
"""
|
31 |
+
packed_batch = list(zip(*batch))
|
32 |
+
if len(packed_batch) == 5:
|
33 |
+
_, feat_lengths, _, labels_lengths, sample_ids = packed_batch
|
34 |
+
elif len(packed_batch) == 4:
|
35 |
+
sample_ids = None
|
36 |
+
_, feat_lengths, _, labels_lengths = packed_batch
|
37 |
+
else:
|
38 |
+
raise ValueError("Expects 4 or 5 tensors in the batch!")
|
39 |
+
|
40 |
+
features, labels = [], []
|
41 |
+
for b in batch:
|
42 |
+
feat_i, labels_i = b[0], b[2]
|
43 |
+
features.append(feat_i)
|
44 |
+
labels.append(labels_i)
|
45 |
+
|
46 |
+
features = torch.stack(features)
|
47 |
+
feat_lengths = torch.stack(feat_lengths)
|
48 |
+
|
49 |
+
labels = torch.stack(labels)
|
50 |
+
labels_lengths = torch.stack(labels_lengths)
|
51 |
+
|
52 |
+
if sample_ids is None:
|
53 |
+
return features, feat_lengths, labels, labels_lengths
|
54 |
+
else:
|
55 |
+
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
|
56 |
+
return features, feat_lengths, labels, labels_lengths, sample_ids
|
57 |
+
|
58 |
+
|
59 |
+
def _audio_feature_collate_fn(batch, feat_pad_val, label_pad_id):
|
60 |
+
"""collate batch of audio feature, audio len, labels, labels len
|
61 |
+
Args:
|
62 |
+
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
|
63 |
+
LongTensor): A tuple of tuples of feature, feature lengths,
|
64 |
+
labels, and label lengths. This collate func assumes the
|
65 |
+
features are torch tensors of Log-Melspectrogram (i.e. [N_MEL, T]).
|
66 |
+
"""
|
67 |
+
packed_batch = list(zip(*batch))
|
68 |
+
if len(packed_batch) == 5:
|
69 |
+
_, feat_lengths, _, labels_lengths, sample_ids = packed_batch
|
70 |
+
elif len(packed_batch) == 4:
|
71 |
+
sample_ids = None
|
72 |
+
_, feat_lengths, _, labels_lengths = packed_batch
|
73 |
+
else:
|
74 |
+
raise ValueError("Expects 4 or 5 tensors in the batch!")
|
75 |
+
max_feat_len = 0
|
76 |
+
has_feat = feat_lengths[0] is not None
|
77 |
+
if has_feat:
|
78 |
+
max_feat_len = max(feat_lengths).item()
|
79 |
+
max_labels_len = max(labels_lengths).item()
|
80 |
+
|
81 |
+
features, labels = [], []
|
82 |
+
for b in batch:
|
83 |
+
feat_i, feat_i_len, label_i, label_i_len = b[0], b[1], b[2], b[3]
|
84 |
+
|
85 |
+
if has_feat:
|
86 |
+
feat_i_len = feat_i_len.item()
|
87 |
+
if feat_i_len < max_feat_len:
|
88 |
+
pad = (0, max_feat_len - feat_i_len)
|
89 |
+
feat_i = torch.nn.functional.pad(feat_i, pad, value=feat_pad_val)
|
90 |
+
features.append(feat_i)
|
91 |
+
|
92 |
+
label_i_len = label_i_len.item()
|
93 |
+
if label_i_len < max_labels_len:
|
94 |
+
pad = (0, max_labels_len - label_i_len)
|
95 |
+
label_i = torch.nn.functional.pad(label_i, pad, value=label_pad_id)
|
96 |
+
labels.append(label_i)
|
97 |
+
|
98 |
+
if has_feat:
|
99 |
+
features = torch.stack(features)
|
100 |
+
feature_lengths = torch.stack(feat_lengths)
|
101 |
+
else:
|
102 |
+
features, feat_lengths = None, None
|
103 |
+
labels = torch.stack(labels)
|
104 |
+
labels_lengths = torch.stack(labels_lengths)
|
105 |
+
|
106 |
+
if sample_ids is None:
|
107 |
+
return features, feature_lengths, labels, labels_lengths
|
108 |
+
else:
|
109 |
+
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
|
110 |
+
return features, feature_lengths, labels, labels_lengths, sample_ids
|
111 |
+
|
112 |
+
|
113 |
+
def _vad_feature_segment_collate_fn(batch, window_length_in_sec, shift_length_in_sec, frame_unit_in_sec):
|
114 |
+
"""collate batch of audio features, features len, tokens, tokens len
|
115 |
+
Args:
|
116 |
+
batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
|
117 |
+
LongTensor): A tuple of tuples of signal, signal lengths,
|
118 |
+
encoded tokens, and encoded tokens length. This collate func
|
119 |
+
assumes the signals are 1d torch tensors (i.e. mono audio).
|
120 |
+
batch size equals to 1.
|
121 |
+
"""
|
122 |
+
slice_length = int(window_length_in_sec / frame_unit_in_sec)
|
123 |
+
audio_features, feat_lengths, _, tokens_lengths = zip(*batch)
|
124 |
+
|
125 |
+
slice_length = int(min(slice_length, max(feat_lengths)))
|
126 |
+
shift = int(shift_length_in_sec / frame_unit_in_sec)
|
127 |
+
has_audio = feat_lengths[0] is not None
|
128 |
+
|
129 |
+
f_dim = audio_features[0].shape[0]
|
130 |
+
audio_features, num_slices, tokens, feat_lengths = [], [], [], []
|
131 |
+
append_len_start = torch.div(slice_length, 2, rounding_mode='trunc')
|
132 |
+
append_len_end = slice_length - torch.div(slice_length, 2, rounding_mode='trunc')
|
133 |
+
for feat_i, feat_i_len, tokens_i, _ in batch:
|
134 |
+
start = torch.zeros(f_dim, append_len_start)
|
135 |
+
end = torch.zeros(f_dim, append_len_end)
|
136 |
+
feat_i = torch.cat((start, feat_i, end), dim=1)
|
137 |
+
feat_i_len += slice_length
|
138 |
+
|
139 |
+
if has_audio:
|
140 |
+
slices = max(1, torch.div(feat_i_len - slice_length, shift, rounding_mode='trunc'))
|
141 |
+
|
142 |
+
for slice_id in range(slices):
|
143 |
+
start_idx = slice_id * shift
|
144 |
+
end_idx = start_idx + slice_length
|
145 |
+
feat_slice = feat_i[:, start_idx:end_idx]
|
146 |
+
audio_features.append(feat_slice)
|
147 |
+
|
148 |
+
num_slices.append(slices)
|
149 |
+
tokens.extend([tokens_i] * slices)
|
150 |
+
feat_lengths.extend([slice_length] * slices)
|
151 |
+
|
152 |
+
if has_audio:
|
153 |
+
audio_features = torch.stack(audio_features)
|
154 |
+
feat_lengths = torch.tensor(feat_lengths)
|
155 |
+
else:
|
156 |
+
audio_features, feat_lengths = None, None
|
157 |
+
|
158 |
+
tokens = torch.stack(tokens)
|
159 |
+
tokens_lengths = torch.tensor(num_slices)
|
160 |
+
return audio_features, feat_lengths, tokens, tokens_lengths
|
161 |
+
|
162 |
+
|
163 |
+
class _FeatureSeqSpeakerLabelDataset(Dataset):
|
164 |
+
"""
|
165 |
+
Dataset that loads tensors via a json file containing paths to feature files, sequences of labels.
|
166 |
+
Each new line is a different sample. Example below:
|
167 |
+
and their target labels. JSON files should be of the following format:
|
168 |
+
{"feature_filepath": "/path/to/feature_0.p", "seq_label": speakerA speakerB SpeakerA ....} \
|
169 |
+
...
|
170 |
+
{"feature_filepath": "/path/to/feature_n.p", "seq_label": target_seq_label_n}
|
171 |
+
target_seq_label_n is the string of sequence of speaker label, separated by space.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
manifest_filepath (str): Dataset parameter. Path to JSON containing data.
|
175 |
+
labels (Optional[list]): Dataset parameter. List of unique labels collected from all samples.
|
176 |
+
feature_loader : Dataset parameter. Feature loader to load (external) feature.
|
177 |
+
"""
|
178 |
+
|
179 |
+
@property
|
180 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
181 |
+
"""Returns definitions of module output ports.
|
182 |
+
"""
|
183 |
+
# TODO output type for external features
|
184 |
+
output_types = {
|
185 |
+
'external_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
|
186 |
+
'feat_length': NeuralType(tuple('B'), LengthsType()),
|
187 |
+
}
|
188 |
+
|
189 |
+
if self.is_speaker_emb:
|
190 |
+
output_types.update(
|
191 |
+
{
|
192 |
+
'embs': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
|
193 |
+
'embs_length': NeuralType(tuple('B'), LengthsType()),
|
194 |
+
'label': NeuralType(('B', 'T'), LabelsType()),
|
195 |
+
'label_length': NeuralType(tuple('B'), LengthsType()),
|
196 |
+
}
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
output_types.update(
|
200 |
+
{'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),}
|
201 |
+
)
|
202 |
+
|
203 |
+
return output_types
|
204 |
+
|
205 |
+
def __init__(
|
206 |
+
self, *, manifest_filepath: str, labels: List[str], feature_loader, is_speaker_emb: bool = False,
|
207 |
+
):
|
208 |
+
super().__init__()
|
209 |
+
self.collection = collections.ASRFeatureSequenceLabel(manifests_files=manifest_filepath.split(','),)
|
210 |
+
|
211 |
+
self.feature_loader = feature_loader
|
212 |
+
self.labels = labels if labels else self.collection.uniq_labels
|
213 |
+
self.is_speaker_emb = is_speaker_emb
|
214 |
+
|
215 |
+
self.label2id, self.id2label = {}, {}
|
216 |
+
for label_id, label in enumerate(self.labels):
|
217 |
+
self.label2id[label] = label_id
|
218 |
+
self.id2label[label_id] = label
|
219 |
+
|
220 |
+
for idx in range(len(self.labels[:5])):
|
221 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
222 |
+
|
223 |
+
def __len__(self):
|
224 |
+
return len(self.collection)
|
225 |
+
|
226 |
+
def __getitem__(self, index):
|
227 |
+
sample = self.collection[index]
|
228 |
+
|
229 |
+
features = self.feature_loader.process(sample.feature_file)
|
230 |
+
f, fl = features, torch.tensor(features.shape[0]).long()
|
231 |
+
|
232 |
+
t = torch.tensor(sample.seq_label).float()
|
233 |
+
tl = torch.tensor(len(sample.seq_label)).long()
|
234 |
+
|
235 |
+
return f, fl, t, tl
|
236 |
+
|
237 |
+
|
238 |
+
class FeatureToSeqSpeakerLabelDataset(_FeatureSeqSpeakerLabelDataset):
|
239 |
+
"""
|
240 |
+
Dataset that loads tensors via a json file containing paths to feature
|
241 |
+
files and sequence of speakers. Each new line is a
|
242 |
+
different sample. Example below:
|
243 |
+
{"feature_filepath": "/path/to/feature_0.p", "seq_label": speakerA speakerB SpeakerA ....} \
|
244 |
+
...
|
245 |
+
{"feature_filepath": "/path/to/feature_n.p", "seq_label": target_seq_label_n}
|
246 |
+
target_seq_label_n is the string of sequence of speaker label, separated by space.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
manifest_filepath (str): Path to manifest json as described above. Canbe comma-separated paths.
|
250 |
+
labels (Optional[list]): String containing all the possible labels to map to
|
251 |
+
if None then automatically picks from ASRFeatureSequenceLabel collection.
|
252 |
+
feature_loader, Feature load to loader (external) feature.
|
253 |
+
|
254 |
+
"""
|
255 |
+
|
256 |
+
def _collate_fn(self, batch):
|
257 |
+
return _feature_collate_fn(batch)
|
258 |
+
|
259 |
+
|
260 |
+
class FeatureToLabelDataset(Dataset):
|
261 |
+
"""
|
262 |
+
Dataset that loads tensors via a json file containing paths to feature files and their labels.
|
263 |
+
Each new line is a different sample. Example below:
|
264 |
+
and their target labels. JSON files should be of the following format:
|
265 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "label": "1"}
|
266 |
+
...
|
267 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "label": "0"}
|
268 |
+
Args:
|
269 |
+
manifest_filepath (str): Path to JSON containing data.
|
270 |
+
labels (Optional[list]): List of unique labels collected from all samples.
|
271 |
+
augmentor (Optional): feature augmentation
|
272 |
+
window_length_in_sec (float): Window length in seconds.
|
273 |
+
shift_length_in_sec (float): Shift length in seconds.
|
274 |
+
is_regression_task (bool): if True, the labels are treated as for a regression task.
|
275 |
+
cal_labels_occurrence (bool): if True, the labels occurrence will be calculated.
|
276 |
+
zero_spec_db_val (float): Value to replace non-speech signals in log-melspectrogram.
|
277 |
+
min_duration (float): Minimum duration of the audio file in seconds.
|
278 |
+
max_duration (float): Maximum duration of the audio file in seconds.
|
279 |
+
"""
|
280 |
+
|
281 |
+
ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
|
282 |
+
FRAME_UNIT_TIME_SECS = 0.01
|
283 |
+
|
284 |
+
@property
|
285 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
286 |
+
"""Returns definitions of module output ports.
|
287 |
+
"""
|
288 |
+
output_types = {
|
289 |
+
'audio_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
|
290 |
+
'feat_length': NeuralType(tuple('B'), LengthsType()),
|
291 |
+
'labels': NeuralType(('B'), LabelsType()),
|
292 |
+
'labels_length': NeuralType(tuple('B'), LengthsType()),
|
293 |
+
}
|
294 |
+
|
295 |
+
return output_types
|
296 |
+
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
*,
|
300 |
+
manifest_filepath: str,
|
301 |
+
labels: List[str] = None,
|
302 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
303 |
+
window_length_in_sec: float = 0.63,
|
304 |
+
shift_length_in_sec: float = 0.01,
|
305 |
+
is_regression_task: bool = False,
|
306 |
+
cal_labels_occurrence: Optional[bool] = False,
|
307 |
+
zero_spec_db_val: float = -16.635,
|
308 |
+
min_duration: Optional[float] = None,
|
309 |
+
max_duration: Optional[float] = None,
|
310 |
+
):
|
311 |
+
super().__init__()
|
312 |
+
self.window_length_in_sec = window_length_in_sec
|
313 |
+
self.shift_length_in_sec = shift_length_in_sec
|
314 |
+
self.zero_spec_db_val = zero_spec_db_val
|
315 |
+
|
316 |
+
if isinstance(manifest_filepath, str):
|
317 |
+
manifest_filepath = manifest_filepath.split(',')
|
318 |
+
|
319 |
+
self.collection = collections.ASRFeatureLabel(
|
320 |
+
manifests_files=manifest_filepath,
|
321 |
+
is_regression_task=is_regression_task,
|
322 |
+
cal_labels_occurrence=cal_labels_occurrence,
|
323 |
+
min_duration=min_duration,
|
324 |
+
max_duration=max_duration,
|
325 |
+
)
|
326 |
+
|
327 |
+
self.feature_loader = ExternalFeatureLoader(augmentor=augmentor)
|
328 |
+
self.labels = labels if labels else self.collection.uniq_labels
|
329 |
+
|
330 |
+
self.is_regression_task = is_regression_task
|
331 |
+
|
332 |
+
if not is_regression_task:
|
333 |
+
self.labels = labels if labels else self.collection.uniq_labels
|
334 |
+
self.num_classes = len(self.labels) if self.labels is not None else 1
|
335 |
+
self.label2id, self.id2label = {}, {}
|
336 |
+
self.id2occurrence, self.labels_occurrence = {}, []
|
337 |
+
|
338 |
+
for label_id, label in enumerate(self.labels):
|
339 |
+
self.label2id[label] = label_id
|
340 |
+
self.id2label[label_id] = label
|
341 |
+
if cal_labels_occurrence:
|
342 |
+
self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
|
343 |
+
|
344 |
+
if cal_labels_occurrence:
|
345 |
+
self.labels_occurrence = [self.id2occurrence[k] for k in sorted(self.id2occurrence)]
|
346 |
+
|
347 |
+
for idx in range(len(self.labels[:5])):
|
348 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
349 |
+
else:
|
350 |
+
self.labels = []
|
351 |
+
self.num_classes = 1
|
352 |
+
|
353 |
+
def __len__(self):
|
354 |
+
return len(self.collection)
|
355 |
+
|
356 |
+
def __getitem__(self, index):
|
357 |
+
sample = self.collection[index]
|
358 |
+
|
359 |
+
features = self.feature_loader.process(sample.feature_file)
|
360 |
+
f, fl = features, torch.tensor(features.shape[1]).long()
|
361 |
+
|
362 |
+
t = torch.tensor(self.label2id[sample.label])
|
363 |
+
tl = torch.tensor(1).long()
|
364 |
+
|
365 |
+
return f, fl, t, tl
|
366 |
+
|
367 |
+
def _collate_fn(self, batch):
|
368 |
+
return _audio_feature_collate_fn(batch, self.zero_spec_db_val, 0)
|
369 |
+
|
370 |
+
def _vad_segment_collate_fn(self, batch):
|
371 |
+
return _vad_feature_segment_collate_fn(
|
372 |
+
batch, self.window_length_in_sec, self.shift_length_in_sec, self.FRAME_UNIT_TIME_SECS
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
class FeatureToMultiLabelDataset(Dataset):
|
377 |
+
"""
|
378 |
+
Dataset that loads tensors via a json file containing paths to feature files and their labels.
|
379 |
+
Each new line is a different sample. Example below:
|
380 |
+
and their target labels. JSON files should be of the following format:
|
381 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "label": "1 1 0 0 1"}
|
382 |
+
...
|
383 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "label": "0 1 0 0"}
|
384 |
+
Args:
|
385 |
+
manifest_filepath (str): Path to JSON containing data.
|
386 |
+
labels (Optional[list]): List of unique labels collected from all samples.
|
387 |
+
augmentor (Optional): feature augmentation
|
388 |
+
delimiter (str): delimiter to split the labels.
|
389 |
+
is_regression_task (bool): if True, the labels are treated as for a regression task.
|
390 |
+
cal_labels_occurrence (bool): if True, the labels occurrence will be calculated.
|
391 |
+
zero_spec_db_val (float): Value to replace non-speech signals in log-melspectrogram.
|
392 |
+
min_duration (float): Minimum duration of the audio file in seconds.
|
393 |
+
max_duration (float): Maximum duration of the audio file in seconds.
|
394 |
+
"""
|
395 |
+
|
396 |
+
ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
|
397 |
+
|
398 |
+
@property
|
399 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
400 |
+
"""Returns definitions of module output ports.
|
401 |
+
"""
|
402 |
+
output_types = {
|
403 |
+
'audio_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
|
404 |
+
'feat_length': NeuralType(tuple('B'), LengthsType()),
|
405 |
+
'labels': NeuralType(('B', 'T'), LabelsType()),
|
406 |
+
'labels_length': NeuralType(tuple('B'), LengthsType()),
|
407 |
+
}
|
408 |
+
|
409 |
+
return output_types
|
410 |
+
|
411 |
+
def __init__(
|
412 |
+
self,
|
413 |
+
*,
|
414 |
+
manifest_filepath: str,
|
415 |
+
labels: List[str] = None,
|
416 |
+
augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
|
417 |
+
delimiter: Optional[str] = None,
|
418 |
+
is_regression_task: bool = False,
|
419 |
+
cal_labels_occurrence: Optional[bool] = False,
|
420 |
+
zero_spec_db_val: float = -16.635,
|
421 |
+
min_duration: Optional[float] = None,
|
422 |
+
max_duration: Optional[float] = None,
|
423 |
+
):
|
424 |
+
super().__init__()
|
425 |
+
self.delimiter = delimiter
|
426 |
+
self.zero_spec_db_val = zero_spec_db_val
|
427 |
+
|
428 |
+
if isinstance(manifest_filepath, str):
|
429 |
+
manifest_filepath = manifest_filepath.split(',')
|
430 |
+
|
431 |
+
self.collection = collections.ASRFeatureLabel(
|
432 |
+
manifests_files=manifest_filepath,
|
433 |
+
is_regression_task=is_regression_task,
|
434 |
+
cal_labels_occurrence=cal_labels_occurrence,
|
435 |
+
delimiter=delimiter,
|
436 |
+
min_duration=min_duration,
|
437 |
+
max_duration=max_duration,
|
438 |
+
)
|
439 |
+
|
440 |
+
self.is_regression_task = is_regression_task
|
441 |
+
self.feature_loader = ExternalFeatureLoader(augmentor=augmentor)
|
442 |
+
self.labels = labels if labels else self.collection.uniq_labels
|
443 |
+
|
444 |
+
self.label2id, self.id2label = {}, {}
|
445 |
+
if not is_regression_task:
|
446 |
+
self.labels = labels if labels else self._get_label_set()
|
447 |
+
self.num_classes = len(self.labels) if self.labels is not None else 1
|
448 |
+
self.label2id, self.id2label = {}, {}
|
449 |
+
for label_id, label in enumerate(self.labels):
|
450 |
+
self.label2id[label] = label_id
|
451 |
+
self.id2label[label_id] = label
|
452 |
+
if cal_labels_occurrence:
|
453 |
+
self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
|
454 |
+
self.labels_occurrence.append(self.id2occurrence[label_id])
|
455 |
+
|
456 |
+
for idx in range(len(self.labels[:5])):
|
457 |
+
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
|
458 |
+
else:
|
459 |
+
self.labels = []
|
460 |
+
self.num_classes = 1
|
461 |
+
|
462 |
+
def _get_label_set(self):
|
463 |
+
labels = []
|
464 |
+
for sample in self.collection:
|
465 |
+
label_str = sample.label
|
466 |
+
if label_str:
|
467 |
+
label_str_list = label_str.split(self.delimiter) if self.delimiter else label_str.split()
|
468 |
+
labels.extend(label_str_list)
|
469 |
+
return sorted(set(labels))
|
470 |
+
|
471 |
+
def _label_str_to_tensor(self, label_str: str):
|
472 |
+
labels = label_str.split(self.delimiter) if self.delimiter else label_str.split()
|
473 |
+
|
474 |
+
if self.is_regression_task:
|
475 |
+
labels = [float(s) for s in labels]
|
476 |
+
labels = torch.tensor(labels).float()
|
477 |
+
else:
|
478 |
+
labels = [self.label2id[s] for s in labels]
|
479 |
+
labels = torch.tensor(labels).long()
|
480 |
+
return labels
|
481 |
+
|
482 |
+
def __len__(self):
|
483 |
+
return len(self.collection)
|
484 |
+
|
485 |
+
def __getitem__(self, index):
|
486 |
+
sample = self.collection[index]
|
487 |
+
|
488 |
+
features = self.feature_loader.process(sample.feature_file)
|
489 |
+
f, fl = features, torch.tensor(features.shape[1]).long()
|
490 |
+
|
491 |
+
t = self._label_str_to_tensor(sample.label)
|
492 |
+
tl = torch.tensor(t.size(0)).long()
|
493 |
+
|
494 |
+
return f, fl, t, tl
|
495 |
+
|
496 |
+
def _collate_fn(self, batch):
|
497 |
+
return _audio_feature_collate_fn(batch, self.zero_spec_db_val, 0)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label_dataset.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
from nemo.collections.asr.data import feature_to_label
|
17 |
+
|
18 |
+
|
19 |
+
def get_feature_seq_speakerlabel_dataset(
|
20 |
+
feature_loader, config: dict
|
21 |
+
) -> feature_to_label.FeatureToSeqSpeakerLabelDataset:
|
22 |
+
"""
|
23 |
+
Instantiates a FeatureSeqSpeakerLabelDataset.
|
24 |
+
Args:
|
25 |
+
config: Config of the FeatureToSeqSpeakerLabelDataset.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
An instance of FeatureToSeqSpeakerLabelDataset.
|
29 |
+
"""
|
30 |
+
dataset = feature_to_label.FeatureToSeqSpeakerLabelDataset(
|
31 |
+
manifest_filepath=config['manifest_filepath'], labels=config['labels'], feature_loader=feature_loader,
|
32 |
+
)
|
33 |
+
return dataset
|
34 |
+
|
35 |
+
|
36 |
+
def get_feature_label_dataset(
|
37 |
+
config: dict, augmentor: Optional['FeatureAugmentor'] = None
|
38 |
+
) -> feature_to_label.FeatureToLabelDataset:
|
39 |
+
dataset = feature_to_label.FeatureToLabelDataset(
|
40 |
+
manifest_filepath=config['manifest_filepath'],
|
41 |
+
labels=config['labels'],
|
42 |
+
augmentor=augmentor,
|
43 |
+
window_length_in_sec=config.get("window_length_in_sec", 0.63),
|
44 |
+
shift_length_in_sec=config.get("shift_length_in_sec", 0.08),
|
45 |
+
is_regression_task=config.get("is_regression_task", False),
|
46 |
+
cal_labels_occurrence=config.get("cal_labels_occurrence", False),
|
47 |
+
zero_spec_db_val=config.get("zero_spec_db_val", -16.635),
|
48 |
+
max_duration=config.get('max_duration', None),
|
49 |
+
min_duration=config.get('min_duration', None),
|
50 |
+
)
|
51 |
+
return dataset
|
52 |
+
|
53 |
+
|
54 |
+
def get_feature_multi_label_dataset(
|
55 |
+
config: dict, augmentor: Optional['FeatureAugmentor'] = None
|
56 |
+
) -> feature_to_label.FeatureToMultiLabelDataset:
|
57 |
+
dataset = feature_to_label.FeatureToMultiLabelDataset(
|
58 |
+
manifest_filepath=config['manifest_filepath'],
|
59 |
+
labels=config['labels'],
|
60 |
+
augmentor=augmentor,
|
61 |
+
delimiter=config.get('delimiter', None),
|
62 |
+
is_regression_task=config.get("is_regression_task", False),
|
63 |
+
cal_labels_occurrence=config.get("cal_labels_occurrence", False),
|
64 |
+
zero_spec_db_val=config.get("zero_spec_db_val", -16.635),
|
65 |
+
max_duration=config.get('max_duration', None),
|
66 |
+
min_duration=config.get('min_duration', None),
|
67 |
+
)
|
68 |
+
return dataset
|
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from nemo.collections.asr.data.feature_to_label import _audio_feature_collate_fn
|
20 |
+
from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader
|
21 |
+
from nemo.collections.asr.parts.preprocessing.features import normalize_batch
|
22 |
+
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
|
23 |
+
from nemo.collections.asr.parts.utils.vad_utils import load_speech_segments_from_rttm
|
24 |
+
from nemo.collections.common import tokenizers
|
25 |
+
from nemo.collections.common.parts.preprocessing import collections, parsers
|
26 |
+
from nemo.core.classes import Dataset
|
27 |
+
from nemo.core.neural_types import AcousticEncodedRepresentation, LabelsType, LengthsType, NeuralType
|
28 |
+
|
29 |
+
|
30 |
+
class ASRFeatureManifestProcessor:
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
manifest_filepath: str,
|
34 |
+
parser: Union[str, Callable],
|
35 |
+
max_duration: Optional[float] = None,
|
36 |
+
min_duration: Optional[float] = None,
|
37 |
+
max_utts: int = 0,
|
38 |
+
bos_id: Optional[int] = None,
|
39 |
+
eos_id: Optional[int] = None,
|
40 |
+
pad_id: int = 0,
|
41 |
+
index_by_file_id: bool = False,
|
42 |
+
):
|
43 |
+
self.parser = parser
|
44 |
+
self.collection = collections.ASRFeatureText(
|
45 |
+
manifests_files=manifest_filepath,
|
46 |
+
parser=parser,
|
47 |
+
min_duration=min_duration,
|
48 |
+
max_duration=max_duration,
|
49 |
+
max_number=max_utts,
|
50 |
+
index_by_file_id=index_by_file_id,
|
51 |
+
)
|
52 |
+
|
53 |
+
self.eos_id = eos_id
|
54 |
+
self.bos_id = bos_id
|
55 |
+
self.pad_id = pad_id
|
56 |
+
|
57 |
+
def process_text_by_id(self, index: int) -> Tuple[List[int], int]:
|
58 |
+
sample = self.collection[index]
|
59 |
+
return self.process_text_by_sample(sample)
|
60 |
+
|
61 |
+
def process_text_by_file_id(self, file_id: str) -> Tuple[List[int], int]:
|
62 |
+
manifest_idx = self.collection.mapping[file_id][0]
|
63 |
+
sample = self.collection[manifest_idx]
|
64 |
+
return self.process_text_by_sample(sample)
|
65 |
+
|
66 |
+
def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -> Tuple[List[int], int]:
|
67 |
+
t, tl = sample.text_tokens, len(sample.text_tokens)
|
68 |
+
|
69 |
+
if self.bos_id is not None:
|
70 |
+
t = [self.bos_id] + t
|
71 |
+
tl += 1
|
72 |
+
if self.eos_id is not None:
|
73 |
+
t = t + [self.eos_id]
|
74 |
+
tl += 1
|
75 |
+
|
76 |
+
return t, tl
|
77 |
+
|
78 |
+
|
79 |
+
class _FeatureTextDataset(Dataset):
|
80 |
+
"""
|
81 |
+
Dataset that loads tensors via a json file containing paths to audio feature files, transcripts,
|
82 |
+
durations (in seconds) and optional RTTM files. Each new line is a different sample. Example below:
|
83 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "text_filepath": "/path/to/audio.txt",
|
84 |
+
"rttm_filepath": "/path/to/audio_rttm.rttm", "duration": 23.147}
|
85 |
+
...
|
86 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
|
87 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
88 |
+
Args:
|
89 |
+
manifest_filepath (str): Path to manifest json as described above. Can be comma-separated paths.
|
90 |
+
parser: Str for a language specific preprocessor or a callable.
|
91 |
+
normalize (bool): whether and where to normalize feature, must be one of [None, "post_norm", "pre_norm"]
|
92 |
+
normalize_type (Union[str, dict]): how to normalize feature, see `nemo.collections.asr.parts.preprocessing.features.normalize_batch`
|
93 |
+
use_rttm (bool): whether to use RTTM files if there is any, default to False
|
94 |
+
rttm_mode (str): how to use RTTM files, must be one of ['mask', 'drop'], default to 'mask'
|
95 |
+
feat_min_len (int): minimum length of feature when rttm_mode=deop, default to 4.
|
96 |
+
feat_mask_val (Optional[float]): value used to mask features with RTTM files, default to None to use zero mel-spectralgram
|
97 |
+
frame_unit_time_secs (float): time in seconds for each frame
|
98 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
99 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
100 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded audio
|
101 |
+
max_duration (float): If audio exceeds this length, do not include in dataset
|
102 |
+
min_duration (float): If audio is less than this length, do not include in dataset
|
103 |
+
max_utts (int): Limit number of utterances
|
104 |
+
trim (bool): whether or not to trim silence. Defaults to False
|
105 |
+
bos_id (int): Id of beginning of sequence symbol to append if not None
|
106 |
+
eos_id (int): Id of end of sequence symbol to append if not None
|
107 |
+
pad_id (int): Id of pad symbol. Defaults to 0
|
108 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
109 |
+
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
|
110 |
+
"""
|
111 |
+
|
112 |
+
ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
|
113 |
+
NORM_MODES = ["pre_norm", "post_norm"]
|
114 |
+
RTTM_MODES = ["mask", "drop"]
|
115 |
+
|
116 |
+
@property
|
117 |
+
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
118 |
+
"""Returns definitions of module output ports.
|
119 |
+
"""
|
120 |
+
return {
|
121 |
+
'features': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
|
122 |
+
'feature_length': NeuralType(tuple('B'), LengthsType()),
|
123 |
+
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
124 |
+
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
125 |
+
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
126 |
+
}
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
manifest_filepath: str,
|
131 |
+
parser: Union[str, Callable],
|
132 |
+
normalize: Optional[str] = "post_norm",
|
133 |
+
normalize_type: Union[str, dict] = "per_feature",
|
134 |
+
use_rttm: bool = False,
|
135 |
+
rttm_mode: str = "mask",
|
136 |
+
feat_min_len: int = 4,
|
137 |
+
feat_mask_val: Optional[float] = None,
|
138 |
+
frame_unit_time_secs: float = 0.01,
|
139 |
+
sample_rate: Optional[int] = 16000,
|
140 |
+
augmentor: 'nemo.collections.asr.parts.perturb.FeatureAugmentor' = None,
|
141 |
+
max_duration: Optional[int] = None,
|
142 |
+
min_duration: Optional[int] = None,
|
143 |
+
max_utts: int = 0,
|
144 |
+
trim: bool = False,
|
145 |
+
bos_id: Optional[int] = None,
|
146 |
+
eos_id: Optional[int] = None,
|
147 |
+
pad_id: int = 0,
|
148 |
+
return_sample_id: bool = False,
|
149 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
150 |
+
):
|
151 |
+
if type(manifest_filepath) == str:
|
152 |
+
manifest_filepath = manifest_filepath.split(",")
|
153 |
+
|
154 |
+
self.sample_rate = sample_rate
|
155 |
+
self.normalize = normalize
|
156 |
+
self.normalize_type = normalize_type
|
157 |
+
self.use_rttm = use_rttm
|
158 |
+
self.rttm_mode = rttm_mode
|
159 |
+
if self.use_rttm and self.rttm_mode not in self.RTTM_MODES:
|
160 |
+
raise ValueError(f"`rttm_mode` must be one of {self.RTTM_MODES}, got `{rttm_mode}` instead")
|
161 |
+
|
162 |
+
self.feat_min_len = feat_min_len
|
163 |
+
if feat_mask_val is not None:
|
164 |
+
self.feat_mask_val = feat_mask_val
|
165 |
+
elif normalize == "pre_norm":
|
166 |
+
self.feat_mask_val = 0.0 # similar to SpectralAugmentation
|
167 |
+
else:
|
168 |
+
self.feat_mask_val = self.ZERO_LEVEL_SPEC_DB_VAL
|
169 |
+
|
170 |
+
if normalize is not None and normalize not in self.NORM_MODES:
|
171 |
+
raise ValueError(f"`normalize` must be one of {self.NORM_MODES}, got `{normalize}` instead")
|
172 |
+
|
173 |
+
self.frame_unit_time_secs = frame_unit_time_secs
|
174 |
+
|
175 |
+
self.manifest_processor = ASRFeatureManifestProcessor(
|
176 |
+
manifest_filepath=manifest_filepath,
|
177 |
+
parser=parser,
|
178 |
+
max_duration=max_duration,
|
179 |
+
min_duration=min_duration,
|
180 |
+
max_utts=max_utts,
|
181 |
+
bos_id=bos_id,
|
182 |
+
eos_id=eos_id,
|
183 |
+
pad_id=pad_id,
|
184 |
+
)
|
185 |
+
self.featurizer = ExternalFeatureLoader(augmentor=augmentor)
|
186 |
+
self.trim = trim
|
187 |
+
self.return_sample_id = return_sample_id
|
188 |
+
self.channel_selector = channel_selector
|
189 |
+
|
190 |
+
def get_manifest_sample(self, sample_id):
|
191 |
+
return self.manifest_processor.collection[sample_id]
|
192 |
+
|
193 |
+
def __getitem__(self, index):
|
194 |
+
sample = self.manifest_processor.collection[index]
|
195 |
+
offset = sample.offset
|
196 |
+
|
197 |
+
if offset is None:
|
198 |
+
offset = 0
|
199 |
+
|
200 |
+
features = self.featurizer.process(sample.feature_file)
|
201 |
+
|
202 |
+
f, fl = features, torch.tensor(features.shape[1]).long()
|
203 |
+
|
204 |
+
t, tl = self.manifest_processor.process_text_by_sample(sample=sample)
|
205 |
+
|
206 |
+
# Feature normalization
|
207 |
+
if self.normalize is None:
|
208 |
+
if self.use_rttm and sample.rttm_file:
|
209 |
+
f = self.process_features_with_rttm(f, offset, sample.rttm_file, self.feat_mask_val)
|
210 |
+
elif self.normalize == "post_norm":
|
211 |
+
# (Optional) Masking based on RTTM file
|
212 |
+
if self.use_rttm and sample.rttm_file:
|
213 |
+
f = self.process_features_with_rttm(f, offset, sample.rttm_file, self.feat_mask_val)
|
214 |
+
|
215 |
+
f = self.normalize_feature(f)
|
216 |
+
else: # pre-norm
|
217 |
+
f = self.normalize_feature(f)
|
218 |
+
# (Optional) Masking based on RTTM file
|
219 |
+
if self.use_rttm and sample.rttm_file:
|
220 |
+
f = self.process_features_with_rttm(f, offset, sample.rttm_file, self.feat_mask_val)
|
221 |
+
|
222 |
+
if self.return_sample_id:
|
223 |
+
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
|
224 |
+
else:
|
225 |
+
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
226 |
+
|
227 |
+
return output
|
228 |
+
|
229 |
+
def process_features_with_rttm(self, features, offset, rttm_file, mask_val):
|
230 |
+
segments = load_speech_segments_from_rttm(rttm_file)
|
231 |
+
new_features = features.clone()
|
232 |
+
sid, fid = 0, 0
|
233 |
+
for i in range(features.size(1)):
|
234 |
+
t = offset + i * self.frame_unit_time_secs
|
235 |
+
while sid < len(segments) - 1 and segments[sid][1] < t:
|
236 |
+
sid += 1
|
237 |
+
if segments[sid][1] == 0 or t < segments[sid][0] or t > segments[sid][1]:
|
238 |
+
# not in speech segment
|
239 |
+
if self.rttm_mode == "drop":
|
240 |
+
# drop the frame
|
241 |
+
continue
|
242 |
+
else:
|
243 |
+
# mask the frame with specified value
|
244 |
+
new_features[:, i] = mask_val
|
245 |
+
fid += 1
|
246 |
+
else:
|
247 |
+
# in speech segment
|
248 |
+
new_features[:, fid] = features[:, i]
|
249 |
+
fid += 1
|
250 |
+
|
251 |
+
if fid < self.feat_min_len and self.rttm_mode == "drop":
|
252 |
+
new_features[:, : self.feat_min_len] = mask_val
|
253 |
+
return new_features[:, : self.feat_min_len]
|
254 |
+
return new_features[:, :fid]
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return len(self.manifest_processor.collection)
|
258 |
+
|
259 |
+
def _collate_fn(self, batch):
|
260 |
+
return _audio_feature_collate_fn(
|
261 |
+
batch, feat_pad_val=self.feat_mask_val, label_pad_id=self.manifest_processor.pad_id
|
262 |
+
)
|
263 |
+
|
264 |
+
def normalize_feature(self, feat):
|
265 |
+
"""
|
266 |
+
Args:
|
267 |
+
feat: feature tensor of shape [M, T]
|
268 |
+
"""
|
269 |
+
feat = feat.unsqueeze(0) # add batch dim
|
270 |
+
feat, _, _ = normalize_batch(feat, torch.tensor([feat.size(-1)]), self.normalize_type)
|
271 |
+
return feat.squeeze(0) # delete batch dim
|
272 |
+
|
273 |
+
|
274 |
+
class FeatureToCharDataset(_FeatureTextDataset):
|
275 |
+
"""
|
276 |
+
Dataset that loads tensors via a json file containing paths to audio feature
|
277 |
+
files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a
|
278 |
+
different sample. Example below:
|
279 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "text_filepath":
|
280 |
+
"/path/to/audio.txt", "duration": 23.147, "rttm_filepath": "/path/to/audio_rttm.rttm",}
|
281 |
+
...
|
282 |
+
{"feature_filepath": "/path/to/audio_feature.pt", "text": "the
|
283 |
+
transcription", "offset": 301.75, "duration": 0.82, "utt":
|
284 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
285 |
+
|
286 |
+
Args:
|
287 |
+
manifest_filepath (str): Path to manifest json as described above. Can
|
288 |
+
be comma-separated paths.
|
289 |
+
labels (str): String containing all the possible characters to map to
|
290 |
+
normalize (str): how to normalize feature, must be one of [None, "post_norm", "pre_norm"]
|
291 |
+
normalize_type (Union[str, dict]): how to normalize feature, see `nemo.collections.asr.parts.preprocessing.features.normalize_batch`
|
292 |
+
use_rttm (bool): whether to use RTTM files if there is any, default to False
|
293 |
+
rttm_mode (str): how to use RTTM files, must be one of ['mask', 'drop'], default to 'mask'
|
294 |
+
feat_min_len (int): minimum length of feature, default to 4
|
295 |
+
feat_mask_val (Optional[float]): value used to mask features with RTTM files, default to None to use zero mel-spectralgram
|
296 |
+
frame_unit_time_secs: time in seconds for each frame
|
297 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
298 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
299 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
300 |
+
object used to augment loaded audio
|
301 |
+
max_duration: If audio exceeds this length, do not include in dataset
|
302 |
+
min_duration: If audio is less than this length, do not include
|
303 |
+
in dataset
|
304 |
+
max_utts: Limit number of utterances
|
305 |
+
blank_index: blank character index, default = -1
|
306 |
+
unk_index: unk_character index, default = -1
|
307 |
+
bos_id: Id of beginning of sequence symbol to append if not None
|
308 |
+
eos_id: Id of end of sequence symbol to append if not None
|
309 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
310 |
+
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
manifest_filepath: str,
|
316 |
+
labels: Union[str, List[str]],
|
317 |
+
normalize: Optional[str] = "post_norm",
|
318 |
+
normalize_type: Union[str, dict] = "per_feature",
|
319 |
+
use_rttm: bool = False,
|
320 |
+
rttm_mode: str = "mask",
|
321 |
+
feat_min_len: int = 4,
|
322 |
+
feat_mask_val: Optional[float] = None,
|
323 |
+
frame_unit_time_secs: float = 0.01,
|
324 |
+
sample_rate: Optional[int] = 16000,
|
325 |
+
augmentor: 'nemo.collections.asr.parts.perturb.FeatureAugmentor' = None,
|
326 |
+
max_duration: Optional[int] = None,
|
327 |
+
min_duration: Optional[int] = None,
|
328 |
+
max_utts: int = 0,
|
329 |
+
blank_index: int = -1,
|
330 |
+
unk_index: int = -1,
|
331 |
+
trim: bool = False,
|
332 |
+
bos_id: Optional[int] = None,
|
333 |
+
eos_id: Optional[int] = None,
|
334 |
+
pad_id: int = 0,
|
335 |
+
parser: Union[str, Callable] = 'en',
|
336 |
+
return_sample_id: bool = False,
|
337 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
338 |
+
):
|
339 |
+
self.labels = labels
|
340 |
+
|
341 |
+
parser = parsers.make_parser(
|
342 |
+
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
|
343 |
+
)
|
344 |
+
|
345 |
+
super().__init__(
|
346 |
+
manifest_filepath=manifest_filepath,
|
347 |
+
parser=parser,
|
348 |
+
normalize=normalize,
|
349 |
+
normalize_type=normalize_type,
|
350 |
+
use_rttm=use_rttm,
|
351 |
+
rttm_mode=rttm_mode,
|
352 |
+
feat_min_len=feat_min_len,
|
353 |
+
feat_mask_val=feat_mask_val,
|
354 |
+
frame_unit_time_secs=frame_unit_time_secs,
|
355 |
+
sample_rate=sample_rate,
|
356 |
+
augmentor=augmentor,
|
357 |
+
max_duration=max_duration,
|
358 |
+
min_duration=min_duration,
|
359 |
+
max_utts=max_utts,
|
360 |
+
trim=trim,
|
361 |
+
bos_id=bos_id,
|
362 |
+
eos_id=eos_id,
|
363 |
+
pad_id=pad_id,
|
364 |
+
return_sample_id=return_sample_id,
|
365 |
+
channel_selector=channel_selector,
|
366 |
+
)
|
367 |
+
|
368 |
+
|
369 |
+
class FeatureToBPEDataset(_FeatureTextDataset):
|
370 |
+
"""
|
371 |
+
Dataset that loads tensors via a json file containing paths to audio feature
|
372 |
+
files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a different sample.
|
373 |
+
Example below:
|
374 |
+
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
375 |
+
"/path/to/audio.txt", "duration": 23.147, "rttm_filepath": "/path/to/audio_rttm.rttm",}
|
376 |
+
...
|
377 |
+
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
378 |
+
transcription", "offset": 301.75, "duration": 0.82, "utt":
|
379 |
+
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
380 |
+
|
381 |
+
In practice, the dataset and manifest used for character encoding and byte pair encoding
|
382 |
+
are exactly the same. The only difference lies in how the dataset tokenizes the text in
|
383 |
+
the manifest.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
manifest_filepath (str): Path to manifest json as described above. Can
|
387 |
+
be comma-separated paths.
|
388 |
+
tokenizer: A subclass of the Tokenizer wrapper found in the common collection,
|
389 |
+
nemo.collections.common.tokenizers.TokenizerSpec. ASR Models support a subset of
|
390 |
+
all available tokenizers.
|
391 |
+
normalize (str): how to normalize feature, must be one of [None, "post_norm", "pre_norm"]
|
392 |
+
normalize_type (Union[str, dict]): how to normalize feature, see `nemo.collections.asr.parts.preprocessing.features.normalize_batch`
|
393 |
+
use_rttm (bool): whether to use RTTM files if there is any, default to False
|
394 |
+
rttm_mode (str): how to use RTTM files, must be one of ['mask', 'drop'], default to 'mask'
|
395 |
+
feat_min_len (int): minimum length of feature, default to 4
|
396 |
+
feat_mask_val (Optional[float]): value used to mask features with RTTM files, default to None to use zero mel-spectralgram
|
397 |
+
frame_unit_time_secs: time in seconds for each frame
|
398 |
+
sample_rate (int): Sample rate to resample loaded audio to
|
399 |
+
int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
|
400 |
+
augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
|
401 |
+
object used to augment loaded audio
|
402 |
+
max_duration: If audio exceeds this length, do not include in dataset
|
403 |
+
min_duration: If audio is less than this length, do not include
|
404 |
+
in dataset
|
405 |
+
max_utts: Limit number of utterances
|
406 |
+
trim: Whether to trim silence segments
|
407 |
+
use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
|
408 |
+
tokens to beginning and ending of speech respectively.
|
409 |
+
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
410 |
+
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
|
411 |
+
"""
|
412 |
+
|
413 |
+
def __init__(
|
414 |
+
self,
|
415 |
+
manifest_filepath: str,
|
416 |
+
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
|
417 |
+
normalize: Optional[str] = "post_norm",
|
418 |
+
normalize_type: Union[str, dict] = "per_feature",
|
419 |
+
use_rttm: bool = False,
|
420 |
+
rttm_mode: str = "mask",
|
421 |
+
feat_min_len: int = 4,
|
422 |
+
feat_mask_val: Optional[float] = None,
|
423 |
+
frame_unit_time_secs: float = 0.01,
|
424 |
+
sample_rate: Optional[int] = 16000,
|
425 |
+
augmentor: 'nemo.collections.asr.parts.perturb.FeatureAugmentor' = None,
|
426 |
+
max_duration: Optional[int] = None,
|
427 |
+
min_duration: Optional[int] = None,
|
428 |
+
max_utts: int = 0,
|
429 |
+
use_start_end_token: bool = True,
|
430 |
+
trim: bool = False,
|
431 |
+
return_sample_id: bool = False,
|
432 |
+
channel_selector: Optional[ChannelSelectorType] = None,
|
433 |
+
):
|
434 |
+
if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
|
435 |
+
bos_id = tokenizer.bos_id
|
436 |
+
else:
|
437 |
+
bos_id = None
|
438 |
+
|
439 |
+
if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0:
|
440 |
+
eos_id = tokenizer.eos_id
|
441 |
+
else:
|
442 |
+
eos_id = None
|
443 |
+
|
444 |
+
if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0:
|
445 |
+
pad_id = tokenizer.pad_id
|
446 |
+
else:
|
447 |
+
pad_id = 0
|
448 |
+
|
449 |
+
class TokenizerWrapper:
|
450 |
+
def __init__(self, tokenizer):
|
451 |
+
if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
|
452 |
+
self.is_aggregate = True
|
453 |
+
else:
|
454 |
+
self.is_aggregate = False
|
455 |
+
self._tokenizer = tokenizer
|
456 |
+
|
457 |
+
def __call__(self, *args):
|
458 |
+
if isinstance(args[0], List) and self.is_aggregate:
|
459 |
+
t = []
|
460 |
+
for span in args[0]:
|
461 |
+
t.extend(self._tokenizer.text_to_ids(span['str'], span['lang']))
|
462 |
+
return t
|
463 |
+
|
464 |
+
t = self._tokenizer.text_to_ids(*args)
|
465 |
+
return t
|
466 |
+
|
467 |
+
super().__init__(
|
468 |
+
manifest_filepath=manifest_filepath,
|
469 |
+
parser=TokenizerWrapper(tokenizer),
|
470 |
+
normalize=normalize,
|
471 |
+
normalize_type=normalize_type,
|
472 |
+
use_rttm=use_rttm,
|
473 |
+
rttm_mode=rttm_mode,
|
474 |
+
feat_min_len=feat_min_len,
|
475 |
+
feat_mask_val=feat_mask_val,
|
476 |
+
frame_unit_time_secs=frame_unit_time_secs,
|
477 |
+
sample_rate=sample_rate,
|
478 |
+
augmentor=augmentor,
|
479 |
+
max_duration=max_duration,
|
480 |
+
min_duration=min_duration,
|
481 |
+
max_utts=max_utts,
|
482 |
+
trim=trim,
|
483 |
+
bos_id=bos_id,
|
484 |
+
eos_id=eos_id,
|
485 |
+
pad_id=pad_id,
|
486 |
+
return_sample_id=return_sample_id,
|
487 |
+
channel_selector=channel_selector,
|
488 |
+
)
|
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text_dataset.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
from nemo.collections.asr.data.feature_to_text import FeatureToBPEDataset, FeatureToCharDataset
|
18 |
+
from nemo.utils import logging
|
19 |
+
|
20 |
+
|
21 |
+
def get_char_dataset(config: dict, augmentor: Optional['FeatureAugmentor'] = None) -> FeatureToCharDataset:
|
22 |
+
"""
|
23 |
+
Instantiates a Character Encoding based FeatureToCharDataset.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
config: Config of the FeatureToCharDataset.
|
27 |
+
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
An instance of FeatureToCharDataset.
|
31 |
+
"""
|
32 |
+
if 'labels' not in config:
|
33 |
+
logging.warning(f"dataset does not have explicitly defined labels")
|
34 |
+
|
35 |
+
dataset = FeatureToCharDataset(
|
36 |
+
manifest_filepath=config['manifest_filepath'],
|
37 |
+
labels=config.get('labels', None),
|
38 |
+
normalize=config.get('normalize', 'post_norm'),
|
39 |
+
normalize_type=config.get('normalize_type', 'per_feature'),
|
40 |
+
use_rttm=config.get('use_rttm', False),
|
41 |
+
rttm_mode=config.get('rttm_mode', 'mask'),
|
42 |
+
feat_min_len=config.get('feat_min_len', 4),
|
43 |
+
feat_mask_val=config.get('feat_mask_val', None),
|
44 |
+
frame_unit_time_secs=config.get('frame_unit_time_secs', 0.01),
|
45 |
+
sample_rate=config.get('sample_rate', 16000),
|
46 |
+
augmentor=augmentor,
|
47 |
+
max_duration=config.get('max_duration', None),
|
48 |
+
min_duration=config.get('min_duration', None),
|
49 |
+
max_utts=config.get('max_utts', 0),
|
50 |
+
blank_index=config.get('blank_index', -1),
|
51 |
+
unk_index=config.get('unk_index', -1),
|
52 |
+
trim=config.get('trim_silence', False),
|
53 |
+
parser=config.get('parser', 'en'),
|
54 |
+
return_sample_id=config.get('return_sample_id', False),
|
55 |
+
channel_selector=config.get('channel_selector', None),
|
56 |
+
)
|
57 |
+
return dataset
|
58 |
+
|
59 |
+
|
60 |
+
def get_bpe_dataset(
|
61 |
+
config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['FeatureAugmentor'] = None
|
62 |
+
) -> FeatureToBPEDataset:
|
63 |
+
"""
|
64 |
+
Instantiates a Byte Pair Encoding / Word Piece Encoding based FeatureoToBPEDataset.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
config: Config of the FeatureToBPEDataset.
|
68 |
+
tokenizer: An instance of a TokenizerSpec object.
|
69 |
+
augmentor: Optional FeatureAugmentor object for augmentations on audio features.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
An instance of FeatureToBPEDataset.
|
73 |
+
"""
|
74 |
+
dataset = FeatureToBPEDataset(
|
75 |
+
manifest_filepath=config['manifest_filepath'],
|
76 |
+
tokenizer=tokenizer,
|
77 |
+
normalize=config.get('normalize', 'post_norm'),
|
78 |
+
normalize_type=config.get('normalize_type', 'per_feature'),
|
79 |
+
use_rttm=config.get('use_rttm', False),
|
80 |
+
rttm_mode=config.get('rttm_mode', 'mask'),
|
81 |
+
feat_min_len=config.get('feat_min_len', 4),
|
82 |
+
feat_mask_val=config.get('feat_mask_val', None),
|
83 |
+
frame_unit_time_secs=config.get('frame_unit_time_secs', 0.01),
|
84 |
+
sample_rate=config.get('sample_rate', 16000),
|
85 |
+
augmentor=augmentor,
|
86 |
+
max_duration=config.get('max_duration', None),
|
87 |
+
min_duration=config.get('min_duration', None),
|
88 |
+
max_utts=config.get('max_utts', 0),
|
89 |
+
trim=config.get('trim_silence', False),
|
90 |
+
use_start_end_token=config.get('use_start_end_token', True),
|
91 |
+
return_sample_id=config.get('return_sample_id', False),
|
92 |
+
channel_selector=config.get('channel_selector', None),
|
93 |
+
)
|
94 |
+
return dataset
|
SoundScribe/SpeakerID/nemo/collections/asr/data/text_to_text.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import concurrent.futures
|
18 |
+
import copy
|
19 |
+
import gc
|
20 |
+
import json
|
21 |
+
import math
|
22 |
+
import random
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Union
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import torch.utils.data
|
29 |
+
from torch.nn.utils.rnn import pad_sequence
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
|
32 |
+
from nemo.collections.asr.data.audio_to_text import _speech_collate_fn
|
33 |
+
from nemo.collections.common.tokenizers import TokenizerSpec
|
34 |
+
from nemo.core.classes import Dataset, IterableDataset
|
35 |
+
from nemo.utils import logging
|
36 |
+
|
37 |
+
try:
|
38 |
+
from nemo_text_processing.text_normalization.normalize import Normalizer
|
39 |
+
except Exception as e:
|
40 |
+
pass # Normalizer imported only for annotation purposes, error can be ignored
|
41 |
+
|
42 |
+
AnyPath = Union[Path, str]
|
43 |
+
|
44 |
+
|
45 |
+
class TextToTextItem(NamedTuple):
|
46 |
+
tts_text: torch.Tensor # normalized and tokenized text for TTS
|
47 |
+
transcript: torch.Tensor # tokenized text for ASR
|
48 |
+
speaker: int # speaker id for multi-speaker TTS
|
49 |
+
|
50 |
+
|
51 |
+
class TextToTextBatch(NamedTuple):
|
52 |
+
tts_texts: torch.Tensor # tokenized texts for tts
|
53 |
+
tts_text_lengths: torch.Tensor
|
54 |
+
transcripts: torch.Tensor # tokenized texts for ASR
|
55 |
+
transcript_lengths: torch.Tensor
|
56 |
+
speakers: torch.Tensor # speaker ids for multi-speaker TTS
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def collate_fn(batch: List[TextToTextItem], asr_pad_id: int, tts_text_pad_id: int) -> TextToTextBatch:
|
60 |
+
return TextToTextBatch(
|
61 |
+
tts_texts=pad_sequence([item.tts_text for item in batch], batch_first=True, padding_value=tts_text_pad_id),
|
62 |
+
tts_text_lengths=torch.tensor([item.tts_text.shape[0] for item in batch]).long(),
|
63 |
+
transcripts=pad_sequence([item.transcript for item in batch], batch_first=True, padding_value=asr_pad_id),
|
64 |
+
transcript_lengths=torch.tensor([item.transcript.shape[0] for item in batch]).long(),
|
65 |
+
speakers=torch.tensor([item.speaker for item in batch]).long(),
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class TextOrAudioToTextBatch(NamedTuple):
|
70 |
+
audio_signals: torch.Tensor
|
71 |
+
audio_signal_lengths: torch.Tensor
|
72 |
+
tts_texts: torch.Tensor
|
73 |
+
tts_text_lengths: torch.Tensor
|
74 |
+
speakers: torch.Tensor
|
75 |
+
transcripts: torch.Tensor
|
76 |
+
transcript_lengths: torch.Tensor
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def collate_fn(
|
80 |
+
batch: List[Union[TextToTextItem, tuple]], tts_text_pad_id: int, asr_pad_id: int
|
81 |
+
) -> Union[TextToTextBatch, TextOrAudioToTextBatch, tuple]:
|
82 |
+
"""
|
83 |
+
Collate function for dataloader
|
84 |
+
Can accept mixed batch of text-to-text items and audio-text items (typical for ASR)
|
85 |
+
"""
|
86 |
+
text_items: List[TextToTextItem] = [item for item in batch if isinstance(item, TextToTextItem)]
|
87 |
+
if not text_items:
|
88 |
+
# pure audio-text batch
|
89 |
+
return _speech_collate_fn(batch=batch, pad_id=asr_pad_id)
|
90 |
+
|
91 |
+
asr_items = [item for item in batch if not isinstance(item, TextToTextItem)]
|
92 |
+
|
93 |
+
if not asr_items:
|
94 |
+
# pure text-to-text batch
|
95 |
+
return TextToTextBatch.collate_fn(batch=text_items, asr_pad_id=asr_pad_id, tts_text_pad_id=tts_text_pad_id)
|
96 |
+
|
97 |
+
# mixed batch
|
98 |
+
|
99 |
+
# each asr item is a tuple:
|
100 |
+
# audio_signal (0), audio_length (1), transcript (2), transcript_length (3), sample_id (4, optional)
|
101 |
+
audio_signals = pad_sequence([item[0] for item in asr_items], batch_first=True, padding_value=0.0)
|
102 |
+
audio_signal_lengths = torch.tensor([item[1] for item in asr_items]).long()
|
103 |
+
|
104 |
+
tts_texts = pad_sequence(
|
105 |
+
[item.tts_text for item in text_items], batch_first=True, padding_value=tts_text_pad_id
|
106 |
+
)
|
107 |
+
tts_text_lengths = torch.tensor([item.tts_text.shape[0] for item in text_items]).long()
|
108 |
+
speakers = torch.tensor([item.speaker for item in text_items]).long()
|
109 |
+
|
110 |
+
transcripts = pad_sequence(
|
111 |
+
[item.transcript for item in text_items] + [item[2] for item in asr_items],
|
112 |
+
batch_first=True,
|
113 |
+
padding_value=asr_pad_id,
|
114 |
+
)
|
115 |
+
transcript_lengths = torch.tensor(
|
116 |
+
[item.transcript.shape[0] for item in text_items] + [item[3] for item in asr_items]
|
117 |
+
).long()
|
118 |
+
|
119 |
+
return TextOrAudioToTextBatch(
|
120 |
+
audio_signals=audio_signals,
|
121 |
+
audio_signal_lengths=audio_signal_lengths,
|
122 |
+
tts_texts=tts_texts,
|
123 |
+
tts_text_lengths=tts_text_lengths,
|
124 |
+
speakers=speakers,
|
125 |
+
transcripts=transcripts,
|
126 |
+
transcript_lengths=transcript_lengths,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def _asr_text_to_tokens(text: str) -> np.ndarray:
|
131 |
+
"""
|
132 |
+
Helper function for asr tokenization with multiprocessing pool only.
|
133 |
+
Must be defined on the top level.
|
134 |
+
Expects asr_tokenizer_global, asr_bos_id_global, asr_eos_id_global to exist in the current pool process
|
135 |
+
"""
|
136 |
+
ids = asr_tokenizer_global.text_to_ids(text)
|
137 |
+
if asr_bos_id_global is not None:
|
138 |
+
ids = [asr_bos_id_global] + ids
|
139 |
+
if asr_eos_id_global is not None:
|
140 |
+
ids.append(asr_eos_id_global)
|
141 |
+
return np.asarray(ids)
|
142 |
+
|
143 |
+
|
144 |
+
def _tts_text_to_tokens(text: str) -> np.ndarray:
|
145 |
+
"""
|
146 |
+
Helper function for asr tokenization with multiprocessing pool only.
|
147 |
+
Must be defined on the top level.
|
148 |
+
Expects tts_tokenizer_global to exist in the current pool process
|
149 |
+
"""
|
150 |
+
return np.asarray(tts_tokenizer_global(text))
|
151 |
+
|
152 |
+
|
153 |
+
def _iterate_manifest(filepath: AnyPath) -> Iterable[Dict[str, Any]]:
|
154 |
+
"""
|
155 |
+
Helper function to iterate manifest
|
156 |
+
"""
|
157 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
158 |
+
for line in f:
|
159 |
+
record = json.loads(line)
|
160 |
+
yield record
|
161 |
+
|
162 |
+
|
163 |
+
class TextToTextDatasetBase:
|
164 |
+
"""
|
165 |
+
Base class for loading text-to-text manifests
|
166 |
+
Map-style and Iterable datasets should inherit this class
|
167 |
+
"""
|
168 |
+
|
169 |
+
asr_pad_id: int
|
170 |
+
tts_text_pad_id: int
|
171 |
+
asr_bos_id: Optional[int] = None
|
172 |
+
asr_eos_id: Optional[int] = None
|
173 |
+
data: List[Dict[str, Any]]
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
manifest_filepath: Union[AnyPath, List[AnyPath]],
|
178 |
+
speakers_filepath: Union[AnyPath, List[AnyPath]],
|
179 |
+
asr_tokenizer: TokenizerSpec,
|
180 |
+
asr_use_start_end_token: bool,
|
181 |
+
tts_parser: Callable,
|
182 |
+
tts_text_pad_id: int,
|
183 |
+
tts_text_normalizer: "Normalizer",
|
184 |
+
tts_text_normalizer_call_kwargs: Dict,
|
185 |
+
min_words: int = 1,
|
186 |
+
max_words: int = 1_000_000,
|
187 |
+
tokenizer_workers: int = 1,
|
188 |
+
num_parts: int = 1,
|
189 |
+
current_part_index: int = 0,
|
190 |
+
):
|
191 |
+
super().__init__()
|
192 |
+
# ASR tokenizer setup
|
193 |
+
if asr_use_start_end_token and hasattr(asr_tokenizer, 'bos_token'):
|
194 |
+
self.asr_bos_id = asr_tokenizer.bos_id
|
195 |
+
|
196 |
+
if asr_use_start_end_token and hasattr(asr_tokenizer, 'eos_token'):
|
197 |
+
self.asr_eos_id = asr_tokenizer.eos_id
|
198 |
+
|
199 |
+
if hasattr(asr_tokenizer, 'pad_token'):
|
200 |
+
self.asr_pad_id = asr_tokenizer.pad_id
|
201 |
+
else:
|
202 |
+
self.asr_pad_id = 0
|
203 |
+
|
204 |
+
self.asr_tokenizer = asr_tokenizer
|
205 |
+
|
206 |
+
# TTS tokenizer setup
|
207 |
+
self.tts_parser = tts_parser
|
208 |
+
self.tts_normalizer = tts_text_normalizer
|
209 |
+
self.tts_normalizer_kwargs = tts_text_normalizer_call_kwargs
|
210 |
+
self.tts_text_pad_id = tts_text_pad_id
|
211 |
+
|
212 |
+
# Load speakers
|
213 |
+
if isinstance(speakers_filepath, str):
|
214 |
+
speakers_filepath = speakers_filepath.split(",")
|
215 |
+
elif isinstance(speakers_filepath, Path):
|
216 |
+
speakers_filepath = [speakers_filepath]
|
217 |
+
speakers: Set[int] = set()
|
218 |
+
for filepath in speakers_filepath:
|
219 |
+
with open(Path(filepath).expanduser(), "r") as f:
|
220 |
+
speakers.update(map(int, f.read().split()))
|
221 |
+
self.speakers = np.asarray(sorted(speakers))
|
222 |
+
logging.info(f"Loaded {len(self.speakers)} speakers")
|
223 |
+
|
224 |
+
# Load manifest
|
225 |
+
if isinstance(manifest_filepath, str):
|
226 |
+
manifest_filepath = manifest_filepath.split(",")
|
227 |
+
elif isinstance(manifest_filepath, Path):
|
228 |
+
manifest_filepath = [manifest_filepath]
|
229 |
+
self.manifest_paths = [Path(filepath) for filepath in manifest_filepath]
|
230 |
+
|
231 |
+
num_skipped_words = 0
|
232 |
+
num_skipped_utterances = 0
|
233 |
+
asr_texts = []
|
234 |
+
tts_texts = []
|
235 |
+
need_normalization = False
|
236 |
+
|
237 |
+
for manifest_path in self.manifest_paths:
|
238 |
+
for tmp_item in tqdm(_iterate_manifest(manifest_path)):
|
239 |
+
text = tmp_item["text"]
|
240 |
+
num_words = len(text.split())
|
241 |
+
# skip if number of works not in desired range
|
242 |
+
# TODO: maybe it would be valuable to sample sub-utterances from long utterances
|
243 |
+
if not (min_words <= num_words <= max_words):
|
244 |
+
num_skipped_words += num_words
|
245 |
+
num_skipped_utterances += 1
|
246 |
+
continue
|
247 |
+
asr_texts.append(tmp_item["text"])
|
248 |
+
if "tts_text_normalized" in tmp_item:
|
249 |
+
tts_texts.append(tmp_item["tts_text_normalized"])
|
250 |
+
else:
|
251 |
+
tts_texts.append(tmp_item["tts_text"])
|
252 |
+
need_normalization = True
|
253 |
+
|
254 |
+
if need_normalization:
|
255 |
+
logging.warning("TTS normalization is extremely slow! It is recommended to normalize TTS text")
|
256 |
+
|
257 |
+
if num_skipped_utterances:
|
258 |
+
logging.warning(f"Skipped {num_skipped_utterances} utterances " f"with {num_skipped_words}")
|
259 |
+
|
260 |
+
num_utterances = len(asr_texts)
|
261 |
+
# preprocessing is very costly, if we need only part - remove unnecessary utterances
|
262 |
+
if num_parts > 1:
|
263 |
+
# NB: floor division, full dataset can contain fewer utterances than original, like in tarred dataset
|
264 |
+
num_utterances_part = num_utterances // num_parts
|
265 |
+
start = num_utterances_part * current_part_index
|
266 |
+
end = start + num_utterances_part
|
267 |
+
logging.info(
|
268 |
+
f"Taking part of the dataset: {current_part_index} index, total {num_parts} from {start} to {end}"
|
269 |
+
)
|
270 |
+
asr_texts = asr_texts[start:end]
|
271 |
+
tts_texts = tts_texts[start:end]
|
272 |
+
num_utterances = num_utterances_part
|
273 |
+
|
274 |
+
self.data = [dict() for _ in range(num_utterances)]
|
275 |
+
|
276 |
+
if len(asr_texts) == 0:
|
277 |
+
# no data was loaded
|
278 |
+
logging.warning("Text-to-text dataset is empty")
|
279 |
+
return
|
280 |
+
|
281 |
+
if tokenizer_workers == 1:
|
282 |
+
logging.warning(
|
283 |
+
"Preprocessing large text with tokenizer_workers=1 may be slow with TTS tokenizer. "
|
284 |
+
"Prefer tokenizer_workers=(num_cpu_cores/num_gpus_per_node)"
|
285 |
+
)
|
286 |
+
for i, tokenized_text in enumerate(
|
287 |
+
tqdm((self._asr_text_to_tokens(text) for text in asr_texts), total=len(asr_texts))
|
288 |
+
):
|
289 |
+
self.data[i]["asr_text_tokens"] = tokenized_text
|
290 |
+
else:
|
291 |
+
# Multiprocessing hack: use global variables for every process (not really global in program context)
|
292 |
+
def _init_asr_tokenize_process(tokenizer, bos_id, eos_id):
|
293 |
+
global asr_tokenizer_global, asr_bos_id_global, asr_eos_id_global # process-global
|
294 |
+
# deepcopy to avoid serialization of parent models
|
295 |
+
asr_tokenizer_global = copy.deepcopy(tokenizer)
|
296 |
+
asr_bos_id_global = copy.deepcopy(bos_id)
|
297 |
+
asr_eos_id_global = copy.deepcopy(eos_id)
|
298 |
+
|
299 |
+
with concurrent.futures.ProcessPoolExecutor(
|
300 |
+
initializer=_init_asr_tokenize_process,
|
301 |
+
initargs=(asr_tokenizer, self.asr_bos_id, self.asr_eos_id),
|
302 |
+
max_workers=tokenizer_workers,
|
303 |
+
) as pool:
|
304 |
+
# chunk size for pool map is empirically chosen as a trade-off between speed and responsiveness
|
305 |
+
for i, tokenized_text in enumerate(
|
306 |
+
tqdm(pool.map(_asr_text_to_tokens, asr_texts, chunksize=1000), total=len(asr_texts))
|
307 |
+
):
|
308 |
+
self.data[i]["asr_text_tokens"] = tokenized_text
|
309 |
+
# force free memory
|
310 |
+
del asr_texts
|
311 |
+
gc.collect()
|
312 |
+
|
313 |
+
if tokenizer_workers == 1:
|
314 |
+
logging.warning(
|
315 |
+
"Preprocessing large text with tokenizer_workers=1 may be slow with TTS tokenizer. "
|
316 |
+
"Prefer tokenizer_workers=(num_cpu_cores/num_gpus_per_node)"
|
317 |
+
)
|
318 |
+
for i, tokenized_text in enumerate(
|
319 |
+
tqdm(
|
320 |
+
(self._tts_text_to_tokens(text, normalize=need_normalization) for text in tts_texts),
|
321 |
+
total=len(tts_texts),
|
322 |
+
)
|
323 |
+
):
|
324 |
+
self.data[i]["tts_text_tokens"] = tokenized_text
|
325 |
+
else:
|
326 |
+
if need_normalization:
|
327 |
+
# TODO: implement, if we really need normalization inplace
|
328 |
+
raise NotImplementedError(
|
329 |
+
"Normalization with tokenizer_workers > 1 is not implemented. "
|
330 |
+
"It is not recommended to use normalization on the fly at all, since it's extremely slow"
|
331 |
+
)
|
332 |
+
|
333 |
+
def _init_tts_tokenize_process(tokenizer):
|
334 |
+
global tts_tokenizer_global # process-global
|
335 |
+
tts_tokenizer_global = copy.deepcopy(tokenizer)
|
336 |
+
|
337 |
+
with concurrent.futures.ProcessPoolExecutor(
|
338 |
+
initializer=_init_tts_tokenize_process, initargs=(tts_parser,), max_workers=tokenizer_workers,
|
339 |
+
) as pool:
|
340 |
+
# chunk size for pool map is empirically chosen as a trade-off between speed and responsiveness
|
341 |
+
for i, tokenized_text in enumerate(
|
342 |
+
tqdm(pool.map(_tts_text_to_tokens, tts_texts, chunksize=1000), total=len(tts_texts))
|
343 |
+
):
|
344 |
+
self.data[i]["tts_text_tokens"] = tokenized_text
|
345 |
+
# force free memory
|
346 |
+
del tts_texts
|
347 |
+
gc.collect()
|
348 |
+
|
349 |
+
def _asr_text_to_tokens(self, text: str) -> np.ndarray:
|
350 |
+
ids = self.asr_tokenizer.text_to_ids(text)
|
351 |
+
if self.asr_bos_id is not None:
|
352 |
+
ids = [self.asr_bos_id] + ids
|
353 |
+
if self.asr_eos_id is not None:
|
354 |
+
ids.append(self.asr_eos_id)
|
355 |
+
return np.asarray(ids)
|
356 |
+
|
357 |
+
def _tts_text_to_tokens(self, text: str, normalize=True) -> np.ndarray:
|
358 |
+
if normalize:
|
359 |
+
text = self.tts_normalizer.normalize(text, **self.tts_normalizer_kwargs)
|
360 |
+
tokens = self.tts_parser(text)
|
361 |
+
return np.asarray(tokens)
|
362 |
+
|
363 |
+
def __getitem__(self, index):
|
364 |
+
item = self.data[index]
|
365 |
+
return TextToTextItem(
|
366 |
+
transcript=torch.from_numpy(item["asr_text_tokens"]).long(),
|
367 |
+
tts_text=torch.from_numpy(item["tts_text_tokens"]).long(),
|
368 |
+
speaker=random.choice(self.speakers),
|
369 |
+
)
|
370 |
+
|
371 |
+
def __len__(self):
|
372 |
+
return len(self.data)
|
373 |
+
|
374 |
+
|
375 |
+
class TextToTextDataset(TextToTextDatasetBase, Dataset):
|
376 |
+
"""Text-to-Text Map-style Dataset for hybrid ASR-TTS models"""
|
377 |
+
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
manifest_filepath: Union[AnyPath, List[AnyPath]],
|
381 |
+
speakers_filepath: Union[AnyPath, List[AnyPath]],
|
382 |
+
asr_tokenizer: TokenizerSpec,
|
383 |
+
asr_use_start_end_token: bool,
|
384 |
+
tts_parser: Callable,
|
385 |
+
tts_text_pad_id: int,
|
386 |
+
tts_text_normalizer: "Normalizer",
|
387 |
+
tts_text_normalizer_call_kwargs: Dict,
|
388 |
+
min_words: int = 1,
|
389 |
+
max_words: int = 1_000_000,
|
390 |
+
tokenizer_workers: int = 1,
|
391 |
+
):
|
392 |
+
super().__init__(
|
393 |
+
manifest_filepath=manifest_filepath,
|
394 |
+
speakers_filepath=speakers_filepath,
|
395 |
+
asr_tokenizer=asr_tokenizer,
|
396 |
+
asr_use_start_end_token=asr_use_start_end_token,
|
397 |
+
tts_parser=tts_parser,
|
398 |
+
tts_text_pad_id=tts_text_pad_id,
|
399 |
+
tts_text_normalizer=tts_text_normalizer,
|
400 |
+
tts_text_normalizer_call_kwargs=tts_text_normalizer_call_kwargs,
|
401 |
+
min_words=min_words,
|
402 |
+
max_words=max_words,
|
403 |
+
tokenizer_workers=tokenizer_workers,
|
404 |
+
num_parts=1,
|
405 |
+
)
|
406 |
+
|
407 |
+
def collate_fn(
|
408 |
+
self, batch: List[Union[TextToTextItem, tuple]]
|
409 |
+
) -> Union[TextToTextBatch, TextOrAudioToTextBatch, tuple]:
|
410 |
+
"""
|
411 |
+
Collate function for dataloader
|
412 |
+
Can accept mixed batch of text-to-text items and audio-text items (typical for ASR)
|
413 |
+
"""
|
414 |
+
return TextOrAudioToTextBatch.collate_fn(
|
415 |
+
batch=batch, asr_pad_id=self.asr_pad_id, tts_text_pad_id=self.tts_text_pad_id
|
416 |
+
)
|
417 |
+
|
418 |
+
|
419 |
+
class TextToTextIterableDataset(TextToTextDatasetBase, IterableDataset):
|
420 |
+
"""
|
421 |
+
Text-to-Text Iterable Dataset for hybrid ASR-TTS models
|
422 |
+
Only part necessary for current process should be loaded and stored
|
423 |
+
"""
|
424 |
+
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
manifest_filepath: Union[AnyPath, List[AnyPath]],
|
428 |
+
speakers_filepath: Union[AnyPath, List[AnyPath]],
|
429 |
+
asr_tokenizer: TokenizerSpec,
|
430 |
+
asr_use_start_end_token: bool,
|
431 |
+
tts_parser: Callable,
|
432 |
+
tts_text_pad_id: int,
|
433 |
+
tts_text_normalizer: "Normalizer",
|
434 |
+
tts_text_normalizer_call_kwargs: Dict,
|
435 |
+
min_words: int = 1,
|
436 |
+
max_words: int = 1_000_000,
|
437 |
+
tokenizer_workers: int = 1,
|
438 |
+
num_parts: int = 1,
|
439 |
+
current_part_index: int = 0,
|
440 |
+
):
|
441 |
+
super().__init__(
|
442 |
+
manifest_filepath=manifest_filepath,
|
443 |
+
speakers_filepath=speakers_filepath,
|
444 |
+
asr_tokenizer=asr_tokenizer,
|
445 |
+
asr_use_start_end_token=asr_use_start_end_token,
|
446 |
+
tts_parser=tts_parser,
|
447 |
+
tts_text_pad_id=tts_text_pad_id,
|
448 |
+
tts_text_normalizer=tts_text_normalizer,
|
449 |
+
tts_text_normalizer_call_kwargs=tts_text_normalizer_call_kwargs,
|
450 |
+
min_words=min_words,
|
451 |
+
max_words=max_words,
|
452 |
+
tokenizer_workers=tokenizer_workers,
|
453 |
+
num_parts=num_parts,
|
454 |
+
current_part_index=current_part_index,
|
455 |
+
)
|
456 |
+
|
457 |
+
def __iter__(self):
|
458 |
+
# Implementation based on docs: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
459 |
+
worker_info = torch.utils.data.get_worker_info()
|
460 |
+
if worker_info is None: # single-process data loading, return the full iterator
|
461 |
+
start = 0
|
462 |
+
end = len(self)
|
463 |
+
else: # in a worker process
|
464 |
+
# split workload
|
465 |
+
per_worker = int(math.ceil(len(self) / float(worker_info.num_workers)))
|
466 |
+
worker_id = worker_info.id
|
467 |
+
start = worker_id * per_worker
|
468 |
+
end = min(start + per_worker, len(self))
|
469 |
+
indices = np.arange(start, end)
|
470 |
+
np.random.shuffle(indices)
|
471 |
+
return map(self.__getitem__, indices)
|
472 |
+
|
473 |
+
def collate_fn(
|
474 |
+
self, batch: List[Union[TextToTextItem, tuple]]
|
475 |
+
) -> Union[TextToTextBatch, TextOrAudioToTextBatch, tuple]:
|
476 |
+
"""
|
477 |
+
Collate function for dataloader
|
478 |
+
Can accept mixed batch of text-to-text items and audio-text items (typical for ASR)
|
479 |
+
"""
|
480 |
+
return TextOrAudioToTextBatch.collate_fn(
|
481 |
+
batch=batch, asr_pad_id=self.asr_pad_id, tts_text_pad_id=self.tts_text_pad_id
|
482 |
+
)
|
SoundScribe/SpeakerID/nemo/collections/asr/losses/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
|
16 |
+
from nemo.collections.asr.losses.audio_losses import SDRLoss
|
17 |
+
from nemo.collections.asr.losses.ctc import CTCLoss
|
18 |
+
from nemo.collections.asr.losses.lattice_losses import LatticeLoss
|
19 |
+
from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss
|
20 |
+
from nemo.collections.asr.losses.ssl_losses.ctc import CTCLossForSSL
|
21 |
+
from nemo.collections.asr.losses.ssl_losses.mlm import MLMLoss
|
22 |
+
from nemo.collections.asr.losses.ssl_losses.rnnt import RNNTLossForSSL
|
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (784 Bytes). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/angularloss.cpython-310.pyc
ADDED
Binary file (2.43 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/audio_losses.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/ctc.cpython-310.pyc
ADDED
Binary file (2.29 kB). View file
|
|