crystal-technologies commited on
Commit
2d8da09
1 Parent(s): c1bb68d

Upload 1287 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SoundScribe/SpeakerID/Dockerfile +140 -0
  2. SoundScribe/SpeakerID/Jenkinsfile +0 -0
  3. SoundScribe/SpeakerID/LICENSE +201 -0
  4. SoundScribe/SpeakerID/ci.groovy +119 -0
  5. SoundScribe/SpeakerID/external/get_collections.py +90 -0
  6. SoundScribe/SpeakerID/external/get_modules.py +159 -0
  7. SoundScribe/SpeakerID/nemo/README.md +9 -0
  8. SoundScribe/SpeakerID/nemo/__init__.py +28 -0
  9. SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-310.pyc +0 -0
  10. SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-39.pyc +0 -0
  11. SoundScribe/SpeakerID/nemo/__pycache__/constants.cpython-310.pyc +0 -0
  12. SoundScribe/SpeakerID/nemo/__pycache__/package_info.cpython-310.pyc +0 -0
  13. SoundScribe/SpeakerID/nemo/collections/__init__.py +13 -0
  14. SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-310.pyc +0 -0
  15. SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-39.pyc +0 -0
  16. SoundScribe/SpeakerID/nemo/collections/asr/__init__.py +25 -0
  17. SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-310.pyc +0 -0
  18. SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-39.pyc +0 -0
  19. SoundScribe/SpeakerID/nemo/collections/asr/data/__init__.py +13 -0
  20. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/__init__.cpython-310.pyc +0 -0
  21. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio.cpython-310.pyc +0 -0
  22. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio_dataset.cpython-310.pyc +0 -0
  23. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_diar_label.cpython-310.pyc +0 -0
  24. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label.cpython-310.pyc +0 -0
  25. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label_dataset.cpython-310.pyc +0 -0
  26. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text.cpython-310.pyc +0 -0
  27. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dali.cpython-310.pyc +0 -0
  28. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dataset.cpython-310.pyc +0 -0
  29. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label.cpython-310.pyc +0 -0
  30. SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label_dataset.cpython-310.pyc +0 -0
  31. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio.py +1136 -0
  32. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio_dataset.py +95 -0
  33. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_ctm_dataset.py +95 -0
  34. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_diar_label.py +853 -0
  35. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label.py +1294 -0
  36. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label_dataset.py +304 -0
  37. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text.py +1366 -0
  38. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dali.py +772 -0
  39. SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dataset.py +950 -0
  40. SoundScribe/SpeakerID/nemo/collections/asr/data/data_simulation.py +0 -0
  41. SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label.py +497 -0
  42. SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label_dataset.py +68 -0
  43. SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text.py +488 -0
  44. SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text_dataset.py +94 -0
  45. SoundScribe/SpeakerID/nemo/collections/asr/data/text_to_text.py +482 -0
  46. SoundScribe/SpeakerID/nemo/collections/asr/losses/__init__.py +22 -0
  47. SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/__init__.cpython-310.pyc +0 -0
  48. SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/angularloss.cpython-310.pyc +0 -0
  49. SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/audio_losses.cpython-310.pyc +0 -0
  50. SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/ctc.cpython-310.pyc +0 -0
SoundScribe/SpeakerID/Dockerfile ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:experimental
2
+
3
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.08-py3
18
+
19
+ # build an image that includes only the nemo dependencies, ensures that dependencies
20
+ # are included first for optimal caching, and useful for building a development
21
+ # image (by specifying build target as `nemo-deps`)
22
+ FROM ${BASE_IMAGE} as nemo-deps
23
+
24
+ # dependency flags; should be declared after FROM
25
+ # torchaudio: not required by default
26
+ ARG REQUIRE_TORCHAUDIO=false
27
+ # k2: not required by default
28
+ ARG REQUIRE_K2=false
29
+ # ais cli: not required by default, install only if required
30
+ ARG REQUIRE_AIS_CLI=false
31
+
32
+ # Ensure apt-get won't prompt for selecting options
33
+ ENV DEBIAN_FRONTEND=noninteractive
34
+ # libavdevice-dev rerquired for latest torchaudio
35
+ RUN apt-get update && \
36
+ apt-get upgrade -y && \
37
+ apt-get install -y \
38
+ libsndfile1 sox \
39
+ libfreetype6 \
40
+ swig \
41
+ ffmpeg \
42
+ libavdevice-dev && \
43
+ rm -rf /var/lib/apt/lists/*
44
+
45
+ WORKDIR /workspace/
46
+ # install megatron core, this can be removed once 0.3 pip package is released
47
+ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
48
+ cd Megatron-LM && \
49
+ git checkout ab0336a5c8eab77aa74ae604ba1e73decbf6d560 && \
50
+ pip install -e .
51
+
52
+ WORKDIR /tmp/
53
+
54
+ # Distributed Adam support for multiple dtypes
55
+ RUN git clone https://github.com/NVIDIA/apex.git && \
56
+ cd apex && \
57
+ git checkout 52e18c894223800cb611682dce27d88050edf1de && \
58
+ pip3 install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./
59
+
60
+ # uninstall stuff from base container
61
+ RUN pip3 uninstall -y sacrebleu torchtext
62
+
63
+ # build torchaudio
64
+ WORKDIR /tmp/torchaudio_build
65
+ COPY scripts/installers /tmp/torchaudio_build/scripts/installers/
66
+ RUN INSTALL_MSG=$(/bin/bash /tmp/torchaudio_build/scripts/installers/install_torchaudio_latest.sh); INSTALL_CODE=$?; \
67
+ echo ${INSTALL_MSG}; \
68
+ if [ ${INSTALL_CODE} -ne 0 ]; then \
69
+ echo "torchaudio installation failed"; \
70
+ if [ "${REQUIRE_TORCHAUDIO}" = true ]; then \
71
+ exit ${INSTALL_CODE}; \
72
+ else echo "Skipping failed torchaudio installation"; fi \
73
+ else echo "torchaudio installed successfully"; fi
74
+
75
+ # install nemo dependencies
76
+ WORKDIR /tmp/nemo
77
+ COPY requirements .
78
+ RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done
79
+
80
+ # install flash attention dependencies
81
+ RUN pip install flash-attn
82
+ # pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
83
+ RUN pip install triton==2.0.0.dev20221202
84
+ # install numba for latest containers
85
+ RUN pip install numba>=0.57.1
86
+
87
+ # install k2, skip if installation fails
88
+ COPY scripts /tmp/nemo/scripts/
89
+ RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/speech_recognition/k2/setup.sh); INSTALL_CODE=$?; \
90
+ echo ${INSTALL_MSG}; \
91
+ if [ ${INSTALL_CODE} -ne 0 ]; then \
92
+ echo "k2 installation failed"; \
93
+ if [ "${REQUIRE_K2}" = true ]; then \
94
+ exit ${INSTALL_CODE}; \
95
+ else echo "Skipping failed k2 installation"; fi \
96
+ else echo "k2 installed successfully"; fi
97
+
98
+ # copy nemo source into a scratch image
99
+ FROM scratch as nemo-src
100
+ COPY . .
101
+
102
+ # start building the final container
103
+ FROM nemo-deps as nemo
104
+ ARG NEMO_VERSION=1.21.0
105
+
106
+ # Check that NEMO_VERSION is set. Build will fail without this. Expose NEMO and base container
107
+ # version information as runtime environment variable for introspection purposes
108
+ RUN /usr/bin/test -n "$NEMO_VERSION" && \
109
+ /bin/echo "export NEMO_VERSION=${NEMO_VERSION}" >> /root/.bashrc && \
110
+ /bin/echo "export BASE_IMAGE=${BASE_IMAGE}" >> /root/.bashrc
111
+
112
+ # Install NeMo
113
+ RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]"
114
+
115
+ # Check install
116
+ RUN python -c "import nemo.collections.nlp as nemo_nlp" && \
117
+ python -c "import nemo.collections.tts as nemo_tts" && \
118
+ python -c "import nemo_text_processing.text_normalization as text_normalization"
119
+
120
+
121
+ # copy scripts/examples/tests into container for end user
122
+ WORKDIR /workspace/nemo
123
+ COPY scripts /workspace/nemo/scripts
124
+ COPY examples /workspace/nemo/examples
125
+ COPY tests /workspace/nemo/tests
126
+ COPY tutorials /workspace/nemo/tutorials
127
+ # COPY README.rst LICENSE /workspace/nemo/
128
+
129
+ RUN printf "#!/bin/bash\njupyter lab --no-browser --allow-root --ip=0.0.0.0" >> start-jupyter.sh && \
130
+ chmod +x start-jupyter.sh
131
+
132
+ # If required, install AIS CLI
133
+ RUN if [ "${REQUIRE_AIS_CLI}" = true ]; then \
134
+ INSTALL_MSG=$(/bin/bash scripts/installers/install_ais_cli_latest.sh); INSTALL_CODE=$?; \
135
+ echo ${INSTALL_MSG}; \
136
+ if [ ${INSTALL_CODE} -ne 0 ]; then \
137
+ echo "AIS CLI installation failed"; \
138
+ exit ${INSTALL_CODE}; \
139
+ else echo "AIS CLI installed successfully"; fi \
140
+ else echo "Skipping AIS CLI installation"; fi
SoundScribe/SpeakerID/Jenkinsfile ADDED
The diff for this file is too large to render. See raw diff
 
SoundScribe/SpeakerID/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
SoundScribe/SpeakerID/ci.groovy ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @Library('blossom-github-lib@master')
2
+ import ipp.blossom.*
3
+
4
+ podTemplate(cloud:'sc-ipp-blossom-prod', yaml : """
5
+ apiVersion: v1
6
+ kind: Pod
7
+ metadata:
8
+ labels:
9
+ some-label: some-label-value
10
+ spec:
11
+ volumes:
12
+ - name: scratch
13
+ nfs:
14
+ server: ipp1-cdot01-col01
15
+ path: /vol/scratch1/scratch.okuchaiev_blossom
16
+ containers:
17
+ - name: latestdlfw
18
+ image: nvcr.io/nvidia/pytorch:23.02-py3
19
+ command:
20
+ - cat
21
+ volumeMounts:
22
+ - name: scratch
23
+ mountPath: /testdata
24
+ resources:
25
+ limits:
26
+ nvidia.com/gpu: 2
27
+ restartPolicy: Never
28
+ backoffLimit: 4
29
+ tty: true
30
+ shm-size: 32g
31
+ nodeSelector:
32
+ kubernetes.io/os: linux
33
+ nvidia.com/gpu_type: "Tesla_T4x4"
34
+ nvidia.com/node_type: gpu_tester
35
+ nvidia.com/driver_version: "510.20"
36
+ """
37
+ ) {
38
+ node(POD_LABEL) {
39
+ def githubHelper
40
+ stage('Get Token') {
41
+ withCredentials([usernamePassword(credentialsId: 'GHAtoken', passwordVariable: 'GIT_PASSWORD', usernameVariable: 'GIT_USERNAME')]) {
42
+ // create new instance of helper object
43
+ githubHelper = GithubHelper.getInstance("${GIT_PASSWORD}", githubData)
44
+ }
45
+
46
+ }
47
+ def stageName = ''
48
+ try {
49
+ currentBuild.description = githubHelper.getBuildDescription()
50
+ container('latestdlfw') {
51
+ stage('Code checkout') {
52
+ // update status on github
53
+ githubHelper.updateCommitStatus("$BUILD_URL", "$stageName Running", GitHubCommitState.PENDING)
54
+ checkout changelog: true, poll: true, scm: [$class: 'GitSCM', branches: [[name: "pr/"+githubHelper.getPRNumber()]],
55
+ doGenerateSubmoduleConfigurations: false,
56
+ submoduleCfg: [],
57
+ userRemoteConfigs: [[credentialsId: 'github-token', url: githubHelper.getCloneUrl(), refspec: '+refs/pull/*/head:refs/remotes/origin/pr/*']]]
58
+ }
59
+
60
+ stage('Code Style') {
61
+ sh "apt-get update && \
62
+ apt-get install -y bc && \
63
+ nvidia-smi && \
64
+ pip install -r requirements/requirements_test.txt && \
65
+ python setup.py style && ls -l /testdata/TestData && ln -s /testdata/TestData /home/TestData && \
66
+ ls -l /home && ls -l /home/TestData"
67
+ }
68
+
69
+ stage('Installation') {
70
+ sh "git config --global --add safe.directory '*' && nvidia-smi && ./reinstall.sh release"
71
+ }
72
+
73
+ stage('L0: GPU unit tests') {
74
+ sh "NEMO_NUMBA_MINVER=0.53 pytest -m 'not pleasefixme'"
75
+ }
76
+
77
+ parallel( //USE CUDA_VISIBLE_DEVICES to execute 2 single GPU tests in parallel here
78
+ [
79
+ "L1: NMT Training Pre-LN": { sh 'CUDA_VISIBLE_DEVICES=0 python examples/nlp/machine_translation/enc_dec_nmt.py \
80
+ --config-path=conf \
81
+ --config-name=aayn_base \
82
+ do_testing=true \
83
+ model.train_ds.src_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
84
+ model.train_ds.tgt_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
85
+ model.validation_ds.src_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
86
+ model.validation_ds.tgt_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
87
+ model.test_ds.src_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
88
+ model.test_ds.tgt_file_name=/testdata/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
89
+ model.encoder_tokenizer.tokenizer_model=/testdata/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
90
+ model.decoder_tokenizer.tokenizer_model=/testdata/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
91
+ model.encoder.pre_ln=true \
92
+ model.decoder.pre_ln=true \
93
+ trainer.devices=[0] \
94
+ trainer.accelerator="gpu" \
95
+ +trainer.fast_dev_run=true \
96
+ +trainer.limit_test_batches=2 \
97
+ exp_manager=null \
98
+ '},
99
+ "L1: Speech to text": { sh 'CUDA_VISIBLE_DEVICES=1 python examples/asr/asr_ctc/speech_to_text_ctc.py \
100
+ model.train_ds.manifest_filepath=/testdata/TestData/an4_dataset/an4_train.json \
101
+ model.validation_ds.manifest_filepath=/testdata/TestData/an4_dataset/an4_val.json \
102
+ trainer.devices=[0] \
103
+ trainer.accelerator="gpu" \
104
+ +trainer.fast_dev_run=True \
105
+ exp_manager=null \
106
+ '}
107
+ ]
108
+ )//end of parallel
109
+ }
110
+ githubHelper.updateCommitStatus("$BUILD_URL", "Complete", GitHubCommitState.SUCCESS)
111
+ }
112
+ catch (Exception ex){
113
+ currentBuild.result = 'FAILURE'
114
+ println ex
115
+ githubHelper.updateCommitStatus("$BUILD_URL", "$stageName Failed", GitHubCommitState.FAILURE)
116
+ }
117
+
118
+ }
119
+ }
SoundScribe/SpeakerID/external/get_collections.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Script responsible for generation of a JSON file with list of NeMo collections. """
17
+
18
+ import argparse
19
+ import importlib
20
+ import json
21
+ import os
22
+
23
+ import nemo
24
+ from nemo.utils import logging
25
+
26
+
27
+ def process_collection(id, col):
28
+ """ Helper function processing the collection.
29
+
30
+ Args:
31
+ id: (short) name of the collection.
32
+ col: a collection (python module).
33
+ """
34
+ return {
35
+ "id": id,
36
+ "name": col.__name__,
37
+ "description": col.__description__,
38
+ "version": col.__version__,
39
+ "author": col.__author__,
40
+ }
41
+
42
+
43
+ def main():
44
+ """ Main function generating a JSON file with list of NeMo collections. """
45
+ # Parse filename.
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="collections.json")
48
+ args = parser.parse_args()
49
+
50
+ # Get collections directory.
51
+ colletions_dir = os.path.dirname(nemo.collections.__file__)
52
+ logging.info('Analysing collections in `{}`'.format(colletions_dir))
53
+
54
+ # Generate list of NeMo collections - from the list of collection subfolders.
55
+ collections = {}
56
+ for sub_dir in os.listdir(colletions_dir):
57
+ # Skip cache.
58
+ if sub_dir == "__pycache__":
59
+ continue
60
+ # Check if it is a directory.
61
+ if os.path.isdir(os.path.join(colletions_dir, sub_dir)):
62
+ collections[sub_dir] = "nemo.collections." + sub_dir
63
+
64
+ output_list = []
65
+ # Iterate over all collections.
66
+ for key, val in collections.items():
67
+ # Try to get module specification.
68
+ module_spec = importlib.util.find_spec(val)
69
+ if module_spec is None:
70
+ logging.warning(" * Failed to process `{}`".format(val))
71
+ else:
72
+ try:
73
+ # Import the module from the module specification.
74
+ module = importlib.util.module_from_spec(module_spec)
75
+ module_spec.loader.exec_module(module)
76
+ # Add to list.
77
+ output_list.append(process_collection(key, module))
78
+ logging.info(" * Processed `{}`".format(val))
79
+ except AttributeError:
80
+ logging.warning(" * Failed to process `{}`".format(val))
81
+
82
+ # Export to JSON.
83
+ with open(args.filename, 'w', encoding='utf-8') as outfile:
84
+ json.dump(output_list, outfile)
85
+
86
+ logging.info('Finshed the analysis, results exported to `{}`.'.format(args.filename))
87
+
88
+
89
+ if __name__ == '__main__':
90
+ main()
SoundScribe/SpeakerID/external/get_modules.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Script responsible for generation of a JSON file containing list of modules of a given collection. """
17
+
18
+ import argparse
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+
24
+ import nemo
25
+ from nemo.utils import logging
26
+
27
+
28
+ def process_member(name, obj, module_list):
29
+ """ Helper function processing the passed object and, if ok, adding a record to the module list.
30
+
31
+ Args:
32
+ name: name of the member
33
+ obj: member (class/function etc.)
34
+ module_list: list of modules that (probably) will be expanded.
35
+ """
36
+ # It is not a class - skip it.
37
+ if not inspect.isclass(obj):
38
+ return
39
+
40
+ # Check inheritance - we know that all our datasets/modules/losses inherit from Serialization,
41
+ # Btw. Serialization is also required by this script.
42
+ if not issubclass(obj, nemo.core.Serialization):
43
+ return
44
+
45
+ logging.info(" * Processing `{}`".format(str(obj)))
46
+
47
+ module_list.append(
48
+ {
49
+ "name": name,
50
+ "cls": str(obj),
51
+ # Temporary solution: mockup arguments.
52
+ "arguments": [
53
+ "jasper",
54
+ "activation",
55
+ "feat_in",
56
+ "normalization_mode",
57
+ "residual_mode",
58
+ "norm_groups",
59
+ "conv_mask",
60
+ "frame_splicing",
61
+ "init_mode",
62
+ ],
63
+ # Temporary solution: mockup input types.
64
+ "input_types": {
65
+ "audio_signal": "axes: (batch, dimension, time); elements_type: MelSpectrogramType",
66
+ "length": "axes: (batch,); elements_type: LengthType",
67
+ },
68
+ # Temporary solution: mockup output types.
69
+ "output_types": {
70
+ "encoder_output": "axes: (batch, dimension, time); elements_type: AcousticEncodedRepresentation"
71
+ },
72
+ }
73
+ )
74
+
75
+
76
+ def main():
77
+ """ Main function analysing the indicated NeMo collection and generating a JSON file with module descriptions. """
78
+ # Parse filename.
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument('--collection', help='ID of the collection', type=str)
81
+ parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="modules.json")
82
+ args = parser.parse_args()
83
+
84
+ # Get collections directory.
85
+ colletions_dir = os.path.dirname(nemo.collections.__file__)
86
+ logging.info('Analysing collections in `{}`'.format(colletions_dir))
87
+
88
+ # Generate list of NeMo collections - from the list of collection subfolders.
89
+ collections = {}
90
+ for sub_dir in os.listdir(colletions_dir):
91
+ # Skip cache.
92
+ if sub_dir == "__pycache__":
93
+ continue
94
+ # Check if it is a directory.
95
+ if os.path.isdir(os.path.join(colletions_dir, sub_dir)):
96
+ collections[sub_dir] = "nemo.collections." + sub_dir
97
+
98
+ # Check the collection.
99
+ if args.collection not in collections.keys():
100
+ logging.error("Coudn't process the incidated `{}` collection".format(args.collection))
101
+ logging.info(
102
+ "Please select one of the existing collections using `--collection [{}]`".format("|".join(collections))
103
+ )
104
+ exit(-1)
105
+
106
+ # Load the collection specification.
107
+ collection_spec = importlib.util.find_spec(collections[args.collection])
108
+ if collection_spec is None:
109
+ logging.error("Failed to load the `{}` collection".format(val))
110
+
111
+ # Import the module from the module specification.
112
+ collection = importlib.util.module_from_spec(collection_spec)
113
+ collection_spec.loader.exec_module(collection)
114
+
115
+ module_list = []
116
+ # Iterate over the packages in the indicated collection.
117
+ logging.info("Analysing the `{}` collection".format(args.collection))
118
+
119
+ try: # Datasets in dataset folder
120
+ logging.info("Analysing the 'data' package")
121
+ for name, obj in inspect.getmembers(collection.data):
122
+ process_member(name, obj, module_list)
123
+ except AttributeError as e:
124
+ logging.info(" * No datasets found")
125
+
126
+ try: # Datasets in dataset folder
127
+ logging.info("Analysing the 'datasets' package")
128
+ for name, obj in inspect.getmembers(collection.datasets):
129
+ process_member(name, obj, module_list)
130
+ except AttributeError as e:
131
+ logging.info(" * No datasets found")
132
+
133
+ try: # Modules
134
+ logging.info("Analysing the 'modules' package")
135
+ for name, obj in inspect.getmembers(collection.modules):
136
+ process_member(name, obj, module_list)
137
+ except AttributeError as e:
138
+ logging.info(" * No modules found")
139
+
140
+ try: # Losses
141
+ logging.info("Analysing the 'losses' package")
142
+ for name, obj in inspect.getmembers(collection.losses):
143
+ process_member(name, obj, module_list)
144
+ except AttributeError as e:
145
+ logging.info(" * No losses found")
146
+
147
+ # Add prefix - only for default name.
148
+ filename = args.filename if args.filename != "modules.json" else args.collection + "_" + args.filename
149
+ # Export to JSON.
150
+ with open(filename, 'w', encoding='utf-8') as outfile:
151
+ json.dump(module_list, outfile)
152
+
153
+ logging.info(
154
+ 'Finished analysis of the `{}` collection, results exported to `{}`.'.format(args.collection, filename)
155
+ )
156
+
157
+
158
+ if __name__ == '__main__':
159
+ main()
SoundScribe/SpeakerID/nemo/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ NeMo (**Ne**ural **Mo**dules) is a toolkit for creating AI applications built around **neural modules**, conceptual blocks of neural networks that take *typed* inputs and produce *typed* outputs.
2
+
3
+ **NeMo Core** provides common APIs all modules and models have to implement.
4
+
5
+ **NeMo Collections**
6
+
7
+ * ASR - collection of modules and models for building speech recognition networks
8
+ * TTS - collection of modules and models for building speech synthesis networks
9
+ * NLP - collection of modules and models for building NLP networks
SoundScribe/SpeakerID/nemo/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from nemo.package_info import (
17
+ __contact_emails__,
18
+ __contact_names__,
19
+ __description__,
20
+ __download_url__,
21
+ __homepage__,
22
+ __keywords__,
23
+ __license__,
24
+ __package_name__,
25
+ __repository_url__,
26
+ __shortversion__,
27
+ __version__,
28
+ )
SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (454 Bytes). View file
 
SoundScribe/SpeakerID/nemo/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (452 Bytes). View file
 
SoundScribe/SpeakerID/nemo/__pycache__/constants.cpython-310.pyc ADDED
Binary file (549 Bytes). View file
 
SoundScribe/SpeakerID/nemo/__pycache__/package_info.cpython-310.pyc ADDED
Binary file (909 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (150 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from nemo.collections.asr import data, losses, models, modules
16
+ from nemo.package_info import __version__
17
+
18
+ # Set collection version equal to NeMo version.
19
+ __version = __version__
20
+
21
+ # Authorship.
22
+ __author__ = "NVIDIA Corporation"
23
+
24
+ # Set collection name.
25
+ __description__ = "Automatic Speech Recognition collection"
SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (429 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (427 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (161 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio.cpython-310.pyc ADDED
Binary file (37.9 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_audio_dataset.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_diar_label.cpython-310.pyc ADDED
Binary file (34.5 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label.cpython-310.pyc ADDED
Binary file (50.4 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_label_dataset.cpython-310.pyc ADDED
Binary file (7.75 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text.cpython-310.pyc ADDED
Binary file (50.8 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dali.cpython-310.pyc ADDED
Binary file (24.9 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/audio_to_text_dataset.cpython-310.pyc ADDED
Binary file (23.8 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/__pycache__/feature_to_label_dataset.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import math
17
+ import random
18
+ from collections import OrderedDict, namedtuple
19
+ from dataclasses import dataclass
20
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
21
+
22
+ import librosa
23
+ import numpy as np
24
+ import torch
25
+
26
+ from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
27
+ from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
28
+ from nemo.collections.common.parts.preprocessing import collections
29
+ from nemo.collections.common.parts.utils import flatten
30
+ from nemo.core.classes import Dataset
31
+ from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType
32
+ from nemo.utils import logging
33
+
34
+ __all__ = [
35
+ 'AudioToTargetDataset',
36
+ 'AudioToTargetWithReferenceDataset',
37
+ 'AudioToTargetWithEmbeddingDataset',
38
+ ]
39
+
40
+
41
+ def _audio_collate_fn(batch: List[dict]) -> Tuple[torch.Tensor]:
42
+ """Collate a batch of items returned by __getitem__.
43
+ Examples for each signal are zero padded to the same length
44
+ (batch_length), which is determined by the longest example.
45
+ Lengths of the original signals are returned in the output.
46
+
47
+ Args:
48
+ batch: List of dictionaries. Each element of the list
49
+ has the following format
50
+ ```
51
+ {
52
+ 'signal_0': 1D or 2D tensor,
53
+ 'signal_1': 1D or 2D tensor,
54
+ ...
55
+ 'signal_N': 1D or 2D tensor,
56
+ }
57
+ ```
58
+ 1D tensors have shape (num_samples,) and 2D tensors
59
+ have shape (num_channels, num_samples)
60
+
61
+ Returns:
62
+ A tuple containing signal tensor and signal length tensor (in samples)
63
+ for each signal.
64
+ The output has the following format:
65
+ ```
66
+ (signal_0, signal_0_length, signal_1, signal_1_length, ..., signal_N, signal_N_length)
67
+ ```
68
+ Note that the output format is obtained by interleaving signals and their length.
69
+ """
70
+ signals = batch[0].keys()
71
+
72
+ batched = tuple()
73
+
74
+ for signal in signals:
75
+ signal_length = [b[signal].shape[-1] for b in batch]
76
+ # Batch length is determined by the longest signal in the batch
77
+ batch_length = max(signal_length)
78
+ b_signal = []
79
+ for s_len, b in zip(signal_length, batch):
80
+ # check if padding is necessary
81
+ if s_len < batch_length:
82
+ if b[signal].ndim == 1:
83
+ # single-channel signal
84
+ pad = (0, batch_length - s_len)
85
+ elif b[signal].ndim == 2:
86
+ # multi-channel signal
87
+ pad = (0, batch_length - s_len, 0, 0)
88
+ else:
89
+ raise RuntimeError(
90
+ f'Signal {signal} has unsuported dimensions {signal.shape}. Currently, only 1D and 2D arrays are supported.'
91
+ )
92
+ b[signal] = torch.nn.functional.pad(b[signal], pad)
93
+ # append the current padded signal
94
+ b_signal.append(b[signal])
95
+ # (signal_batched, signal_length)
96
+ batched += (torch.stack(b_signal), torch.tensor(signal_length, dtype=torch.int32))
97
+
98
+ # Currently, outputs are expected to be in a tuple, where each element must correspond
99
+ # to the output type in the OrderedDict returned by output_types.
100
+ #
101
+ # Therefore, we return batched signals by interleaving signals and their length:
102
+ # (signal_0, signal_0_length, signal_1, signal_1_length, ...)
103
+ return batched
104
+
105
+
106
+ @dataclass
107
+ class SignalSetup:
108
+ signals: List[str] # signal names
109
+ duration: Optional[Union[float, list]] = None # duration for each signal
110
+ channel_selectors: Optional[List[ChannelSelectorType]] = None # channel selector for loading each signal
111
+
112
+
113
+ class ASRAudioProcessor:
114
+ """Class that processes an example from Audio collection and returns
115
+ a dictionary with prepared signals.
116
+
117
+ For example, the output dictionary may be the following
118
+ ```
119
+ {
120
+ 'input_signal': input_signal_tensor,
121
+ 'target_signal': target_signal_tensor,
122
+ 'reference_signal': reference_signal_tensor,
123
+ 'embedding_vector': embedding_vector
124
+ }
125
+ ```
126
+ Keys in the output dictionary are ordered with synchronous signals given first,
127
+ followed by asynchronous signals and embedding.
128
+
129
+ Args:
130
+ sample_rate: sample rate used for all audio signals
131
+ random_offset: If `True`, offset will be randomized when loading a subsegment
132
+ from a file.
133
+ """
134
+
135
+ def __init__(
136
+ self, sample_rate: float, random_offset: bool,
137
+ ):
138
+ self.sample_rate = sample_rate
139
+ self.random_offset = random_offset
140
+
141
+ self.sync_setup = None
142
+ self.async_setup = None
143
+ self.embedding_setup = None
144
+
145
+ @property
146
+ def sample_rate(self) -> float:
147
+ return self._sample_rate
148
+
149
+ @sample_rate.setter
150
+ def sample_rate(self, value: float):
151
+ if value <= 0:
152
+ raise ValueError(f'Sample rate must be positive, received {value}')
153
+
154
+ self._sample_rate = value
155
+
156
+ @property
157
+ def random_offset(self) -> bool:
158
+ return self._random_offset
159
+
160
+ @random_offset.setter
161
+ def random_offset(self, value: bool):
162
+ self._random_offset = value
163
+
164
+ @property
165
+ def sync_setup(self) -> SignalSetup:
166
+ """Return the current setup for synchronous signals.
167
+
168
+ Returns:
169
+ A dataclass containing the list of signals, their
170
+ duration and channel selectors.
171
+ """
172
+ return self._sync_setup
173
+
174
+ @sync_setup.setter
175
+ def sync_setup(self, value: Optional[SignalSetup]):
176
+ """Setup signals to be loaded synchronously.
177
+
178
+ Args:
179
+ value: An instance of SignalSetup with the following fields
180
+ - signals: list of signals (keys of example.audio_signals) which will be loaded
181
+ synchronously with the same start time and duration.
182
+ - duration: Duration for each signal to be loaded.
183
+ If duration is set to None, the whole file will be loaded.
184
+ - channel_selectors: A list of channel selector for each signal. If channel selector
185
+ is None, all channels in the audio file will be loaded.
186
+ """
187
+ if value is None or isinstance(value, SignalSetup):
188
+ self._sync_setup = value
189
+ else:
190
+ raise ValueError(f'Unexpected type {type(value)} for value {value}.')
191
+
192
+ @property
193
+ def async_setup(self) -> SignalSetup:
194
+ """Return the current setup for asynchronous signals.
195
+
196
+ Returns:
197
+ A dataclass containing the list of signals, their
198
+ duration and channel selectors.
199
+ """
200
+ return self._async_setup
201
+
202
+ @async_setup.setter
203
+ def async_setup(self, value: Optional[SignalSetup]):
204
+ """Setup signals to be loaded asynchronously.
205
+
206
+ Args:
207
+ Args:
208
+ value: An instance of SignalSetup with the following fields
209
+ - signals: list of signals (keys of example.audio_signals) which will be loaded
210
+ asynchronously with signals possibly having different start and duration
211
+ - duration: Duration for each signal to be loaded.
212
+ If duration is set to None, the whole file will be loaded.
213
+ - channel_selectors: A list of channel selector for each signal. If channel selector
214
+ is None, all channels in the audio file will be loaded.
215
+ """
216
+ if value is None or isinstance(value, SignalSetup):
217
+ self._async_setup = value
218
+ else:
219
+ raise ValueError(f'Unexpected type {type(value)} for value {value}.')
220
+
221
+ @property
222
+ def embedding_setup(self) -> SignalSetup:
223
+ """Setup signals corresponding to an embedding vector.
224
+ """
225
+ return self._embedding_setup
226
+
227
+ @embedding_setup.setter
228
+ def embedding_setup(self, value: SignalSetup):
229
+ """Setup signals corresponding to an embedding vector.
230
+
231
+ Args:
232
+ value: An instance of SignalSetup with the following fields
233
+ - signals: list of signals (keys of example.audio_signals) which will be loaded
234
+ as embedding vectors.
235
+ """
236
+ if value is None or isinstance(value, SignalSetup):
237
+ self._embedding_setup = value
238
+ else:
239
+ raise ValueError(f'Unexpected type {type(value)} for value {value}.')
240
+
241
+ def process(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
242
+ """Process an example from a collection of audio examples.
243
+
244
+ Args:
245
+ example: an example from Audio collection.
246
+
247
+ Returns:
248
+ An ordered dictionary of signals and their tensors.
249
+ For example, the output dictionary may be the following
250
+ ```
251
+ {
252
+ 'input_signal': input_signal_tensor,
253
+ 'target_signal': target_signal_tensor,
254
+ 'reference_signal': reference_signal_tensor,
255
+ 'embedding_vector': embedding_vector
256
+ }
257
+ ```
258
+ Keys in the output dictionary are ordered with synchronous signals given first,
259
+ followed by asynchronous signals and embedding.
260
+ """
261
+ audio = self.load_audio(example=example)
262
+ audio = self.process_audio(audio=audio)
263
+ return audio
264
+
265
+ def load_audio(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
266
+ """Given an example, load audio from `example.audio_files` and prepare
267
+ the output dictionary.
268
+
269
+ Args:
270
+ example: An example from an audio collection
271
+
272
+ Returns:
273
+ An ordered dictionary of signals and their tensors.
274
+ For example, the output dictionary may be the following
275
+ ```
276
+ {
277
+ 'input_signal': input_signal_tensor,
278
+ 'target_signal': target_signal_tensor,
279
+ 'reference_signal': reference_signal_tensor,
280
+ 'embedding_vector': embedding_vector
281
+ }
282
+ ```
283
+ Keys in the output dictionary are ordered with synchronous signals given first,
284
+ followed by asynchronous signals and embedding.
285
+ """
286
+ output = OrderedDict()
287
+
288
+ if self.sync_setup is not None:
289
+ # Load all signals with the same start and duration
290
+ sync_signals = self.load_sync_signals(example)
291
+ output.update(sync_signals)
292
+
293
+ if self.async_setup is not None:
294
+ # Load each signal independently
295
+ async_signals = self.load_async_signals(example)
296
+ output.update(async_signals)
297
+
298
+ # Load embedding vector
299
+ if self.embedding_setup is not None:
300
+ embedding = self.load_embedding(example)
301
+ output.update(embedding)
302
+
303
+ if not output:
304
+ raise RuntimeError('Output dictionary is empty. Please use `_setup` methods to setup signals to be loaded')
305
+
306
+ return output
307
+
308
+ def process_audio(self, audio: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
309
+ """Process audio signals available in the input dictionary.
310
+
311
+ Args:
312
+ audio: A dictionary containing loaded signals `signal: tensor`
313
+
314
+ Returns:
315
+ An ordered dictionary of signals and their tensors.
316
+ """
317
+ # Currently, not doing any processing of the loaded signals.
318
+ return audio
319
+
320
+ def load_sync_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
321
+ """Load signals with the same start and duration.
322
+
323
+ Args:
324
+ example: an example from audio collection
325
+
326
+ Returns:
327
+ An ordered dictionary of signals and their tensors.
328
+ """
329
+ output = OrderedDict()
330
+ sync_audio_files = [example.audio_files[s] for s in self.sync_setup.signals]
331
+
332
+ sync_samples = self.get_samples_synchronized(
333
+ audio_files=sync_audio_files,
334
+ channel_selectors=self.sync_setup.channel_selectors,
335
+ sample_rate=self.sample_rate,
336
+ duration=self.sync_setup.duration,
337
+ fixed_offset=example.offset,
338
+ random_offset=self.random_offset,
339
+ )
340
+
341
+ for signal, samples in zip(self.sync_setup.signals, sync_samples):
342
+ output[signal] = torch.tensor(samples)
343
+
344
+ return output
345
+
346
+ def load_async_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
347
+ """Load each async signal independently, no constraints on starting
348
+ from the same time.
349
+
350
+ Args:
351
+ example: an example from audio collection
352
+
353
+ Returns:
354
+ An ordered dictionary of signals and their tensors.
355
+ """
356
+ output = OrderedDict()
357
+ for idx, signal in enumerate(self.async_setup.signals):
358
+ samples = self.get_samples(
359
+ audio_file=example.audio_files[signal],
360
+ sample_rate=self.sample_rate,
361
+ duration=self.async_setup.duration[idx],
362
+ channel_selector=self.async_setup.channel_selectors[idx],
363
+ fixed_offset=example.offset,
364
+ random_offset=self.random_offset,
365
+ )
366
+ output[signal] = torch.tensor(samples)
367
+ return output
368
+
369
+ @classmethod
370
+ def get_samples(
371
+ cls,
372
+ audio_file: str,
373
+ sample_rate: int,
374
+ duration: Optional[float] = None,
375
+ channel_selector: ChannelSelectorType = None,
376
+ fixed_offset: float = 0,
377
+ random_offset: bool = False,
378
+ ) -> np.ndarray:
379
+ """Get samples from an audio file.
380
+ For a single-channel signal, the output is shape (num_samples,).
381
+ For a multi-channel signal, the output is shape (num_samples, num_channels).
382
+
383
+ Args:
384
+ audio_file: path to an audio file
385
+ sample_rate: desired sample rate for output samples
386
+ duration: Optional desired duration of output samples.
387
+ If `None`, the complete file will be loaded.
388
+ If set, a segment of `duration` seconds will be loaded.
389
+ channel_selector: Optional channel selector, for selecting a subset of channels.
390
+ fixed_offset: Optional fixed offset when loading samples.
391
+ random_offset: If `True`, offset will be randomized when loading a short segment
392
+ from a file. The value is randomized between fixed_offset and
393
+ max_offset (set depending on the duration and fixed_offset).
394
+
395
+ Returns:
396
+ Numpy array with samples from audio file.
397
+ The array has shape (num_samples,) for a single-channel signal
398
+ or (num_channels, num_samples) for a multi-channel signal.
399
+ """
400
+ output = cls.get_samples_synchronized(
401
+ audio_files=[audio_file],
402
+ sample_rate=sample_rate,
403
+ duration=duration,
404
+ channel_selectors=[channel_selector],
405
+ fixed_offset=fixed_offset,
406
+ random_offset=random_offset,
407
+ )
408
+
409
+ return output[0]
410
+
411
+ @classmethod
412
+ def get_samples_synchronized(
413
+ cls,
414
+ audio_files: List[str],
415
+ sample_rate: int,
416
+ duration: Optional[float] = None,
417
+ channel_selectors: Optional[List[ChannelSelectorType]] = None,
418
+ fixed_offset: float = 0,
419
+ random_offset: bool = False,
420
+ ) -> List[np.ndarray]:
421
+ """Get samples from multiple files with the same start and end point.
422
+
423
+ Args:
424
+ audio_files: list of paths to audio files
425
+ sample_rate: desired sample rate for output samples
426
+ duration: Optional desired duration of output samples.
427
+ If `None`, the complete files will be loaded.
428
+ If set, a segment of `duration` seconds will be loaded from
429
+ all files. Segment is synchronized across files, so that
430
+ start and end points are the same.
431
+ channel_selectors: Optional channel selector for each signal, for selecting
432
+ a subset of channels.
433
+ fixed_offset: Optional fixed offset when loading samples.
434
+ random_offset: If `True`, offset will be randomized when loading a short segment
435
+ from a file. The value is randomized between fixed_offset and
436
+ max_offset (set depending on the duration and fixed_offset).
437
+
438
+ Returns:
439
+ List with the same size as `audio_files` but containing numpy arrays
440
+ with samples from each audio file.
441
+ Each array has shape (num_samples,) or (num_channels, num_samples), for single-
442
+ or multi-channel signal, respectively.
443
+ For example, if `audio_files = [path/to/file_1.wav, path/to/file_2.wav]`,
444
+ the output will be a list `output = [samples_file_1, samples_file_2]`.
445
+ """
446
+ if channel_selectors is None:
447
+ channel_selectors = [None] * len(audio_files)
448
+
449
+ if duration is None:
450
+ # Load complete files starting from a fixed offset
451
+ offset = fixed_offset # fixed offset
452
+ num_samples = None # no constrain on the number of samples
453
+
454
+ else:
455
+ # Fixed duration of the output
456
+ audio_durations = cls.get_duration(audio_files)
457
+ min_audio_duration = min(audio_durations)
458
+ available_duration = min_audio_duration - fixed_offset
459
+
460
+ if available_duration <= 0:
461
+ raise ValueError(f'Fixed offset {fixed_offset}s is larger than shortest file {min_duration}s.')
462
+
463
+ if duration + fixed_offset > min_audio_duration:
464
+ # The shortest file is shorter than the requested duration
465
+ logging.debug(
466
+ f'Shortest file ({min_audio_duration}s) is less than the desired duration {duration}s + fixed offset {fixed_offset}s. Returned signals will be shortened to {available_duration} seconds.'
467
+ )
468
+ offset = fixed_offset
469
+ duration = available_duration
470
+ elif random_offset:
471
+ # Randomize offset based on the shortest file
472
+ max_offset = min_audio_duration - duration
473
+ offset = random.uniform(fixed_offset, max_offset)
474
+ else:
475
+ # Fixed offset
476
+ offset = fixed_offset
477
+
478
+ # Fixed number of samples
479
+ num_samples = math.floor(duration * sample_rate)
480
+
481
+ output = []
482
+
483
+ # Prepare segments
484
+ for idx, audio_file in enumerate(audio_files):
485
+ segment_samples = cls.get_samples_from_file(
486
+ audio_file=audio_file,
487
+ sample_rate=sample_rate,
488
+ offset=offset,
489
+ num_samples=num_samples,
490
+ channel_selector=channel_selectors[idx],
491
+ )
492
+ output.append(segment_samples)
493
+
494
+ return output
495
+
496
+ @classmethod
497
+ def get_samples_from_file(
498
+ cls,
499
+ audio_file: Union[str, List[str]],
500
+ sample_rate: int,
501
+ offset: float,
502
+ num_samples: Optional[int] = None,
503
+ channel_selector: Optional[ChannelSelectorType] = None,
504
+ ) -> np.ndarray:
505
+ """Get samples from a single or multiple files.
506
+ If loading samples from multiple files, they will
507
+ be concatenated along the channel dimension.
508
+
509
+ Args:
510
+ audio_file: path or a list of paths.
511
+ sample_rate: sample rate of the loaded samples
512
+ offset: fixed offset in seconds
513
+ num_samples: Optional, number of samples to load.
514
+ If `None`, all available samples will be loaded.
515
+ channel_selector: Select a subset of available channels.
516
+
517
+ Returns:
518
+ An array with shape (samples,) or (channels, samples)
519
+ """
520
+ if isinstance(audio_file, str):
521
+ # Load samples from a single file
522
+ segment_samples = cls.get_segment_from_file(
523
+ audio_file=audio_file,
524
+ sample_rate=sample_rate,
525
+ offset=offset,
526
+ num_samples=num_samples,
527
+ channel_selector=channel_selector,
528
+ )
529
+ elif isinstance(audio_file, list):
530
+ # Load samples from multiple files and form a multi-channel signal
531
+ segment_samples = []
532
+ for a_file in audio_file:
533
+ a_file_samples = cls.get_segment_from_file(
534
+ audio_file=a_file,
535
+ sample_rate=sample_rate,
536
+ offset=offset,
537
+ num_samples=num_samples,
538
+ channel_selector=channel_selector,
539
+ )
540
+ segment_samples.append(a_file_samples)
541
+ segment_samples = cls.list_to_multichannel(segment_samples)
542
+ elif audio_file is None:
543
+ # Support for inference, when the target signal is `None`
544
+ segment_samples = []
545
+ else:
546
+ raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}')
547
+ return segment_samples
548
+
549
+ @staticmethod
550
+ def get_segment_from_file(
551
+ audio_file: str,
552
+ sample_rate: int,
553
+ offset: float,
554
+ num_samples: Optional[int] = None,
555
+ channel_selector: Optional[ChannelSelectorType] = None,
556
+ ) -> np.ndarray:
557
+ """Get a segment of samples from a single audio file.
558
+
559
+ Args:
560
+ audio_file: path to an audio file
561
+ sample_rate: sample rate of the loaded samples
562
+ offset: fixed offset in seconds
563
+ num_samples: Optional, number of samples to load.
564
+ If `None`, all available samples will be loaded.
565
+ channel_selector: Select a subset of available channels.
566
+
567
+ Returns:
568
+ An array with shape (samples,) or (channels, samples)
569
+ """
570
+ if num_samples is None:
571
+ segment = AudioSegment.from_file(
572
+ audio_file=audio_file, target_sr=sample_rate, offset=offset, channel_selector=channel_selector,
573
+ )
574
+
575
+ else:
576
+ segment = AudioSegment.segment_from_file(
577
+ audio_file=audio_file,
578
+ target_sr=sample_rate,
579
+ n_segments=num_samples,
580
+ offset=offset,
581
+ channel_selector=channel_selector,
582
+ )
583
+
584
+ if segment.samples.ndim == 1:
585
+ # Single-channel signal
586
+ return segment.samples
587
+ elif segment.samples.ndim == 2:
588
+ # Use multi-channel format as (channels, samples)
589
+ return segment.samples.T
590
+ else:
591
+ raise RuntimeError(f'Unexpected samples shape: {segment.samples.shape}')
592
+
593
+ @staticmethod
594
+ def list_to_multichannel(signal: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
595
+ """Convert a list of signals into a multi-channel signal by concatenating
596
+ the elements of the list along the channel dimension.
597
+
598
+ If input is not a list, it is returned unmodified.
599
+
600
+ Args:
601
+ signal: list of arrays
602
+
603
+ Returns:
604
+ Numpy array obtained by concatenating the elements of the list
605
+ along the channel dimension (axis=0).
606
+ """
607
+ if not isinstance(signal, list):
608
+ # Nothing to do there
609
+ return signal
610
+ elif len(signal) == 0:
611
+ # Nothing to do, return as is
612
+ return signal
613
+ elif len(signal) == 1:
614
+ # Nothing to concatenate, return the original format
615
+ return signal[0]
616
+
617
+ # If multiple signals are provided in a list, we concatenate them along the channel dimension
618
+ if signal[0].ndim == 1:
619
+ # Single-channel individual files
620
+ mc_signal = np.stack(signal, axis=0)
621
+ elif signal[0].ndim == 2:
622
+ # Multi-channel individual files
623
+ mc_signal = np.concatenate(signal, axis=0)
624
+ else:
625
+ raise RuntimeError(f'Unexpected target with {signal[0].ndim} dimensions.')
626
+
627
+ return mc_signal
628
+
629
+ @staticmethod
630
+ def get_duration(audio_files: List[str]) -> List[float]:
631
+ """Get duration for each audio file in `audio_files`.
632
+
633
+ Args:
634
+ audio_files: list of paths to audio files
635
+
636
+ Returns:
637
+ List of durations in seconds.
638
+ """
639
+ duration = [librosa.get_duration(path=f) for f in flatten(audio_files)]
640
+ return duration
641
+
642
+ def load_embedding(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
643
+ """Given an example, load embedding from `example.audio_files[embedding]`
644
+ and return it in a dictionary.
645
+
646
+ Args:
647
+ example: An example from audio collection
648
+
649
+ Returns:
650
+ An dictionary of embedding keys and their tensors.
651
+ """
652
+ output = OrderedDict()
653
+ for idx, signal in enumerate(self.embedding_setup.signals):
654
+ embedding_file = example.audio_files[signal]
655
+ embedding = self.load_embedding_vector(embedding_file)
656
+ output[signal] = torch.tensor(embedding)
657
+ return output
658
+
659
+ @staticmethod
660
+ def load_embedding_vector(filepath: str) -> np.ndarray:
661
+ """Load an embedding vector from a file.
662
+
663
+ Args:
664
+ filepath: path to a file storing a vector.
665
+ Currently, it is assumed the file is a npy file.
666
+
667
+ Returns:
668
+ Array loaded from filepath.
669
+ """
670
+ if filepath.endswith('.npy'):
671
+ with open(filepath, 'rb') as f:
672
+ embedding = np.load(f)
673
+ else:
674
+ raise RuntimeError(f'Unknown embedding file format in file: {filepath}')
675
+
676
+ return embedding
677
+
678
+
679
+ class BaseAudioDataset(Dataset):
680
+ """Base class of audio datasets, providing common functionality
681
+ for other audio datasets.
682
+
683
+ Args:
684
+ collection: Collection of audio examples prepared from manifest files.
685
+ audio_processor: Used to process every example from the collection.
686
+ A callable with `process` method. For reference,
687
+ please check ASRAudioProcessor.
688
+ """
689
+
690
+ @property
691
+ @abc.abstractmethod
692
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
693
+ """Returns definitions of module output ports.
694
+ """
695
+
696
+ def __init__(self, collection: collections.Audio, audio_processor: Callable, output_type: Type[namedtuple]):
697
+ """Instantiates an audio dataset.
698
+ """
699
+ super().__init__()
700
+
701
+ self.collection = collection
702
+ self.audio_processor = audio_processor
703
+ self.output_type = output_type
704
+
705
+ def num_channels(self, signal_key) -> int:
706
+ """Returns the number of channels for a particular signal in
707
+ items prepared by this dictionary.
708
+
709
+ More specifically, this will get the tensor from the first
710
+ item in the dataset, check if it's a one- or two-dimensional
711
+ tensor, and return the number of channels based on the size
712
+ of the first axis (shape[0]).
713
+
714
+ NOTE:
715
+ This assumes that all examples have the same number of channels.
716
+
717
+ Args:
718
+ signal_key: string, used to select a signal from the dictionary
719
+ output by __getitem__
720
+
721
+ Returns:
722
+ Number of channels for the selected signal.
723
+ """
724
+ # Assumption: whole dataset has the same number of channels
725
+ item = self.__getitem__(0)
726
+
727
+ if item[signal_key].ndim == 1:
728
+ return 1
729
+ elif item[signal_key].ndim == 2:
730
+ return item[signal_key].shape[0]
731
+ else:
732
+ raise RuntimeError(
733
+ f'Unexpected number of dimension for signal {signal_key} with shape {item[signal_key].shape}'
734
+ )
735
+
736
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
737
+ """Return a single example from the dataset.
738
+
739
+ Args:
740
+ index: integer index of an example in the collection
741
+
742
+ Returns:
743
+ Dictionary providing mapping from signal to its tensor.
744
+ For example:
745
+ ```
746
+ {
747
+ 'input_signal': input_signal_tensor,
748
+ 'target_signal': target_signal_tensor,
749
+ }
750
+ ```
751
+ """
752
+ example = self.collection[index]
753
+ output = self.audio_processor.process(example=example)
754
+
755
+ return output
756
+
757
+ def __len__(self) -> int:
758
+ """Return the number of examples in the dataset.
759
+ """
760
+ return len(self.collection)
761
+
762
+ def _collate_fn(self, batch) -> Tuple[torch.Tensor]:
763
+ """Collate items in a batch.
764
+ """
765
+ return self.output_type(*_audio_collate_fn(batch))
766
+
767
+
768
+ AudioToTargetExample = namedtuple(
769
+ typename='AudioToTargetExample', field_names='input_signal input_length target_signal target_length'
770
+ )
771
+
772
+
773
+ class AudioToTargetDataset(BaseAudioDataset):
774
+ """A dataset for audio-to-audio tasks where the goal is to use
775
+ an input signal to recover the corresponding target signal.
776
+
777
+ Each line of the manifest file is expected to have the following format
778
+ ```
779
+ {
780
+ 'input_key': 'path/to/input.wav',
781
+ 'target_key': 'path/to/path_to_target.wav',
782
+ 'duration': duration_of_input,
783
+ }
784
+ ```
785
+
786
+ Additionally, multiple audio files may be provided for each key in the manifest, for example,
787
+ ```
788
+ {
789
+ 'input_key': 'path/to/input.wav',
790
+ 'target_key': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'],
791
+ 'duration': duration_of_input,
792
+ }
793
+ ```
794
+
795
+ Keys for input and target signals can be configured in the constructor (`input_key` and `target_key`).
796
+
797
+ Args:
798
+ manifest_filepath: Path to manifest file in a format described above.
799
+ sample_rate: Sample rate for loaded audio signals.
800
+ input_key: Key pointing to input audio files in the manifest
801
+ target_key: Key pointing to target audio files in manifest
802
+ audio_duration: Optional duration of each item returned by __getitem__.
803
+ If `None`, complete audio will be loaded.
804
+ If set, a random subsegment will be loaded synchronously from
805
+ target and audio, i.e., with the same start and end point.
806
+ random_offset: If `True`, offset will be randomized when loading a subsegment
807
+ from a file.
808
+ max_duration: If audio exceeds this length, do not include in dataset.
809
+ min_duration: If audio is less than this length, do not include in dataset.
810
+ max_utts: Limit number of utterances.
811
+ input_channel_selector: Optional, select subset of channels from each input audio file.
812
+ If `None`, all channels will be loaded.
813
+ target_channel_selector: Optional, select subset of channels from each input audio file.
814
+ If `None`, all channels will be loaded.
815
+ """
816
+
817
+ def __init__(
818
+ self,
819
+ manifest_filepath: str,
820
+ sample_rate: int,
821
+ input_key: str,
822
+ target_key: str,
823
+ audio_duration: Optional[float] = None,
824
+ random_offset: bool = False,
825
+ max_duration: Optional[float] = None,
826
+ min_duration: Optional[float] = None,
827
+ max_utts: Optional[int] = None,
828
+ input_channel_selector: Optional[int] = None,
829
+ target_channel_selector: Optional[int] = None,
830
+ ):
831
+ audio_to_manifest_key = {
832
+ 'input_signal': input_key,
833
+ 'target_signal': target_key,
834
+ }
835
+
836
+ collection = collections.AudioCollection(
837
+ manifest_files=manifest_filepath,
838
+ audio_to_manifest_key=audio_to_manifest_key,
839
+ min_duration=min_duration,
840
+ max_duration=max_duration,
841
+ max_number=max_utts,
842
+ )
843
+
844
+ audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
845
+ audio_processor.sync_setup = SignalSetup(
846
+ signals=['input_signal', 'target_signal'],
847
+ duration=audio_duration,
848
+ channel_selectors=[input_channel_selector, target_channel_selector],
849
+ )
850
+
851
+ super().__init__(collection=collection, audio_processor=audio_processor, output_type=AudioToTargetExample)
852
+
853
+ @property
854
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
855
+ """Returns definitions of module output ports.
856
+
857
+ Returns:
858
+ Ordered dictionary in the following form:
859
+ ```
860
+ {
861
+ 'input_signal': batched single- or multi-channel format,
862
+ 'input_length': batched original length of each input signal
863
+ 'target_signal': batched single- or multi-channel format,
864
+ 'target_length': batched original length of each target signal
865
+ }
866
+ ```
867
+ """
868
+ sc_audio_type = NeuralType(('B', 'T'), AudioSignal())
869
+ mc_audio_type = NeuralType(('B', 'C', 'T'), AudioSignal())
870
+
871
+ return OrderedDict(
872
+ input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
873
+ input_length=NeuralType(('B',), LengthsType()),
874
+ target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
875
+ target_length=NeuralType(('B',), LengthsType()),
876
+ )
877
+
878
+
879
+ AudioToTargetWithReferenceExample = namedtuple(
880
+ typename='AudioToTargetWithReferenceExample',
881
+ field_names='input_signal input_length target_signal target_length reference_signal reference_length',
882
+ )
883
+
884
+
885
+ class AudioToTargetWithReferenceDataset(BaseAudioDataset):
886
+ """A dataset for audio-to-audio tasks where the goal is to use
887
+ an input signal to recover the corresponding target signal and an
888
+ additional reference signal is available.
889
+
890
+ This can be used, for example, when a reference signal is
891
+ available from
892
+ - enrollment utterance for the target signal
893
+ - echo reference from playback
894
+ - reference from another sensor that correlates with the target signal
895
+
896
+ Each line of the manifest file is expected to have the following format
897
+ ```
898
+ {
899
+ 'input_key': 'path/to/input.wav',
900
+ 'target_key': 'path/to/path_to_target.wav',
901
+ 'reference_key': 'path/to/path_to_reference.wav',
902
+ 'duration': duration_of_input,
903
+ }
904
+ ```
905
+
906
+ Keys for input, target and reference signals can be configured in the constructor.
907
+
908
+ Args:
909
+ manifest_filepath: Path to manifest file in a format described above.
910
+ sample_rate: Sample rate for loaded audio signals.
911
+ input_key: Key pointing to input audio files in the manifest
912
+ target_key: Key pointing to target audio files in manifest
913
+ reference_key: Key pointing to reference audio files in manifest
914
+ audio_duration: Optional duration of each item returned by __getitem__.
915
+ If `None`, complete audio will be loaded.
916
+ If set, a random subsegment will be loaded synchronously from
917
+ target and audio, i.e., with the same start and end point.
918
+ random_offset: If `True`, offset will be randomized when loading a subsegment
919
+ from a file.
920
+ max_duration: If audio exceeds this length, do not include in dataset.
921
+ min_duration: If audio is less than this length, do not include in dataset.
922
+ max_utts: Limit number of utterances.
923
+ input_channel_selector: Optional, select subset of channels from each input audio file.
924
+ If `None`, all channels will be loaded.
925
+ target_channel_selector: Optional, select subset of channels from each input audio file.
926
+ If `None`, all channels will be loaded.
927
+ reference_channel_selector: Optional, select subset of channels from each input audio file.
928
+ If `None`, all channels will be loaded.
929
+ reference_is_synchronized: If True, it is assumed that the reference signal is synchronized
930
+ with the input signal, so the same subsegment will be loaded as for
931
+ input and target. If False, reference signal will be loaded independently
932
+ from input and target.
933
+ reference_duration: Optional, can be used to set a fixed duration of the reference utterance. If `None`,
934
+ complete audio file will be loaded.
935
+ """
936
+
937
+ def __init__(
938
+ self,
939
+ manifest_filepath: str,
940
+ sample_rate: int,
941
+ input_key: str,
942
+ target_key: str,
943
+ reference_key: str,
944
+ audio_duration: Optional[float] = None,
945
+ random_offset: bool = False,
946
+ max_duration: Optional[float] = None,
947
+ min_duration: Optional[float] = None,
948
+ max_utts: Optional[int] = None,
949
+ input_channel_selector: Optional[int] = None,
950
+ target_channel_selector: Optional[int] = None,
951
+ reference_channel_selector: Optional[int] = None,
952
+ reference_is_synchronized: bool = True,
953
+ reference_duration: Optional[float] = None,
954
+ ):
955
+ audio_to_manifest_key = {
956
+ 'input_signal': input_key,
957
+ 'target_signal': target_key,
958
+ 'reference_signal': reference_key,
959
+ }
960
+
961
+ collection = collections.AudioCollection(
962
+ manifest_files=manifest_filepath,
963
+ audio_to_manifest_key=audio_to_manifest_key,
964
+ min_duration=min_duration,
965
+ max_duration=max_duration,
966
+ max_number=max_utts,
967
+ )
968
+
969
+ audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
970
+
971
+ if reference_is_synchronized:
972
+ audio_processor.sync_setup = SignalSetup(
973
+ signals=['input_signal', 'target_signal', 'reference_signal'],
974
+ duration=audio_duration,
975
+ channel_selectors=[input_channel_selector, target_channel_selector, reference_channel_selector],
976
+ )
977
+ else:
978
+ audio_processor.sync_setup = SignalSetup(
979
+ signals=['input_signal', 'target_signal'],
980
+ duration=audio_duration,
981
+ channel_selectors=[input_channel_selector, target_channel_selector],
982
+ )
983
+ audio_processor.async_setup = SignalSetup(
984
+ signals=['reference_signal'],
985
+ duration=[reference_duration],
986
+ channel_selectors=[reference_channel_selector],
987
+ )
988
+
989
+ super().__init__(
990
+ collection=collection, audio_processor=audio_processor, output_type=AudioToTargetWithReferenceExample
991
+ )
992
+
993
+ @property
994
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
995
+ """Returns definitions of module output ports.
996
+
997
+ Returns:
998
+ Ordered dictionary in the following form:
999
+ ```
1000
+ {
1001
+ 'input_signal': batched single- or multi-channel format,
1002
+ 'input_length': batched original length of each input signal
1003
+ 'target_signal': batched single- or multi-channel format,
1004
+ 'target_length': batched original length of each target signal
1005
+ 'reference_signal': single- or multi-channel format,
1006
+ 'reference_length': original length of each reference signal
1007
+ }
1008
+ ```
1009
+ """
1010
+ sc_audio_type = NeuralType(('B', 'T'), AudioSignal())
1011
+ mc_audio_type = NeuralType(('B', 'C', 'T'), AudioSignal())
1012
+
1013
+ return OrderedDict(
1014
+ input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
1015
+ input_length=NeuralType(('B',), LengthsType()),
1016
+ target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
1017
+ target_length=NeuralType(('B',), LengthsType()),
1018
+ reference_signal=sc_audio_type if self.num_channels('reference_signal') == 1 else mc_audio_type,
1019
+ reference_length=NeuralType(('B',), LengthsType()),
1020
+ )
1021
+
1022
+
1023
+ AudioToTargetWithEmbeddingExample = namedtuple(
1024
+ typename='AudioToTargetWithEmbeddingExample',
1025
+ field_names='input_signal input_length target_signal target_length embedding_vector embedding_length',
1026
+ )
1027
+
1028
+
1029
+ class AudioToTargetWithEmbeddingDataset(BaseAudioDataset):
1030
+ """A dataset for audio-to-audio tasks where the goal is to use
1031
+ an input signal to recover the corresponding target signal and an
1032
+ additional embedding signal. It is assumed that the embedding
1033
+ is in a form of a vector.
1034
+
1035
+ Each line of the manifest file is expected to have the following format
1036
+ ```
1037
+ {
1038
+ input_key: 'path/to/input.wav',
1039
+ target_key: 'path/to/path_to_target.wav',
1040
+ embedding_key: 'path/to/path_to_reference.npy',
1041
+ 'duration': duration_of_input,
1042
+ }
1043
+ ```
1044
+
1045
+ Keys for input, target and embedding signals can be configured in the constructor.
1046
+
1047
+ Args:
1048
+ manifest_filepath: Path to manifest file in a format described above.
1049
+ sample_rate: Sample rate for loaded audio signals.
1050
+ input_key: Key pointing to input audio files in the manifest
1051
+ target_key: Key pointing to target audio files in manifest
1052
+ embedding_key: Key pointing to embedding files in manifest
1053
+ audio_duration: Optional duration of each item returned by __getitem__.
1054
+ If `None`, complete audio will be loaded.
1055
+ If set, a random subsegment will be loaded synchronously from
1056
+ target and audio, i.e., with the same start and end point.
1057
+ random_offset: If `True`, offset will be randomized when loading a subsegment
1058
+ from a file.
1059
+ max_duration: If audio exceeds this length, do not include in dataset.
1060
+ min_duration: If audio is less than this length, do not include in dataset.
1061
+ max_utts: Limit number of utterances.
1062
+ input_channel_selector: Optional, select subset of channels from each input audio file.
1063
+ If `None`, all channels will be loaded.
1064
+ target_channel_selector: Optional, select subset of channels from each input audio file.
1065
+ If `None`, all channels will be loaded.
1066
+ """
1067
+
1068
+ def __init__(
1069
+ self,
1070
+ manifest_filepath: str,
1071
+ sample_rate: int,
1072
+ input_key: str,
1073
+ target_key: str,
1074
+ embedding_key: str,
1075
+ audio_duration: Optional[float] = None,
1076
+ random_offset: bool = False,
1077
+ max_duration: Optional[float] = None,
1078
+ min_duration: Optional[float] = None,
1079
+ max_utts: Optional[int] = None,
1080
+ input_channel_selector: Optional[int] = None,
1081
+ target_channel_selector: Optional[int] = None,
1082
+ ):
1083
+ audio_to_manifest_key = {
1084
+ 'input_signal': input_key,
1085
+ 'target_signal': target_key,
1086
+ 'embedding_vector': embedding_key,
1087
+ }
1088
+
1089
+ collection = collections.AudioCollection(
1090
+ manifest_files=manifest_filepath,
1091
+ audio_to_manifest_key=audio_to_manifest_key,
1092
+ min_duration=min_duration,
1093
+ max_duration=max_duration,
1094
+ max_number=max_utts,
1095
+ )
1096
+
1097
+ audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,)
1098
+ audio_processor.sync_setup = SignalSetup(
1099
+ signals=['input_signal', 'target_signal'],
1100
+ duration=audio_duration,
1101
+ channel_selectors=[input_channel_selector, target_channel_selector],
1102
+ )
1103
+ audio_processor.embedding_setup = SignalSetup(signals=['embedding_vector'])
1104
+
1105
+ super().__init__(
1106
+ collection=collection, audio_processor=audio_processor, output_type=AudioToTargetWithEmbeddingExample
1107
+ )
1108
+
1109
+ @property
1110
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
1111
+ """Returns definitions of module output ports.
1112
+
1113
+ Returns:
1114
+ Ordered dictionary in the following form:
1115
+ ```
1116
+ {
1117
+ 'input_signal': batched single- or multi-channel format,
1118
+ 'input_length': batched original length of each input signal
1119
+ 'target_signal': batched single- or multi-channel format,
1120
+ 'target_length': batched original length of each target signal
1121
+ 'embedding_vector': batched embedded vector format,
1122
+ 'embedding_length': batched original length of each embedding vector
1123
+ }
1124
+ ```
1125
+ """
1126
+ sc_audio_type = NeuralType(('B', 'T'), AudioSignal())
1127
+ mc_audio_type = NeuralType(('B', 'C', 'T'), AudioSignal())
1128
+
1129
+ return OrderedDict(
1130
+ input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
1131
+ input_length=NeuralType(('B',), LengthsType()),
1132
+ target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
1133
+ target_length=NeuralType(('B',), LengthsType()),
1134
+ embedding_vector=NeuralType(('B', 'D'), EncodedRepresentation()),
1135
+ embedding_length=NeuralType(('B',), LengthsType()),
1136
+ )
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_audio_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from nemo.collections.asr.data import audio_to_audio
16
+
17
+
18
+ def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset:
19
+ """Instantiates an audio-to-audio dataset.
20
+
21
+ Args:
22
+ config: Config of AudioToTargetDataset.
23
+
24
+ Returns:
25
+ An instance of AudioToTargetDataset
26
+ """
27
+ dataset = audio_to_audio.AudioToTargetDataset(
28
+ manifest_filepath=config['manifest_filepath'],
29
+ sample_rate=config['sample_rate'],
30
+ input_key=config['input_key'],
31
+ target_key=config['target_key'],
32
+ audio_duration=config.get('audio_duration', None),
33
+ random_offset=config.get('random_offset', False),
34
+ max_duration=config.get('max_duration', None),
35
+ min_duration=config.get('min_duration', None),
36
+ max_utts=config.get('max_utts', 0),
37
+ input_channel_selector=config.get('input_channel_selector', None),
38
+ target_channel_selector=config.get('target_channel_selector', None),
39
+ )
40
+ return dataset
41
+
42
+
43
+ def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset:
44
+ """Instantiates an audio-to-audio dataset.
45
+
46
+ Args:
47
+ config: Config of AudioToTargetWithReferenceDataset.
48
+
49
+ Returns:
50
+ An instance of AudioToTargetWithReferenceDataset
51
+ """
52
+ dataset = audio_to_audio.AudioToTargetWithReferenceDataset(
53
+ manifest_filepath=config['manifest_filepath'],
54
+ sample_rate=config['sample_rate'],
55
+ input_key=config['input_key'],
56
+ target_key=config['target_key'],
57
+ reference_key=config['reference_key'],
58
+ audio_duration=config.get('audio_duration', None),
59
+ random_offset=config.get('random_offset', False),
60
+ max_duration=config.get('max_duration', None),
61
+ min_duration=config.get('min_duration', None),
62
+ max_utts=config.get('max_utts', 0),
63
+ input_channel_selector=config.get('input_channel_selector', None),
64
+ target_channel_selector=config.get('target_channel_selector', None),
65
+ reference_channel_selector=config.get('reference_channel_selector', None),
66
+ reference_is_synchronized=config.get('reference_is_synchronized', True),
67
+ reference_duration=config.get('reference_duration', None),
68
+ )
69
+ return dataset
70
+
71
+
72
+ def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset:
73
+ """Instantiates an audio-to-audio dataset.
74
+
75
+ Args:
76
+ config: Config of AudioToTargetWithEmbeddingDataset.
77
+
78
+ Returns:
79
+ An instance of AudioToTargetWithEmbeddingDataset
80
+ """
81
+ dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset(
82
+ manifest_filepath=config['manifest_filepath'],
83
+ sample_rate=config['sample_rate'],
84
+ input_key=config['input_key'],
85
+ target_key=config['target_key'],
86
+ embedding_key=config['embedding_key'],
87
+ audio_duration=config.get('audio_duration', None),
88
+ random_offset=config.get('random_offset', False),
89
+ max_duration=config.get('max_duration', None),
90
+ min_duration=config.get('min_duration', None),
91
+ max_utts=config.get('max_utts', 0),
92
+ input_channel_selector=config.get('input_channel_selector', None),
93
+ target_channel_selector=config.get('target_channel_selector', None),
94
+ )
95
+ return dataset
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_ctm_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from typing import Any, List, Tuple
20
+
21
+ from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
22
+ from nemo.utils import logging
23
+
24
+
25
+ @dataclass
26
+ class FrameCtmUnit:
27
+ """A container class for one CTM unit with start and length countable in frames.
28
+ """
29
+
30
+ label: str
31
+ start_frame: int
32
+ length: int
33
+ probability: float
34
+
35
+ def __repr__(self) -> str:
36
+ return f"{self.label}\t({self.probability:1.3f}): [{self.start_frame:6d}, {self.length:6d}]"
37
+
38
+ @property
39
+ def end_frame(self):
40
+ return self.start_frame + self.length
41
+
42
+ def to_ctm_str(self, time_per_frame: int) -> str:
43
+ """Represents the data as part of the CTM line.
44
+
45
+ The CTM line format is
46
+ <utterance_name> <channel> <start_time> <duration> <label_str> <probability>
47
+ This method prepares the last four entities."""
48
+ return f"{self.start_frame * time_per_frame :.3f} {self.length * time_per_frame :.3f} {self.label} {self.probability :1.3f}"
49
+
50
+
51
+ class ASRCTMPredictionWriter(ASRPredictionWriter):
52
+ def __init__(self, dataset, output_file: str, output_ctm_dir: str, time_per_frame: float):
53
+ super().__init__(dataset, output_file)
54
+ self.output_ctm_dir = output_ctm_dir
55
+ self.time_per_frame = time_per_frame
56
+ os.makedirs(self.output_ctm_dir, exist_ok=True)
57
+
58
+ def write_ctm(self, name, filepath, frameCtmUnits):
59
+ with open(filepath, "tw", encoding="utf-8") as f:
60
+ for unit in frameCtmUnits:
61
+ f.write(f"{name} 1 {unit.to_ctm_str(self.time_per_frame)}\n")
62
+
63
+ def write_on_batch_end(
64
+ self,
65
+ trainer,
66
+ pl_module: 'LightningModule',
67
+ prediction: Tuple[int, List[FrameCtmUnit]],
68
+ batch_indices: List[int],
69
+ batch: Any,
70
+ batch_idx: int,
71
+ dataloader_idx: int,
72
+ ):
73
+ for sample_id, units in prediction:
74
+ sample = self.dataset.get_manifest_sample(sample_id)
75
+ with_ctm = True
76
+ if len(units) == 0:
77
+ logging.warning(
78
+ f"""Do not producing CTM output for item `{sample.audio_file}`.
79
+ Check if text is empty or if duration is too short: `{sample.text_raw}`, {sample.duration}"""
80
+ )
81
+ with_ctm = False
82
+ item = {}
83
+ item["audio_filepath"] = sample.audio_file
84
+ item["duration"] = sample.duration
85
+ item["text"] = sample.text_raw
86
+ if with_ctm:
87
+ utt_name = Path(sample.audio_file).stem
88
+ ctm_filepath = os.path.join(self.output_ctm_dir, utt_name) + ".ctm"
89
+ self.write_ctm(utt_name, ctm_filepath, units)
90
+ item["ctm_filepath"] = ctm_filepath
91
+ else:
92
+ item["ctm_filepath"] = ""
93
+ self.outf.write(json.dumps(item) + "\n")
94
+ self.samples_num += 1
95
+ return
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_diar_label.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from collections import OrderedDict
17
+ from statistics import mode
18
+ from typing import Dict, Optional
19
+
20
+ import torch
21
+
22
+ from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat
23
+ from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data
24
+ from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel
25
+ from nemo.core.classes import Dataset
26
+ from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType
27
+
28
+
29
+ def get_scale_mapping_list(uniq_timestamps):
30
+ """
31
+ Call get_argmin_mat function to find the index of the non-base-scale segment that is closest to the
32
+ given base-scale segment. For each scale and each segment, a base-scale segment is assigned.
33
+
34
+ Args:
35
+ uniq_timestamps: (dict)
36
+ The dictionary containing embeddings, timestamps and multiscale weights.
37
+ If uniq_timestamps contains only one scale, single scale diarization is performed.
38
+
39
+ Returns:
40
+ scale_mapping_argmat (torch.tensor):
41
+
42
+ The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
43
+ segment index which has the closest center distance with (n+1)-th segment in the base scale.
44
+
45
+ - Example:
46
+ `scale_mapping_argmat[2][101] = 85`
47
+
48
+ In the above example, the code snippet means that 86-th segment in the 3rd scale (python index is 2) is
49
+ mapped to the 102-th segment in the base scale. Thus, the longer segments bound to have more repeating
50
+ numbers since multiple base scale segments (since the base scale has the shortest length) fall into the
51
+ range of the longer segments. At the same time, each row contains N numbers of indices where N is number
52
+ of segments in the base-scale (i.e., the finest scale).
53
+ """
54
+ timestamps_in_scales = []
55
+ for key, val in uniq_timestamps['scale_dict'].items():
56
+ timestamps_in_scales.append(torch.tensor(val['time_stamps']))
57
+ session_scale_mapping_list = get_argmin_mat(timestamps_in_scales)
58
+ scale_mapping_argmat = [[] for _ in range(len(uniq_timestamps['scale_dict'].keys()))]
59
+ for scale_idx in range(len(session_scale_mapping_list)):
60
+ scale_mapping_argmat[scale_idx] = session_scale_mapping_list[scale_idx]
61
+ scale_mapping_argmat = torch.stack(scale_mapping_argmat)
62
+ return scale_mapping_argmat
63
+
64
+
65
+ def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None):
66
+ """
67
+ Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted
68
+ speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be
69
+ included in the output lists.
70
+
71
+ Args:
72
+ uniq_id (str):
73
+ Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file.
74
+ rttm_lines (list):
75
+ List containing RTTM lines in str format.
76
+ mapping_dict (dict):
77
+ Mapping between the estimated speakers and the speakers in the ground-truth annotation.
78
+ `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode.
79
+ Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation.
80
+ Returns:
81
+ rttm_tup (tuple):
82
+ Tuple containing lists of start time, end time and speaker labels.
83
+
84
+ """
85
+ stt_list, end_list, speaker_list, pairwise_infer_spks = [], [], [], []
86
+ if target_spks:
87
+ inv_map = {v: k for k, v in mapping_dict.items()}
88
+ for spk_idx in target_spks:
89
+ spk_str = f'speaker_{spk_idx}'
90
+ if spk_str in inv_map:
91
+ pairwise_infer_spks.append(inv_map[spk_str])
92
+ for rttm_line in rttm_lines:
93
+ start, end, speaker = convert_rttm_line(rttm_line)
94
+ if target_spks is None or speaker in pairwise_infer_spks:
95
+ end_list.append(end)
96
+ stt_list.append(start)
97
+ speaker_list.append(speaker)
98
+ rttm_tup = (stt_list, end_list, speaker_list)
99
+ return rttm_tup
100
+
101
+
102
+ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, target_spks, min_spks=2):
103
+ """
104
+ Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM.
105
+ The unit-length is the frame shift length of the acoustic feature. The feature-level annotations
106
+ `fr_level_target` will later be converted to base-segment level diarization label.
107
+
108
+ Args:
109
+ rttm_timestamps (list):
110
+ List containing start and end time for each speaker segment label.
111
+ stt_list, end_list and speaker_list are contained.
112
+ frame_per_sec (int):
113
+ Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module.
114
+ target_spks (tuple):
115
+ Speaker indices that are generated from combinations. If there are only one or two speakers,
116
+ only a single target_spks variable is generated.
117
+
118
+ Returns:
119
+ fr_level_target (torch.tensor):
120
+ Tensor containing label for each feature level frame.
121
+ """
122
+ stt_list, end_list, speaker_list = rttm_timestamps
123
+ if len(speaker_list) == 0:
124
+ return None
125
+ else:
126
+ sorted_speakers = sorted(list(set(speaker_list)))
127
+ total_fr_len = int(max(end_list) * (10 ** round_digits))
128
+ spk_num = max(len(sorted_speakers), min_spks)
129
+ speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)}
130
+ fr_level_target = torch.zeros(total_fr_len, spk_num)
131
+
132
+ # If RTTM is not provided, then there is no speaker mapping dict in target_spks.
133
+ # Thus, return a zero-filled tensor as a placeholder.
134
+ for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)):
135
+ stt, end = round(stt, round_digits), round(end, round_digits)
136
+ spk = speaker_mapping_dict[spk_rttm_key]
137
+ stt_fr, end_fr = int(round(stt, 2) * frame_per_sec), int(round(end, round_digits) * frame_per_sec)
138
+ fr_level_target[stt_fr:end_fr, spk] = 1
139
+ return fr_level_target
140
+
141
+
142
+ class _AudioMSDDTrainDataset(Dataset):
143
+ """
144
+ Dataset class that loads a json file containing paths to audio files,
145
+ RTTM files and number of speakers. This Dataset class is designed for
146
+ training or fine-tuning speaker embedding extractor and diarization decoder
147
+ at the same time.
148
+
149
+ Example:
150
+ {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
151
+ "rttm_filepath": "/path/to/diar_label_0.rttm}
152
+ ...
153
+ {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
154
+ "rttm_filepath": "/path/to/diar_label_n.rttm}
155
+
156
+ Args:
157
+ manifest_filepath (str):
158
+ Path to input manifest json files.
159
+ multiscale_args_dict (dict):
160
+ Dictionary containing the parameters for multiscale segmentation and clustering.
161
+ emb_dir (str):
162
+ Path to a temporary folder where segmentation information for embedding extraction is saved.
163
+ soft_label_thres (float):
164
+ Threshold that determines the label of each segment based on RTTM file information.
165
+ featurizer:
166
+ Featurizer instance for generating features from the raw waveform.
167
+ window_stride (float):
168
+ Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
169
+ emb_batch_size (int):
170
+ Number of embedding vectors that are trained with attached computational graphs.
171
+ pairwise_infer (bool):
172
+ This variable should be True if dataloader is created for an inference task.
173
+ random_flip (bool):
174
+ If True, the two labels and input signals are randomly flipped per every epoch while training.
175
+ """
176
+
177
+ @property
178
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
179
+ """Returns definitions of module output ports."""
180
+ output_types = {
181
+ "features": NeuralType(('B', 'T'), AudioSignal()),
182
+ "feature_length": NeuralType(('B'), LengthsType()),
183
+ "ms_seg_timestamps": NeuralType(('B', 'C', 'T', 'D'), LengthsType()),
184
+ "ms_seg_counts": NeuralType(('B', 'C'), LengthsType()),
185
+ "clus_label_index": NeuralType(('B', 'T'), LengthsType()),
186
+ "scale_mapping": NeuralType(('B', 'C', 'T'), LengthsType()),
187
+ "targets": NeuralType(('B', 'T', 'C'), ProbsType()),
188
+ }
189
+
190
+ return output_types
191
+
192
+ def __init__(
193
+ self,
194
+ *,
195
+ manifest_filepath: str,
196
+ multiscale_args_dict: str,
197
+ emb_dir: str,
198
+ soft_label_thres: float,
199
+ featurizer,
200
+ window_stride,
201
+ emb_batch_size,
202
+ pairwise_infer: bool,
203
+ random_flip: bool = True,
204
+ global_rank: int = 0,
205
+ ):
206
+ super().__init__()
207
+ self.collection = DiarizationSpeechLabel(
208
+ manifests_files=manifest_filepath.split(','),
209
+ emb_dict=None,
210
+ clus_label_dict=None,
211
+ pairwise_infer=pairwise_infer,
212
+ )
213
+ self.featurizer = featurizer
214
+ self.multiscale_args_dict = multiscale_args_dict
215
+ self.emb_dir = emb_dir
216
+ self.round_digits = 2
217
+ self.decim = 10 ** self.round_digits
218
+ self.soft_label_thres = soft_label_thres
219
+ self.pairwise_infer = pairwise_infer
220
+ self.max_spks = 2
221
+ self.frame_per_sec = int(1 / window_stride)
222
+ self.emb_batch_size = emb_batch_size
223
+ self.random_flip = random_flip
224
+ self.global_rank = global_rank
225
+ self.manifest_filepath = manifest_filepath
226
+ self.multiscale_timestamp_dict = prepare_split_data(
227
+ self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank,
228
+ )
229
+
230
+ def __len__(self):
231
+ return len(self.collection)
232
+
233
+ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label):
234
+ """
235
+ Assign the generated speaker labels from the base scale (the finest scale) to the longer scales.
236
+ This process is needed to get the cluster labels for each scale. The cluster labels are needed to
237
+ calculate the cluster-average speaker embedding for each scale.
238
+
239
+ Args:
240
+ uniq_id (str):
241
+ Unique sample ID for training.
242
+ base_scale_clus_label (torch.tensor):
243
+ Tensor variable containing the speaker labels for the base-scale segments.
244
+
245
+ Returns:
246
+ per_scale_clus_label (torch.tensor):
247
+ Tensor variable containing the speaker labels for each segment in each scale.
248
+ Note that the total length of the speaker label sequence differs over scale since
249
+ each scale has a different number of segments for the same session.
250
+
251
+ scale_mapping (torch.tensor):
252
+ Matrix containing the segment indices of each scale. scale_mapping is necessary for reshaping the
253
+ multiscale embeddings to form an input matrix for the MSDD model.
254
+ """
255
+ per_scale_clus_label = []
256
+ self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict'])
257
+ uniq_scale_mapping = get_scale_mapping_list(self.multiscale_timestamp_dict[uniq_id])
258
+ for scale_index in range(self.scale_n):
259
+ new_clus_label = []
260
+ scale_seq_len = len(self.multiscale_timestamp_dict[uniq_id]["scale_dict"][scale_index]["time_stamps"])
261
+ for seg_idx in range(scale_seq_len):
262
+ if seg_idx in uniq_scale_mapping[scale_index]:
263
+ seg_clus_label = mode(base_scale_clus_label[uniq_scale_mapping[scale_index] == seg_idx])
264
+ else:
265
+ seg_clus_label = 0 if len(new_clus_label) == 0 else new_clus_label[-1]
266
+ new_clus_label.append(seg_clus_label)
267
+ per_scale_clus_label.extend(new_clus_label)
268
+ per_scale_clus_label = torch.tensor(per_scale_clus_label)
269
+ return per_scale_clus_label, uniq_scale_mapping
270
+
271
+ def get_diar_target_labels(self, uniq_id, sample, fr_level_target):
272
+ """
273
+ Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced
274
+ from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines
275
+ the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable.
276
+
277
+ Args:
278
+ uniq_id (str):
279
+ Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file.
280
+ sample:
281
+ `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath.
282
+ fr_level_target (torch.tensor):
283
+ Tensor containing label for each feature-level frame.
284
+
285
+ Returns:
286
+ seg_target (torch.tensor):
287
+ Tensor containing binary speaker labels for base-scale segments.
288
+ base_clus_label (torch.tensor):
289
+ Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment.
290
+ -1 means that there is no corresponding speaker in the target_spks tuple.
291
+ """
292
+ seg_target_list, base_clus_label = [], []
293
+ self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict'])
294
+ subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"]
295
+ for (seg_stt, seg_end) in subseg_time_stamp_list:
296
+ seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec)
297
+ soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / (
298
+ seg_end_fr - seg_stt_fr
299
+ )
300
+ label_int_sess = torch.argmax(soft_label_vec_sess)
301
+ soft_label_vec = soft_label_vec_sess.unsqueeze(0)[:, sample.target_spks].squeeze()
302
+ if label_int_sess in sample.target_spks and torch.sum(soft_label_vec_sess) > 0:
303
+ label_int = sample.target_spks.index(label_int_sess)
304
+ else:
305
+ label_int = -1
306
+ label_vec = (soft_label_vec > self.soft_label_thres).float()
307
+ seg_target_list.append(label_vec.detach())
308
+ base_clus_label.append(label_int)
309
+ seg_target = torch.stack(seg_target_list)
310
+ base_clus_label = torch.tensor(base_clus_label)
311
+ return seg_target, base_clus_label
312
+
313
+ def parse_rttm_for_ms_targets(self, sample):
314
+ """
315
+ Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file.
316
+ This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level
317
+ diarization label in a matrix form.
318
+
319
+ Example of seg_target:
320
+ [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]]
321
+
322
+ Args:
323
+ sample:
324
+ `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath.
325
+ target_spks (tuple):
326
+ Speaker indices that are generated from combinations. If there are only one or two speakers,
327
+ only a single target_spks tuple is generated.
328
+
329
+ Returns:
330
+ clus_label_index (torch.tensor):
331
+ Groundtruth clustering label (cluster index for each segment) from RTTM files for training purpose.
332
+ seg_target (torch.tensor):
333
+ Tensor variable containing hard-labels of speaker activity in each base-scale segment.
334
+ scale_mapping (torch.tensor):
335
+ Matrix containing the segment indices of each scale. scale_mapping is necessary for reshaping the
336
+ multiscale embeddings to form an input matrix for the MSDD model.
337
+
338
+ """
339
+ rttm_lines = open(sample.rttm_file).readlines()
340
+ uniq_id = self.get_uniq_id_with_range(sample)
341
+ rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines)
342
+ fr_level_target = assign_frame_level_spk_vector(
343
+ rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks
344
+ )
345
+ seg_target, base_clus_label = self.get_diar_target_labels(uniq_id, sample, fr_level_target)
346
+ clus_label_index, scale_mapping = self.assign_labels_to_longer_segs(uniq_id, base_clus_label)
347
+ return clus_label_index, seg_target, scale_mapping
348
+
349
+ def get_uniq_id_with_range(self, sample, deci=3):
350
+ """
351
+ Generate unique training sample ID from unique file ID, offset and duration. The start-end time added
352
+ unique ID is required for identifying the sample since multiple short audio samples are generated from a single
353
+ audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`.
354
+
355
+ Args:
356
+ sample:
357
+ `DiarizationSpeechLabel` instance from collections.
358
+
359
+ Returns:
360
+ uniq_id (str):
361
+ Unique sample ID which includes start and end time of the audio stream.
362
+ Example: abc1001_3122_6458
363
+
364
+ """
365
+ bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0]
366
+ offset = str(int(round(sample.offset, deci) * pow(10, deci)))
367
+ endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci)))
368
+ uniq_id = f"{bare_uniq_id}_{offset}_{endtime}"
369
+ return uniq_id
370
+
371
+ def get_ms_seg_timestamps(self, sample):
372
+ """
373
+ Get start and end time of segments in each scale.
374
+
375
+ Args:
376
+ sample:
377
+ `DiarizationSpeechLabel` instance from preprocessing.collections
378
+ Returns:
379
+ ms_seg_timestamps (torch.tensor):
380
+ Tensor containing Multiscale segment timestamps.
381
+ ms_seg_counts (torch.tensor):
382
+ Number of segments for each scale. This information is used for reshaping embedding batch
383
+ during forward propagation.
384
+ """
385
+ uniq_id = self.get_uniq_id_with_range(sample)
386
+ ms_seg_timestamps_list = []
387
+ max_seq_len = len(self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"])
388
+ ms_seg_counts = [0 for _ in range(self.scale_n)]
389
+ for scale_idx in range(self.scale_n):
390
+ scale_ts_list = []
391
+ for k, (seg_stt, seg_end) in enumerate(
392
+ self.multiscale_timestamp_dict[uniq_id]["scale_dict"][scale_idx]["time_stamps"]
393
+ ):
394
+ stt, end = (
395
+ int((seg_stt - sample.offset) * self.frame_per_sec),
396
+ int((seg_end - sample.offset) * self.frame_per_sec),
397
+ )
398
+ scale_ts_list.append(torch.tensor([stt, end]).detach())
399
+ ms_seg_counts[scale_idx] = len(
400
+ self.multiscale_timestamp_dict[uniq_id]["scale_dict"][scale_idx]["time_stamps"]
401
+ )
402
+ scale_ts = torch.stack(scale_ts_list)
403
+ scale_ts_padded = torch.cat([scale_ts, torch.zeros(max_seq_len - len(scale_ts_list), 2)], dim=0)
404
+ ms_seg_timestamps_list.append(scale_ts_padded.detach())
405
+ ms_seg_timestamps = torch.stack(ms_seg_timestamps_list)
406
+ ms_seg_counts = torch.tensor(ms_seg_counts)
407
+ return ms_seg_timestamps, ms_seg_counts
408
+
409
+ def __getitem__(self, index):
410
+ sample = self.collection[index]
411
+ if sample.offset is None:
412
+ sample.offset = 0
413
+ clus_label_index, targets, scale_mapping = self.parse_rttm_for_ms_targets(sample)
414
+ features = self.featurizer.process(sample.audio_file, offset=sample.offset, duration=sample.duration)
415
+ feature_length = torch.tensor(features.shape[0]).long()
416
+ ms_seg_timestamps, ms_seg_counts = self.get_ms_seg_timestamps(sample)
417
+ if self.random_flip:
418
+ torch.manual_seed(index)
419
+ flip = torch.cat([torch.randperm(self.max_spks), torch.tensor(-1).unsqueeze(0)])
420
+ clus_label_index, targets = flip[clus_label_index], targets[:, flip[: self.max_spks]]
421
+ return features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets
422
+
423
+
424
+ class _AudioMSDDInferDataset(Dataset):
425
+ """
426
+ Dataset class that loads a json file containing paths to audio files,
427
+ RTTM files and number of speakers. This Dataset class is built for diarization inference and
428
+ evaluation. Speaker embedding sequences, segment timestamps, cluster-average speaker embeddings
429
+ are loaded from memory and fed into the dataloader.
430
+
431
+ Example:
432
+ {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
433
+ "rttm_filepath": "/path/to/diar_label_0.rttm}
434
+ ...
435
+ {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
436
+ "rttm_filepath": "/path/to/diar_label_n.rttm}
437
+
438
+ Args:
439
+ manifest_filepath (str):
440
+ Path to input manifest json files.
441
+ emb_dict (dict):
442
+ Dictionary containing cluster-average embeddings and speaker mapping information.
443
+ emb_seq (dict):
444
+ Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps.
445
+ clus_label_dict (dict):
446
+ Subsegment-level (from base-scale) speaker labels from clustering results.
447
+ soft_label_thres (float):
448
+ A threshold that determines the label of each segment based on RTTM file information.
449
+ featurizer:
450
+ Featurizer instance for generating features from raw waveform.
451
+ seq_eval_mode (bool):
452
+ If True, F1 score will be calculated for each speaker pair during inference mode.
453
+ window_stride (float):
454
+ Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
455
+ use_single_scale_clus (bool):
456
+ Use only one scale for clustering instead of using multiple scales of embeddings for clustering.
457
+ pairwise_infer (bool):
458
+ This variable should be True if dataloader is created for an inference task.
459
+ """
460
+
461
+ @property
462
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
463
+ """Returns definitions of module output ports."""
464
+ output_types = OrderedDict(
465
+ {
466
+ "ms_emb_seq": NeuralType(('B', 'T', 'C', 'D'), SpectrogramType()),
467
+ "length": NeuralType(tuple('B'), LengthsType()),
468
+ "ms_avg_embs": NeuralType(('B', 'C', 'D', 'C'), EncodedRepresentation()),
469
+ "targets": NeuralType(('B', 'T', 'C'), ProbsType()),
470
+ }
471
+ )
472
+ return output_types
473
+
474
+ def __init__(
475
+ self,
476
+ *,
477
+ manifest_filepath: str,
478
+ emb_dict: Dict,
479
+ emb_seq: Dict,
480
+ clus_label_dict: Dict,
481
+ soft_label_thres: float,
482
+ seq_eval_mode: bool,
483
+ window_stride: float,
484
+ use_single_scale_clus: bool,
485
+ pairwise_infer: bool,
486
+ ):
487
+ super().__init__()
488
+ self.collection = DiarizationSpeechLabel(
489
+ manifests_files=manifest_filepath.split(','),
490
+ emb_dict=emb_dict,
491
+ clus_label_dict=clus_label_dict,
492
+ seq_eval_mode=seq_eval_mode,
493
+ pairwise_infer=pairwise_infer,
494
+ )
495
+ self.emb_dict = emb_dict
496
+ self.emb_seq = emb_seq
497
+ self.clus_label_dict = clus_label_dict
498
+ self.round_digits = 2
499
+ self.decim = 10 ** self.round_digits
500
+ self.frame_per_sec = int(1 / window_stride)
501
+ self.soft_label_thres = soft_label_thres
502
+ self.pairwise_infer = pairwise_infer
503
+ self.max_spks = 2
504
+ self.use_single_scale_clus = use_single_scale_clus
505
+ self.seq_eval_mode = seq_eval_mode
506
+
507
+ def __len__(self):
508
+ return len(self.collection)
509
+
510
+ def parse_rttm_multiscale(self, sample):
511
+ """
512
+ Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file.
513
+ This function is only used when ``self.seq_eval_mode=True`` and RTTM files are provided. This function converts
514
+ (start, end, speaker_id) format into base-scale (the finest scale) segment level diarization label in a matrix
515
+ form to create target matrix.
516
+
517
+ Args:
518
+ sample:
519
+ DiarizationSpeechLabel instance containing sample information such as audio filepath and RTTM filepath.
520
+ target_spks (tuple):
521
+ Two Indices of targeted speakers for evaluation.
522
+ Example of target_spks: (2, 3)
523
+ Returns:
524
+ seg_target (torch.tensor):
525
+ Tensor variable containing hard-labels of speaker activity in each base-scale segment.
526
+ """
527
+ if sample.rttm_file is None:
528
+ raise ValueError(f"RTTM file is not provided for this sample {sample}")
529
+ rttm_lines = open(sample.rttm_file).readlines()
530
+ uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0]
531
+ mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping']
532
+ rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks)
533
+ fr_level_target = assign_frame_level_spk_vector(
534
+ rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks
535
+ )
536
+ seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target)
537
+ return seg_target
538
+
539
+ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target):
540
+ """
541
+ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level
542
+ speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate
543
+ ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres`
544
+ to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has
545
+ dimension of (number of base-scale segments x 2) dimension.
546
+
547
+ Example of seg_target:
548
+ [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]]
549
+
550
+ Args:
551
+ uniq_id (str):
552
+ Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file.
553
+ fr_level_target (torch.tensor):
554
+ frame-level binary speaker annotation (1: exist 0: non-exist) generated from RTTM file.
555
+
556
+ Returns:
557
+ seg_target (torch.tensor):
558
+ Tensor variable containing binary hard-labels of speaker activity in each base-scale segment.
559
+
560
+ """
561
+ if fr_level_target is None:
562
+ return None
563
+ else:
564
+ seg_target_list = []
565
+ for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]:
566
+ seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec)
567
+ soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / (
568
+ seg_end_fr - seg_stt_fr
569
+ )
570
+ label_vec = (soft_label_vec > self.soft_label_thres).int()
571
+ seg_target_list.append(label_vec)
572
+ seg_target = torch.stack(seg_target_list)
573
+ return seg_target
574
+
575
+ def __getitem__(self, index):
576
+ sample = self.collection[index]
577
+ if sample.offset is None:
578
+ sample.offset = 0
579
+
580
+ uniq_id = os.path.splitext(os.path.basename(sample.audio_file))[0]
581
+ scale_n = len(self.emb_dict.keys())
582
+ _avg_embs = torch.stack([self.emb_dict[scale_index][uniq_id]['avg_embs'] for scale_index in range(scale_n)])
583
+
584
+ if self.pairwise_infer:
585
+ avg_embs = _avg_embs[:, :, self.collection[index].target_spks]
586
+ else:
587
+ avg_embs = _avg_embs
588
+
589
+ if avg_embs.shape[2] > self.max_spks:
590
+ raise ValueError(
591
+ f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}"
592
+ )
593
+
594
+ feats = []
595
+ for scale_index in range(scale_n):
596
+ repeat_mat = self.emb_seq["session_scale_mapping"][uniq_id][scale_index]
597
+ feats.append(self.emb_seq[scale_index][uniq_id][repeat_mat, :])
598
+ feats_out = torch.stack(feats).permute(1, 0, 2)
599
+ feats_len = feats_out.shape[0]
600
+
601
+ if self.seq_eval_mode:
602
+ targets = self.parse_rttm_multiscale(sample)
603
+ else:
604
+ targets = torch.zeros(feats_len, 2).float()
605
+
606
+ return feats_out, feats_len, targets, avg_embs
607
+
608
+
609
+ def _msdd_train_collate_fn(self, batch):
610
+ """
611
+ Collate batch of variables that are needed for raw waveform to diarization label training.
612
+ The following variables are included in training/validation batch:
613
+
614
+ Args:
615
+ batch (tuple):
616
+ Batch tuple containing the variables for the diarization training.
617
+ Returns:
618
+ features (torch.tensor):
619
+ Raw waveform samples (time series) loaded from the audio_filepath in the input manifest file.
620
+ feature lengths (time series sample length):
621
+ A list of lengths of the raw waveform samples.
622
+ ms_seg_timestamps (torch.tensor):
623
+ Matrix containing the start time and end time (timestamps) for each segment and each scale.
624
+ ms_seg_timestamps is needed for extracting acoustic features from raw waveforms.
625
+ ms_seg_counts (torch.tensor):
626
+ Matrix containing The number of segments for each scale. ms_seg_counts is necessary for reshaping
627
+ the input matrix for the MSDD model.
628
+ clus_label_index (torch.tensor):
629
+ Groundtruth Clustering label (cluster index for each segment) from RTTM files for training purpose.
630
+ clus_label_index is necessary for calculating cluster-average embedding.
631
+ scale_mapping (torch.tensor):
632
+ Matrix containing the segment indices of each scale. scale_mapping is necessary for reshaping the
633
+ multiscale embeddings to form an input matrix for the MSDD model.
634
+ targets (torch.tensor):
635
+ Groundtruth Speaker label for the given input embedding sequence.
636
+ """
637
+ packed_batch = list(zip(*batch))
638
+ features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = packed_batch
639
+ features_list, feature_length_list = [], []
640
+ ms_seg_timestamps_list, ms_seg_counts_list, scale_clus_label_list, scale_mapping_list, targets_list = (
641
+ [],
642
+ [],
643
+ [],
644
+ [],
645
+ [],
646
+ )
647
+
648
+ max_raw_feat_len = max([x.shape[0] for x in features])
649
+ max_target_len = max([x.shape[0] for x in targets])
650
+ max_total_seg_len = max([x.shape[0] for x in clus_label_index])
651
+
652
+ for feat, feat_len, ms_seg_ts, ms_seg_ct, scale_clus, scl_map, tgt in batch:
653
+ seq_len = tgt.shape[0]
654
+ pad_feat = (0, max_raw_feat_len - feat_len)
655
+ pad_tgt = (0, 0, 0, max_target_len - seq_len)
656
+ pad_sm = (0, max_target_len - seq_len)
657
+ pad_ts = (0, 0, 0, max_target_len - seq_len)
658
+ pad_sc = (0, max_total_seg_len - scale_clus.shape[0])
659
+ padded_feat = torch.nn.functional.pad(feat, pad_feat)
660
+ padded_tgt = torch.nn.functional.pad(tgt, pad_tgt)
661
+ padded_sm = torch.nn.functional.pad(scl_map, pad_sm)
662
+ padded_ms_seg_ts = torch.nn.functional.pad(ms_seg_ts, pad_ts)
663
+ padded_scale_clus = torch.nn.functional.pad(scale_clus, pad_sc)
664
+
665
+ features_list.append(padded_feat)
666
+ feature_length_list.append(feat_len.clone().detach())
667
+ ms_seg_timestamps_list.append(padded_ms_seg_ts)
668
+ ms_seg_counts_list.append(ms_seg_ct.clone().detach())
669
+ scale_clus_label_list.append(padded_scale_clus)
670
+ scale_mapping_list.append(padded_sm)
671
+ targets_list.append(padded_tgt)
672
+
673
+ features = torch.stack(features_list)
674
+ feature_length = torch.stack(feature_length_list)
675
+ ms_seg_timestamps = torch.stack(ms_seg_timestamps_list)
676
+ clus_label_index = torch.stack(scale_clus_label_list)
677
+ ms_seg_counts = torch.stack(ms_seg_counts_list)
678
+ scale_mapping = torch.stack(scale_mapping_list)
679
+ targets = torch.stack(targets_list)
680
+ return features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets
681
+
682
+
683
+ def _msdd_infer_collate_fn(self, batch):
684
+ """
685
+ Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings.
686
+
687
+ Args:
688
+ batch (tuple):
689
+ Batch tuple containing feats, feats_len, targets and ms_avg_embs.
690
+ Returns:
691
+ feats (torch.tensor):
692
+ Collated speaker embedding with unified length.
693
+ feats_len (torch.tensor):
694
+ The actual length of each embedding sequence without zero padding.
695
+ targets (torch.tensor):
696
+ Groundtruth Speaker label for the given input embedding sequence.
697
+ ms_avg_embs (torch.tensor):
698
+ Cluster-average speaker embedding vectors.
699
+ """
700
+
701
+ packed_batch = list(zip(*batch))
702
+ feats, feats_len, targets, ms_avg_embs = packed_batch
703
+ feats_list, flen_list, targets_list, ms_avg_embs_list = [], [], [], []
704
+ max_audio_len = max(feats_len)
705
+ max_target_len = max([x.shape[0] for x in targets])
706
+
707
+ for feature, feat_len, target, ivector in batch:
708
+ flen_list.append(feat_len)
709
+ ms_avg_embs_list.append(ivector)
710
+ if feat_len < max_audio_len:
711
+ pad_a = (0, 0, 0, 0, 0, max_audio_len - feat_len)
712
+ pad_t = (0, 0, 0, max_target_len - target.shape[0])
713
+ padded_feature = torch.nn.functional.pad(feature, pad_a)
714
+ padded_target = torch.nn.functional.pad(target, pad_t)
715
+ feats_list.append(padded_feature)
716
+ targets_list.append(padded_target)
717
+ else:
718
+ targets_list.append(target.clone().detach())
719
+ feats_list.append(feature.clone().detach())
720
+
721
+ feats = torch.stack(feats_list)
722
+ feats_len = torch.tensor(flen_list)
723
+ targets = torch.stack(targets_list)
724
+ ms_avg_embs = torch.stack(ms_avg_embs_list)
725
+ return feats, feats_len, targets, ms_avg_embs
726
+
727
+
728
+ class AudioToSpeechMSDDTrainDataset(_AudioMSDDTrainDataset):
729
+ """
730
+ Dataset class that loads a json file containing paths to audio files,
731
+ rttm files and number of speakers. This Dataset class is designed for
732
+ training or fine-tuning speaker embedding extractor and diarization decoder
733
+ at the same time.
734
+
735
+ Example:
736
+ {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
737
+ "rttm_filepath": "/path/to/diar_label_0.rttm}
738
+ ...
739
+ {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
740
+ "rttm_filepath": "/path/to/diar_label_n.rttm}
741
+
742
+ Args:
743
+ manifest_filepath (str):
744
+ Path to input manifest json files.
745
+ multiscale_args_dict (dict):
746
+ Dictionary containing the parameters for multiscale segmentation and clustering.
747
+ emb_dir (str):
748
+ Path to a temporary folder where segmentation information for embedding extraction is saved.
749
+ soft_label_thres (float):
750
+ A threshold that determines the label of each segment based on RTTM file information.
751
+ featurizer:
752
+ Featurizer instance for generating features from the raw waveform.
753
+ window_stride (float):
754
+ Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
755
+ emb_batch_size (int):
756
+ Number of embedding vectors that are trained with attached computational graphs.
757
+ pairwise_infer (bool):
758
+ This variable should be True if dataloader is created for an inference task.
759
+ """
760
+
761
+ def __init__(
762
+ self,
763
+ *,
764
+ manifest_filepath: str,
765
+ multiscale_args_dict: Dict,
766
+ emb_dir: str,
767
+ soft_label_thres: float,
768
+ featurizer,
769
+ window_stride,
770
+ emb_batch_size,
771
+ pairwise_infer: bool,
772
+ global_rank: int,
773
+ ):
774
+ super().__init__(
775
+ manifest_filepath=manifest_filepath,
776
+ multiscale_args_dict=multiscale_args_dict,
777
+ emb_dir=emb_dir,
778
+ soft_label_thres=soft_label_thres,
779
+ featurizer=featurizer,
780
+ window_stride=window_stride,
781
+ emb_batch_size=emb_batch_size,
782
+ pairwise_infer=pairwise_infer,
783
+ global_rank=global_rank,
784
+ )
785
+
786
+ def msdd_train_collate_fn(self, batch):
787
+ return _msdd_train_collate_fn(self, batch)
788
+
789
+
790
+ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset):
791
+ """
792
+ Dataset class that loads a json file containing paths to audio files,
793
+ rttm files and number of speakers. The created labels are used for diarization inference.
794
+
795
+ Example:
796
+ {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2,
797
+ "rttm_filepath": "/path/to/diar_label_0.rttm}
798
+ ...
799
+ {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2,
800
+ "rttm_filepath": "/path/to/diar_label_n.rttm}
801
+
802
+ Args:
803
+ manifest_filepath (str):
804
+ Path to input manifest json files.
805
+ emb_dict (dict):
806
+ Dictionary containing cluster-average embeddings and speaker mapping information.
807
+ emb_seq (dict):
808
+ Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps.
809
+ clus_label_dict (dict):
810
+ Subsegment-level (from base-scale) speaker labels from clustering results.
811
+ soft_label_thres (float):
812
+ Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps.
813
+ featurizer:
814
+ Featurizer instance for generating features from raw waveform.
815
+ use_single_scale_clus (bool):
816
+ Use only one scale for clustering instead of using multiple scales of embeddings for clustering.
817
+ seq_eval_mode (bool):
818
+ If True, F1 score will be calculated for each speaker pair during inference mode.
819
+ window_stride (float):
820
+ Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames.
821
+ pairwise_infer (bool):
822
+ If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio
823
+ is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then
824
+ fed into the MSDD to merge the individual results.
825
+ """
826
+
827
+ def __init__(
828
+ self,
829
+ *,
830
+ manifest_filepath: str,
831
+ emb_dict: Dict,
832
+ emb_seq: Dict,
833
+ clus_label_dict: Dict,
834
+ soft_label_thres: float,
835
+ use_single_scale_clus: bool,
836
+ seq_eval_mode: bool,
837
+ window_stride: float,
838
+ pairwise_infer: bool,
839
+ ):
840
+ super().__init__(
841
+ manifest_filepath=manifest_filepath,
842
+ emb_dict=emb_dict,
843
+ emb_seq=emb_seq,
844
+ clus_label_dict=clus_label_dict,
845
+ soft_label_thres=soft_label_thres,
846
+ use_single_scale_clus=use_single_scale_clus,
847
+ window_stride=window_stride,
848
+ seq_eval_mode=seq_eval_mode,
849
+ pairwise_infer=pairwise_infer,
850
+ )
851
+
852
+ def msdd_infer_collate_fn(self, batch):
853
+ return _msdd_infer_collate_fn(self, batch)
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label.py ADDED
@@ -0,0 +1,1294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import os
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import webdataset as wd
20
+
21
+ from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_sharded_filepaths
22
+ from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
23
+ from nemo.collections.asr.parts.preprocessing.segment import available_formats as valid_sf_formats
24
+ from nemo.collections.common.parts.preprocessing import collections
25
+ from nemo.core.classes import Dataset, IterableDataset
26
+ from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType, RegressionValuesType
27
+ from nemo.utils import logging
28
+
29
+ # List of valid file formats (prioritized by order of importance)
30
+ VALID_FILE_FORMATS = ';'.join(['wav', 'mp3', 'flac'] + [fmt.lower() for fmt in valid_sf_formats.keys()])
31
+
32
+
33
+ def repeat_signal(signal: torch.Tensor, sig_len: int, required_length: int) -> torch.Tensor:
34
+ """repeat signal to make short signal to have required_length
35
+ Args:
36
+ signal (Tensor): input signal
37
+ sig_len (int): length of input signal
38
+ required_length (int): length of generated signal
39
+ Returns:
40
+ signal (Tensor): generated signal of required_length by repeating itself.
41
+ """
42
+ sub: torch.Tensor = torch.tensor([])
43
+ repeat = int(required_length // sig_len)
44
+ rem = int(required_length % sig_len)
45
+ sub: torch.Tensor = torch.tensor([])
46
+ rep_sig: torch.Tensor = torch.cat(repeat * [signal])
47
+ if rem > 0:
48
+ sub = signal[-rem:]
49
+ signal = torch.cat((rep_sig, sub))
50
+ else:
51
+ signal = rep_sig
52
+ return signal
53
+
54
+
55
+ def normalize(signal):
56
+ """normalize signal
57
+ Args:
58
+ signal(FloatTensor): signal to be normalized.
59
+ """
60
+ signal_minusmean = signal - signal.mean()
61
+ return signal_minusmean / signal_minusmean.abs().max()
62
+
63
+
64
+ def count_occurence(manifest_file_id):
65
+ """Count number of wav files in Dict manifest_file_id. Use for _TarredAudioToLabelDataset.
66
+ Args:
67
+ manifest_file_id (Dict): Dict of files and their corresponding id. {'A-sub0' : 1, ..., 'S-sub10':100}
68
+ Returns:
69
+ count (Dict): Dict of wav files {'A' : 2, ..., 'S':10}
70
+ """
71
+ count = dict()
72
+ for i in manifest_file_id:
73
+ audio_filename = i.split("-sub")[0]
74
+ count[audio_filename] = count.get(audio_filename, 0) + 1
75
+ return count
76
+
77
+
78
+ def _speech_collate_fn(batch, pad_id):
79
+ """collate batch of audio sig, audio len, tokens, tokens len
80
+ Args:
81
+ batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
82
+ LongTensor): A tuple of tuples of signal, signal lengths,
83
+ encoded tokens, and encoded tokens length. This collate func
84
+ assumes the signals are 1d torch tensors (i.e. mono audio).
85
+ """
86
+ _, audio_lengths, _, tokens_lengths = zip(*batch)
87
+ max_audio_len = 0
88
+ has_audio = audio_lengths[0] is not None
89
+ if has_audio:
90
+ max_audio_len = max(audio_lengths).item()
91
+ max_tokens_len = max(tokens_lengths).item()
92
+
93
+ audio_signal, tokens = [], []
94
+ for sig, sig_len, tokens_i, tokens_i_len in batch:
95
+ if has_audio:
96
+ sig_len = sig_len.item()
97
+ if sig_len < max_audio_len:
98
+ pad = (0, max_audio_len - sig_len)
99
+ sig = torch.nn.functional.pad(sig, pad)
100
+ audio_signal.append(sig)
101
+ tokens_i_len = tokens_i_len.item()
102
+ if tokens_i_len < max_tokens_len:
103
+ pad = (0, max_tokens_len - tokens_i_len)
104
+ tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
105
+ tokens.append(tokens_i)
106
+
107
+ if has_audio:
108
+ audio_signal = torch.stack(audio_signal)
109
+ audio_lengths = torch.stack(audio_lengths)
110
+ else:
111
+ audio_signal, audio_lengths = None, None
112
+ tokens = torch.stack(tokens)
113
+ tokens_lengths = torch.stack(tokens_lengths)
114
+
115
+ return audio_signal, audio_lengths, tokens, tokens_lengths
116
+
117
+
118
+ def _fixed_seq_collate_fn(self, batch):
119
+ """collate batch of audio sig, audio len, tokens, tokens len
120
+ Args:
121
+ batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
122
+ LongTensor): A tuple of tuples of signal, signal lengths,
123
+ encoded tokens, and encoded tokens length. This collate func
124
+ assumes the signals are 1d torch tensors (i.e. mono audio).
125
+ """
126
+ _, audio_lengths, _, tokens_lengths = zip(*batch)
127
+
128
+ has_audio = audio_lengths[0] is not None
129
+ fixed_length = int(max(audio_lengths))
130
+
131
+ audio_signal, tokens, new_audio_lengths = [], [], []
132
+ for sig, sig_len, tokens_i, _ in batch:
133
+ if has_audio:
134
+ sig_len = sig_len.item()
135
+ chunck_len = sig_len - fixed_length
136
+
137
+ if chunck_len < 0:
138
+ repeat = fixed_length // sig_len
139
+ rem = fixed_length % sig_len
140
+ sub = sig[-rem:] if rem > 0 else torch.tensor([])
141
+ rep_sig = torch.cat(repeat * [sig])
142
+ sig = torch.cat((rep_sig, sub))
143
+ new_audio_lengths.append(torch.tensor(fixed_length))
144
+
145
+ audio_signal.append(sig)
146
+
147
+ tokens.append(tokens_i)
148
+
149
+ if has_audio:
150
+ audio_signal = torch.stack(audio_signal)
151
+ audio_lengths = torch.stack(new_audio_lengths)
152
+ else:
153
+ audio_signal, audio_lengths = None, None
154
+ tokens = torch.stack(tokens)
155
+ tokens_lengths = torch.stack(tokens_lengths)
156
+
157
+ return audio_signal, audio_lengths, tokens, tokens_lengths
158
+
159
+
160
+ def _vad_frame_seq_collate_fn(self, batch):
161
+ """collate batch of audio sig, audio len, tokens, tokens len
162
+ Args:
163
+ batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
164
+ LongTensor): A tuple of tuples of signal, signal lengths,
165
+ encoded tokens, and encoded tokens length. This collate func
166
+ assumes the signals are 1d torch tensors (i.e. mono audio).
167
+ batch size equals to 1.
168
+ """
169
+ slice_length = int(self.featurizer.sample_rate * self.window_length_in_sec)
170
+ _, audio_lengths, _, tokens_lengths = zip(*batch)
171
+ slice_length = int(min(slice_length, max(audio_lengths)))
172
+ shift = int(self.featurizer.sample_rate * self.shift_length_in_sec)
173
+ has_audio = audio_lengths[0] is not None
174
+
175
+ audio_signal, num_slices, tokens, audio_lengths = [], [], [], []
176
+
177
+ append_len_start = slice_length // 2
178
+ append_len_end = slice_length - slice_length // 2
179
+ for sig, sig_len, tokens_i, _ in batch:
180
+ if self.normalize_audio:
181
+ sig = normalize(sig)
182
+ start = torch.zeros(append_len_start)
183
+ end = torch.zeros(append_len_end)
184
+ sig = torch.cat((start, sig, end))
185
+ sig_len += slice_length
186
+
187
+ if has_audio:
188
+ slices = torch.div(sig_len - slice_length, shift, rounding_mode='trunc')
189
+ for slice_id in range(slices):
190
+ start_idx = slice_id * shift
191
+ end_idx = start_idx + slice_length
192
+ signal = sig[start_idx:end_idx]
193
+ audio_signal.append(signal)
194
+
195
+ num_slices.append(slices)
196
+ tokens.extend([tokens_i] * slices)
197
+ audio_lengths.extend([slice_length] * slices)
198
+
199
+ if has_audio:
200
+ audio_signal = torch.stack(audio_signal)
201
+ audio_lengths = torch.tensor(audio_lengths)
202
+ else:
203
+ audio_signal, audio_lengths = None, None
204
+
205
+ tokens = torch.stack(tokens)
206
+ tokens_lengths = torch.tensor(num_slices)
207
+ return audio_signal, audio_lengths, tokens, tokens_lengths
208
+
209
+
210
+ class _AudioLabelDataset(Dataset):
211
+ """
212
+ Dataset that loads tensors via a json file containing paths to audio files,
213
+ labels, and durations and offsets(in seconds). Each new line is a
214
+ different sample. Example below:
215
+ and their target labels. JSON files should be of the following format::
216
+ {"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
217
+ target_label_0, "offset": offset_in_sec_0}
218
+ ...
219
+ {"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
220
+ target_label_n, "offset": offset_in_sec_n}
221
+ Args:
222
+ manifest_filepath (Union[str, List[str]]): Dataset parameter. Path to JSON containing data.
223
+ labels (list): Dataset parameter. List of target classes that can be output by the speaker recognition model.
224
+ featurizer
225
+ min_duration (float): Dataset parameter. All training files which have a duration less than min_duration
226
+ are dropped. Note: Duration is read from the manifest JSON.
227
+ Defaults to 0.1.
228
+ max_duration (float): Dataset parameter.
229
+ All training files which have a duration more than max_duration
230
+ are dropped. Note: Duration is read from the manifest JSON.
231
+ Defaults to None.
232
+ trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim().
233
+ Defaults to False.
234
+ """
235
+
236
+ @property
237
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
238
+ """Returns definitions of module output ports.
239
+ """
240
+
241
+ output_types = {
242
+ 'audio_signal': NeuralType(
243
+ ('B', 'T'),
244
+ AudioSignal(freq=self._sample_rate)
245
+ if self is not None and hasattr(self, '_sample_rate')
246
+ else AudioSignal(),
247
+ ),
248
+ 'a_sig_length': NeuralType(tuple('B'), LengthsType()),
249
+ }
250
+
251
+ if self.is_regression_task:
252
+ output_types.update(
253
+ {
254
+ 'targets': NeuralType(tuple('B'), RegressionValuesType()),
255
+ 'targets_length': NeuralType(tuple('B'), LengthsType()),
256
+ }
257
+ )
258
+ else:
259
+
260
+ output_types.update(
261
+ {'label': NeuralType(tuple('B'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),}
262
+ )
263
+
264
+ return output_types
265
+
266
+ def __init__(
267
+ self,
268
+ *,
269
+ manifest_filepath: Union[str, List[str]],
270
+ labels: List[str],
271
+ featurizer,
272
+ min_duration: Optional[float] = 0.1,
273
+ max_duration: Optional[float] = None,
274
+ trim: bool = False,
275
+ is_regression_task: bool = False,
276
+ cal_labels_occurrence: Optional[bool] = False,
277
+ ):
278
+ super().__init__()
279
+ if isinstance(manifest_filepath, str):
280
+ manifest_filepath = manifest_filepath.split(',')
281
+ cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True)
282
+ self.collection = collections.ASRSpeechLabel(
283
+ manifests_files=manifest_filepath,
284
+ min_duration=min_duration,
285
+ max_duration=max_duration,
286
+ is_regression_task=is_regression_task,
287
+ cal_labels_occurrence=cal_labels_occurrence,
288
+ )
289
+
290
+ self.featurizer = featurizer
291
+ self.trim = trim
292
+ self.is_regression_task = is_regression_task
293
+
294
+ if not is_regression_task:
295
+ self.labels = labels if labels else self.collection.uniq_labels
296
+ self.num_classes = len(self.labels) if self.labels is not None else 1
297
+ self.label2id, self.id2label = {}, {}
298
+ self.id2occurrence, self.labels_occurrence = {}, []
299
+
300
+ for label_id, label in enumerate(self.labels):
301
+ self.label2id[label] = label_id
302
+ self.id2label[label_id] = label
303
+ if cal_labels_occurrence:
304
+ self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
305
+
306
+ if cal_labels_occurrence:
307
+ self.labels_occurrence = [self.id2occurrence[k] for k in sorted(self.id2occurrence)]
308
+
309
+ for idx in range(len(self.labels[:5])):
310
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
311
+
312
+ else:
313
+ self.labels = []
314
+ self.num_classes = 1
315
+
316
+ def __len__(self):
317
+ return len(self.collection)
318
+
319
+ def __getitem__(self, index):
320
+ sample = self.collection[index]
321
+
322
+ offset = sample.offset
323
+
324
+ if offset is None:
325
+ offset = 0
326
+
327
+ features = self.featurizer.process(sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim)
328
+ f, fl = features, torch.tensor(features.shape[0]).long()
329
+
330
+ if not self.is_regression_task:
331
+ t = torch.tensor(self.label2id[sample.label]).long()
332
+ else:
333
+ t = torch.tensor(sample.label).float()
334
+
335
+ tl = torch.tensor(1).long() # For compatibility with collate_fn used later
336
+
337
+ return f, fl, t, tl
338
+
339
+
340
+ # Ported from https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/data/speech2text/speech_commands.py
341
+ class AudioToClassificationLabelDataset(_AudioLabelDataset):
342
+ """
343
+ Dataset that loads tensors via a json file containing paths to audio
344
+ files, command class, and durations (in seconds). Each new line is a
345
+ different sample. Example below:
346
+ {"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
347
+ target_label_0, "offset": offset_in_sec_0}
348
+ ...
349
+ {"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
350
+ target_label_n, "offset": offset_in_sec_n}
351
+ Args:
352
+ manifest_filepath (Union[str, List[str]]): Path to manifest json as described above. Can
353
+ be comma-separated paths.
354
+ labels (Optional[list]): String containing all the possible labels to map to
355
+ if None then automatically picks from ASRSpeechLabel collection.
356
+ featurizer: Initialized featurizer class that converts paths of
357
+ audio to feature tensors
358
+ max_duration: If audio exceeds this length, do not include in dataset
359
+ min_duration: If audio is less than this length, do not include
360
+ in dataset
361
+ trim: Boolean flag whether to trim the audio
362
+ """
363
+
364
+ def _collate_fn(self, batch):
365
+ return _speech_collate_fn(batch, pad_id=0)
366
+
367
+
368
+ class AudioToSpeechLabelDataset(_AudioLabelDataset):
369
+ """
370
+ Dataset that loads tensors via a json file containing paths to audio
371
+ files, command class, and durations (in seconds). Each new line is a
372
+ different sample. Example below:
373
+ {"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
374
+ target_label_0, "offset": offset_in_sec_0}
375
+ ...
376
+ {"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
377
+ target_label_n, "offset": offset_in_sec_n}
378
+ Args:
379
+ manifest_filepath (Union[str, List[str]]): Path to manifest json as described above. Can
380
+ be comma-separated paths.
381
+ labels (Optional[list]): String containing all the possible labels to map to
382
+ if None then automatically picks from ASRSpeechLabel collection.
383
+ min_duration (float): Dataset parameter.
384
+ All training files which have a duration less than min_duration
385
+ are dropped. Note: Duration is read from the manifest JSON.
386
+ Defaults to 0.1.
387
+ max_duration (float): Dataset parameter.
388
+ All training files which have a duration more than max_duration
389
+ are dropped. Note: Duration is read from the manifest JSON.
390
+ Defaults to None.
391
+ trim (bool): Whether to use trim silence from beginning and end
392
+ of audio signal using librosa.effects.trim().
393
+ Defaults to False.
394
+ window_length_in_sec (float): length of window/slice (in seconds)
395
+ Use this for speaker recognition and VAD tasks.
396
+ shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch
397
+ Use this for VAD task during inference.
398
+ normalize_audio (bool): Whether to normalize audio signal.
399
+ Defaults to False.
400
+ is_regression_task (bool): Whether the dataset is for a regression task instead of classification.
401
+ Defaults to False.
402
+ cal_labels_occurrence (bool): Whether to calculate occurrence of labels
403
+ Defaults to False.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ *,
409
+ manifest_filepath: Union[str, List[str]],
410
+ labels: List[str],
411
+ featurizer,
412
+ min_duration: Optional[float] = 0.1,
413
+ max_duration: Optional[float] = None,
414
+ trim: bool = False,
415
+ window_length_in_sec: Optional[float] = 8,
416
+ shift_length_in_sec: Optional[float] = 1,
417
+ normalize_audio: bool = False,
418
+ is_regression_task: bool = False,
419
+ cal_labels_occurrence: Optional[bool] = False,
420
+ ):
421
+ self.window_length_in_sec = window_length_in_sec
422
+ self.shift_length_in_sec = shift_length_in_sec
423
+ self.normalize_audio = normalize_audio
424
+
425
+ logging.debug("Window/slice length considered for collate func is {}".format(self.window_length_in_sec))
426
+ logging.debug("Shift length considered for collate func is {}".format(self.shift_length_in_sec))
427
+
428
+ super().__init__(
429
+ manifest_filepath=manifest_filepath,
430
+ labels=labels,
431
+ featurizer=featurizer,
432
+ min_duration=min_duration,
433
+ max_duration=max_duration,
434
+ trim=trim,
435
+ is_regression_task=is_regression_task,
436
+ cal_labels_occurrence=cal_labels_occurrence,
437
+ )
438
+
439
+ def fixed_seq_collate_fn(self, batch):
440
+ return _fixed_seq_collate_fn(self, batch)
441
+
442
+ def vad_frame_seq_collate_fn(self, batch):
443
+ return _vad_frame_seq_collate_fn(self, batch)
444
+
445
+
446
+ class _TarredAudioLabelDataset(IterableDataset):
447
+ """
448
+ A similar Dataset to the AudioLabelDataSet, but which loads tarred audio files.
449
+
450
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToSpeechLabelDataset),
451
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
452
+ contain the information for one audio file, including at least the label and name of the audio
453
+ file within the tarball.
454
+
455
+ Valid formats for the audio_tar_filepaths argument include:
456
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
457
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
458
+
459
+ Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
460
+ This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
461
+ Supported opening braces - { <=> (, [, < and the special tag _OP_.
462
+ Supported closing braces - } <=> ), ], > and the special tag _CL_.
463
+ For SLURM based tasks, we suggest the use of the special tags for ease of use.
464
+
465
+ See the documentation for more information about accepted data and input formats.
466
+
467
+ If using multiple processes the number of shards should be divisible by the number of workers to ensure an
468
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
469
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
470
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
471
+
472
+ Notice that a few arguments are different from the AudioLabelDataSet; for example, shuffle (bool) has been
473
+ replaced by shuffle_n (int).
474
+
475
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
476
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
477
+
478
+ Args:
479
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
480
+ string (can be brace-expandable).
481
+ manifest_filepath (str): Path to the manifest.
482
+ labels (list): Dataset parameter.
483
+ List of target classes that can be output by the speaker recognition model.
484
+ featurizer
485
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
486
+ See WebDataset documentation for more details.
487
+ Defaults to 0.
488
+ min_duration (float): Dataset parameter.
489
+ All training files which have a duration less than min_duration
490
+ are dropped. Note: Duration is read from the manifest JSON.
491
+ Defaults to 0.1.
492
+ max_duration (float): Dataset parameter.
493
+ All training files which have a duration more than max_duration
494
+ are dropped. Note: Duration is read from the manifest JSON.
495
+ Defaults to None.
496
+ trim(bool): Whether to use trim silence from beginning and end
497
+ of audio signal using librosa.effects.trim().
498
+ Defaults to False.
499
+ window_length_in_sec (float): length of slice/window (in seconds) # Pass this only for speaker recognition and VAD task
500
+ shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task. in a batch # Pass this only for VAD task during inference.
501
+ normalize_audio (bool): Whether to normalize audio signal. Defaults to False.
502
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
503
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
504
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
505
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
506
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
507
+ The benefit of replication is that it allows each node to sample data points from the entire
508
+ dataset independently of other nodes, and reduces dependence on the value of `shuffle_n`.
509
+
510
+ .. warning::
511
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
512
+ and therefore more than one node may sample the same tarfile, and even sample the same
513
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
514
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
515
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
516
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
517
+ or test datasets.
518
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
519
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
520
+ is_regression_task (bool): Whether it is a regression task. Defualts to False.
521
+ """
522
+
523
+ def __init__(
524
+ self,
525
+ *,
526
+ audio_tar_filepaths: Union[str, List[str]],
527
+ manifest_filepath: Union[str, List[str]],
528
+ labels: List[str],
529
+ featurizer,
530
+ shuffle_n: int = 0,
531
+ min_duration: Optional[float] = 0.1,
532
+ max_duration: Optional[float] = None,
533
+ trim: bool = False,
534
+ shard_strategy: str = "scatter",
535
+ global_rank: int = 0,
536
+ world_size: int = 0,
537
+ is_regression_task: bool = False,
538
+ ):
539
+ cache_datastore_manifests(manifest_filepaths=manifest_filepath)
540
+ self.collection = collections.ASRSpeechLabel(
541
+ manifests_files=manifest_filepath,
542
+ min_duration=min_duration,
543
+ max_duration=max_duration,
544
+ index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID
545
+ )
546
+
547
+ self.file_occurence = count_occurence(self.collection.mapping)
548
+
549
+ self.featurizer = featurizer
550
+ self.trim = trim
551
+
552
+ self.labels = labels if labels else self.collection.uniq_labels
553
+ self.num_classes = len(self.labels)
554
+
555
+ self.label2id, self.id2label = {}, {}
556
+ for label_id, label in enumerate(self.labels):
557
+ self.label2id[label] = label_id
558
+ self.id2label[label_id] = label
559
+
560
+ for idx in range(len(self.labels[:5])):
561
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
562
+
563
+ audio_tar_filepaths = expand_sharded_filepaths(
564
+ sharded_filepaths=audio_tar_filepaths,
565
+ shard_strategy=shard_strategy,
566
+ world_size=world_size,
567
+ global_rank=global_rank,
568
+ )
569
+ # Put together WebDataset
570
+ self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)
571
+
572
+ if shuffle_n > 0:
573
+ self._dataset = self._dataset.shuffle(shuffle_n)
574
+ else:
575
+ logging.info("WebDataset will not shuffle files within the tar files.")
576
+
577
+ self._dataset = (
578
+ self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
579
+ .to_tuple('audio', 'key')
580
+ .pipe(self._filter)
581
+ .map(f=self._build_sample)
582
+ )
583
+
584
+ def _filter(self, iterator):
585
+ """This function is used to remove samples that have been filtered out by ASRSpeechLabel already.
586
+ Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
587
+ that was filtered out (e.g. for duration).
588
+ Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
589
+ which may make your code hang as one process will finish before the other.
590
+ """
591
+
592
+ class TarredAudioFilter:
593
+ def __init__(self, collection, file_occurence):
594
+ self.iterator = iterator
595
+ self.collection = collection
596
+ self.file_occurence = file_occurence
597
+ self._iterable = self._internal_generator()
598
+
599
+ def __iter__(self):
600
+ self._iterable = self._internal_generator()
601
+ return self
602
+
603
+ def __next__(self):
604
+ try:
605
+ values = next(self._iterable)
606
+ except StopIteration:
607
+ # reset generator
608
+ self._iterable = self._internal_generator()
609
+ values = next(self._iterable)
610
+
611
+ return values
612
+
613
+ def _internal_generator(self):
614
+ """
615
+ WebDataset requires an Iterator, but we require an iterable that yields 1-or-more
616
+ values per value inside self.iterator.
617
+
618
+ Therefore wrap the iterator with a generator function that will yield 1-or-more
619
+ values per sample in the iterator.
620
+ """
621
+ for _, tup in enumerate(self.iterator):
622
+ audio_bytes, audio_filename = tup
623
+
624
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
625
+ if audio_filename in self.file_occurence:
626
+ for j in range(0, self.file_occurence[file_id]):
627
+ if j == 0:
628
+ audio_filename = file_id
629
+ else:
630
+ audio_filename = file_id + "-sub" + str(j)
631
+ yield audio_bytes, audio_filename
632
+
633
+ return TarredAudioFilter(self.collection, self.file_occurence)
634
+
635
+ def _build_sample(self, tup):
636
+ """Builds the training sample by combining the data from the WebDataset with the manifest info.
637
+ """
638
+ audio_bytes, audio_filename = tup
639
+ # Grab manifest entry from self.collection
640
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
641
+
642
+ manifest_idx = self.collection.mapping[file_id]
643
+ manifest_entry = self.collection[manifest_idx]
644
+
645
+ offset = manifest_entry.offset
646
+ if offset is None:
647
+ offset = 0
648
+
649
+ # Convert audio bytes to IO stream for processing (for SoundFile to read)
650
+ audio_filestream = io.BytesIO(audio_bytes)
651
+ features = self.featurizer.process(
652
+ audio_filestream, offset=offset, duration=manifest_entry.duration, trim=self.trim,
653
+ )
654
+
655
+ audio_filestream.close()
656
+
657
+ # Audio features
658
+ f, fl = features, torch.tensor(features.shape[0]).long()
659
+
660
+ t = self.label2id[manifest_entry.label]
661
+ tl = 1 # For compatibility with collate_fn used later
662
+
663
+ return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
664
+
665
+ def __iter__(self):
666
+ return self._dataset.__iter__()
667
+
668
+ def __len__(self):
669
+ return len(self.collection)
670
+
671
+
672
+ class TarredAudioToClassificationLabelDataset(_TarredAudioLabelDataset):
673
+ """
674
+ A similar Dataset to the AudioToClassificationLabelDataset, but which loads tarred audio files.
675
+
676
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToClassificationLabelDataset),
677
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
678
+ contain the information for one audio file, including at least the transcript and name of the audio
679
+ file within the tarball.
680
+
681
+ Valid formats for the audio_tar_filepaths argument include:
682
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
683
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
684
+
685
+ See the WebDataset documentation for more information about accepted data and input formats.
686
+
687
+ If using multiple processes the number of shards should be divisible by the number of workers to ensure an
688
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
689
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
690
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
691
+
692
+ Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
693
+ replaced by shuffle_n (int).
694
+
695
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
696
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
697
+
698
+ Args:
699
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
700
+ string (can be brace-expandable).
701
+ manifest_filepath (str): Path to the manifest.
702
+ labels (list): Dataset parameter.
703
+ List of target classes that can be output by the speaker recognition model.
704
+ featurizer
705
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
706
+ See WebDataset documentation for more details.
707
+ Defaults to 0.
708
+ min_duration (float): Dataset parameter.
709
+ All training files which have a duration less than min_duration
710
+ are dropped. Note: Duration is read from the manifest JSON.
711
+ Defaults to 0.1.
712
+ max_duration (float): Dataset parameter.
713
+ All training files which have a duration more than max_duration
714
+ are dropped. Note: Duration is read from the manifest JSON.
715
+ Defaults to None.
716
+ trim(bool): Whether to use trim silence from beginning and end
717
+ of audio signal using librosa.effects.trim().
718
+ Defaults to False.
719
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
720
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
721
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
722
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
723
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
724
+ The benefit of replication is that it allows each node to sample data points from the entire
725
+ dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
726
+
727
+ .. warning::
728
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
729
+ and therefore more than one node may sample the same tarfile, and even sample the same
730
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
731
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
732
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
733
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
734
+ or test datasets.
735
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
736
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
737
+ is_regression_task (bool): Whether it is a regression task. Defualts to False.
738
+ """
739
+
740
+ def _collate_fn(self, batch):
741
+ return _speech_collate_fn(batch, pad_id=0)
742
+
743
+
744
+ class TarredAudioToSpeechLabelDataset(_TarredAudioLabelDataset):
745
+ """
746
+ A similar Dataset to the AudioToSpeechLabelDataset, but which loads tarred audio files.
747
+
748
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToSpeechLabelDataset),
749
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
750
+ contain the information for one audio file, including at least the transcript and name of the audio
751
+ file within the tarball.
752
+
753
+ Valid formats for the audio_tar_filepaths argument include:
754
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
755
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
756
+
757
+ See the WebDataset documentation for more information about accepted data and input formats.
758
+
759
+ If using multiple processes the number of shards should be divisible by the number of workers to ensure an
760
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
761
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
762
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
763
+
764
+ Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
765
+ replaced by shuffle_n (int).
766
+
767
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
768
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
769
+
770
+ Args:
771
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
772
+ string (can be brace-expandable).
773
+ manifest_filepath (str): Path to the manifest.
774
+ labels (list): Dataset parameter.
775
+ List of target classes that can be output by the speaker recognition model.
776
+ featurizer
777
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
778
+ See WebDataset documentation for more details.
779
+ Defaults to 0.
780
+ min_duration (float): Dataset parameter.
781
+ All training files which have a duration less than min_duration
782
+ are dropped. Note: Duration is read from the manifest JSON.
783
+ Defaults to 0.1.
784
+ max_duration (float): Dataset parameter.
785
+ All training files which have a duration more than max_duration
786
+ are dropped. Note: Duration is read from the manifest JSON.
787
+ Defaults to None.
788
+ trim(bool): Whether to use trim silence from beginning and end
789
+ of audio signal using librosa.effects.trim().
790
+ Defaults to False.
791
+ window_length_in_sec (float): time length of window/slice (in seconds) # Pass this only for speaker recognition and VAD task
792
+ shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task. in a batch # Pass this only for VAD task during inference.
793
+ normalize_audio (bool): Whether to normalize audio signal. Defaults to False.
794
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
795
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
796
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
797
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
798
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
799
+ The benefit of replication is that it allows each node to sample data points from the entire
800
+ dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
801
+
802
+ .. warning::
803
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
804
+ and therefore more than one node may sample the same tarfile, and even sample the same
805
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
806
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
807
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
808
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
809
+ or test datasets.
810
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
811
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
812
+ """
813
+
814
+ def __init__(
815
+ self,
816
+ *,
817
+ audio_tar_filepaths: Union[str, List[str]],
818
+ manifest_filepath: Union[str, List[str]],
819
+ labels: List[str],
820
+ featurizer,
821
+ shuffle_n: int = 0,
822
+ min_duration: Optional[float] = 0.1,
823
+ max_duration: Optional[float] = None,
824
+ trim: bool = False,
825
+ window_length_in_sec: Optional[float] = 8,
826
+ shift_length_in_sec: Optional[float] = 1,
827
+ normalize_audio: bool = False,
828
+ shard_strategy: str = "scatter",
829
+ global_rank: int = 0,
830
+ world_size: int = 0,
831
+ ):
832
+ logging.info("Window/slice length considered for collate func is {}".format(window_length_in_sec))
833
+ logging.info("Shift length considered for collate func is {}".format(shift_length_in_sec))
834
+ self.window_length_in_sec = window_length_in_sec
835
+ self.shift_length_in_sec = shift_length_in_sec
836
+ self.normalize_audio = normalize_audio
837
+
838
+ super().__init__(
839
+ audio_tar_filepaths=audio_tar_filepaths,
840
+ manifest_filepath=manifest_filepath,
841
+ labels=labels,
842
+ featurizer=featurizer,
843
+ shuffle_n=shuffle_n,
844
+ min_duration=min_duration,
845
+ max_duration=max_duration,
846
+ trim=trim,
847
+ shard_strategy=shard_strategy,
848
+ global_rank=global_rank,
849
+ world_size=world_size,
850
+ )
851
+
852
+ def fixed_seq_collate_fn(self, batch):
853
+ return _fixed_seq_collate_fn(self, batch)
854
+
855
+ def sliced_seq_collate_fn(self, batch):
856
+ raise NotImplementedError
857
+
858
+ def vad_frame_seq_collate_fn(self, batch):
859
+ return _vad_frame_seq_collate_fn(self, batch)
860
+
861
+
862
+ class AudioToMultiLabelDataset(Dataset):
863
+ """
864
+ Dataset that loads a json file containing paths to audio files, durations (in seconds), and a sequence of labels.
865
+ Each new line is a different sample. Example below:
866
+ {"audio_filepath": "/path/to/audio_wav_0.wav", "duration": time_in_sec_0, "label": \
867
+ "0 1 1 0 1", "offset": offset_in_sec_0}
868
+ ...
869
+ {"audio_filepath": "/path/to/audio_wav_n.wav", "duration": time_in_sec_n, "label": \
870
+ "0 1 0 0 1", "offset": offset_in_sec_n}
871
+ Args:
872
+ manifest_filepath (Union[str, List[str]]): Path to manifest json as described above. Can
873
+ be comma-separated paths.
874
+ labels (Optional[list]): String containing all the possible labels to map to
875
+ if None then automatically picks from ASRSpeechLabel collection.
876
+ min_duration (float): Dataset parameter.
877
+ All training files which have a duration less than min_duration
878
+ are dropped. Note: Duration is read from the manifest JSON.
879
+ Defaults to 0.1.
880
+ max_duration (float): Dataset parameter.
881
+ All training files which have a duration more than max_duration
882
+ are dropped. Note: Duration is read from the manifest JSON.
883
+ Defaults to None.
884
+ trim (bool): Whether to use trim silence from beginning and end
885
+ of audio signal using librosa.effects.trim().
886
+ Defaults to False.
887
+ window_length_in_sec (float): length of window/slice (in seconds)
888
+ Use this for speaker recognition and VAD tasks.
889
+ shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch
890
+ Use this for VAD task during inference.
891
+ normalize_audio (bool): Whether to normalize audio signal.
892
+ Defaults to False.
893
+ is_regression_task (bool): Whether the dataset is for a regression task instead of classification.
894
+ Defaults to False.
895
+ cal_labels_occurrence (bool): Whether to calculate occurrence of labels
896
+ Defaults to False.
897
+ delimiter (Optional[str]): Delimiter to use when splitting the label string, default to None.
898
+ normalize_audio_db (Optional[float]): normalize audio signal to a target db, default to None.
899
+ """
900
+
901
+ @property
902
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
903
+ """Returns definitions of module output ports.
904
+ """
905
+
906
+ output_types = {
907
+ 'audio_signal': NeuralType(
908
+ ('B', 'T'),
909
+ AudioSignal(freq=self._sample_rate)
910
+ if self is not None and hasattr(self, '_sample_rate')
911
+ else AudioSignal(),
912
+ ),
913
+ 'a_sig_length': NeuralType(tuple('B'), LengthsType()),
914
+ }
915
+
916
+ if self.is_regression_task:
917
+ output_types.update(
918
+ {
919
+ 'targets': NeuralType(tuple('B, T'), RegressionValuesType()),
920
+ 'targets_length': NeuralType(tuple('B'), LengthsType()),
921
+ }
922
+ )
923
+ else:
924
+ output_types.update(
925
+ {'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),}
926
+ )
927
+
928
+ return output_types
929
+
930
+ def __init__(
931
+ self,
932
+ *,
933
+ manifest_filepath: Union[str, List[str]],
934
+ sample_rate: int,
935
+ labels: Optional[List[str]] = None,
936
+ int_values: bool = False,
937
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
938
+ min_duration: Optional[float] = 0.1,
939
+ max_duration: Optional[float] = None,
940
+ trim_silence: bool = False,
941
+ is_regression_task: bool = False,
942
+ cal_labels_occurrence: Optional[bool] = False,
943
+ delimiter: Optional[str] = None,
944
+ normalize_audio_db: Optional[float] = None,
945
+ ):
946
+ super().__init__()
947
+ if isinstance(manifest_filepath, str):
948
+ manifest_filepath = manifest_filepath.split(',')
949
+
950
+ self.delimiter = delimiter
951
+ self.normalize_audio_db = normalize_audio_db
952
+
953
+ self.collection = collections.ASRSpeechLabel(
954
+ manifests_files=manifest_filepath,
955
+ min_duration=min_duration,
956
+ max_duration=max_duration,
957
+ is_regression_task=is_regression_task,
958
+ cal_labels_occurrence=cal_labels_occurrence,
959
+ delimiter=delimiter,
960
+ )
961
+
962
+ self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
963
+ self.trim = trim_silence
964
+ self.is_regression_task = is_regression_task
965
+ self.id2occurrence = {}
966
+ self.labels_occurrence = None
967
+
968
+ if not is_regression_task:
969
+ self.labels = labels if labels else self._get_label_set()
970
+ self.num_classes = len(self.labels) if self.labels is not None else 1
971
+ self.label2id, self.id2label = {}, {}
972
+ for label_id, label in enumerate(self.labels):
973
+ self.label2id[label] = label_id
974
+ self.id2label[label_id] = label
975
+ if cal_labels_occurrence:
976
+ self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
977
+ self.labels_occurrence.append(self.id2occurrence[label_id])
978
+
979
+ for idx in range(len(self.labels[:5])):
980
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
981
+ else:
982
+ self.labels = []
983
+ self.num_classes = 1
984
+
985
+ def _get_label_set(self):
986
+ labels = []
987
+ for sample in self.collection:
988
+ label_str = sample.label
989
+ if label_str:
990
+ label_str_list = label_str.split(self.delimiter) if self.delimiter else label_str.split()
991
+ labels.extend(label_str_list)
992
+ return sorted(set(labels))
993
+
994
+ def _label_str_to_tensor(self, label_str: str):
995
+ labels = label_str.split(self.delimiter) if self.delimiter else label_str.split()
996
+
997
+ if self.is_regression_task:
998
+ labels = [float(s) for s in labels]
999
+ labels = torch.tensor(labels).float()
1000
+ else:
1001
+ labels = [self.label2id[s] for s in labels]
1002
+ labels = torch.tensor(labels).long()
1003
+ return labels
1004
+
1005
+ def __len__(self):
1006
+ return len(self.collection)
1007
+
1008
+ def __getitem__(self, index):
1009
+ sample = self.collection[index]
1010
+
1011
+ offset = sample.offset
1012
+
1013
+ if offset is None:
1014
+ offset = 0
1015
+
1016
+ features = self.featurizer.process(
1017
+ sample.audio_file,
1018
+ offset=offset,
1019
+ duration=sample.duration,
1020
+ trim=self.trim,
1021
+ normalize_db=self.normalize_audio_db,
1022
+ )
1023
+
1024
+ f, fl = features, torch.tensor(features.size(0)).long()
1025
+
1026
+ t = self._label_str_to_tensor(sample.label)
1027
+
1028
+ tl = torch.tensor(t.size(0)).long()
1029
+
1030
+ return f, fl, t, tl
1031
+
1032
+ def _collate_fn(self, batch):
1033
+ return _speech_collate_fn(batch, pad_id=0)
1034
+
1035
+
1036
+ class TarredAudioToMultiLabelDataset(IterableDataset):
1037
+ """
1038
+ A similar Dataset to the AudioToMultiLabelDataset, but which loads tarred audio files.
1039
+
1040
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToSpeechLabelDataset),
1041
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
1042
+ contain the information for one audio file, including at least the transcript and name of the audio
1043
+ file within the tarball.
1044
+
1045
+ Valid formats for the audio_tar_filepaths argument include:
1046
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
1047
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
1048
+
1049
+ See the WebDataset documentation for more information about accepted data and input formats.
1050
+
1051
+ If using multiple processes the number of shards should be divisible by the number of workers to ensure an
1052
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
1053
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
1054
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
1055
+
1056
+ Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
1057
+ replaced by shuffle_n (int).
1058
+
1059
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
1060
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
1061
+
1062
+ Args:
1063
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
1064
+ string (can be brace-expandable).
1065
+ manifest_filepath (str): Path to the manifest.
1066
+ labels (list): Dataset parameter.
1067
+ List of target classes that can be output by the speaker recognition model.
1068
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
1069
+ See WebDataset documentation for more details.
1070
+ Defaults to 0.
1071
+ min_duration (float): Dataset parameter.
1072
+ All training files which have a duration less than min_duration
1073
+ are dropped. Note: Duration is read from the manifest JSON.
1074
+ Defaults to 0.1.
1075
+ max_duration (float): Dataset parameter.
1076
+ All training files which have a duration more than max_duration
1077
+ are dropped. Note: Duration is read from the manifest JSON.
1078
+ Defaults to None.
1079
+ trim(bool): Whether to use trim silence from beginning and end
1080
+ of audio signal using librosa.effects.trim().
1081
+ Defaults to False.
1082
+ window_length_in_sec (float): time length of window/slice (in seconds) # Pass this only for speaker recognition and VAD task
1083
+ shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task. in a batch # Pass this only for VAD task during inference.
1084
+ normalize_audio (bool): Whether to normalize audio signal. Defaults to False.
1085
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
1086
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
1087
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
1088
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
1089
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
1090
+ The benefit of replication is that it allows each node to sample data points from the entire
1091
+ dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
1092
+
1093
+ .. warning::
1094
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
1095
+ and therefore more than one node may sample the same tarfile, and even sample the same
1096
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
1097
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
1098
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
1099
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
1100
+ or test datasets.
1101
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
1102
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
1103
+ delimiter (Optional[str]): Delimiter to use when splitting the label string, default to None.
1104
+ normalize_audio_db (Optional[float]): normalize audio signal to a target db, default to None.
1105
+ """
1106
+
1107
+ def __init__(
1108
+ self,
1109
+ *,
1110
+ audio_tar_filepaths: Union[str, List[str]],
1111
+ manifest_filepath: Union[str, List[str]],
1112
+ sample_rate: int,
1113
+ labels: Optional[List[str]] = None,
1114
+ shuffle_n: int = 0,
1115
+ int_values: bool = False,
1116
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
1117
+ min_duration: Optional[float] = 0.1,
1118
+ max_duration: Optional[float] = None,
1119
+ trim_silence: bool = False,
1120
+ is_regression_task: bool = False,
1121
+ shard_strategy: str = "scatter",
1122
+ global_rank: int = 0,
1123
+ world_size: int = 0,
1124
+ delimiter: Optional[str] = None,
1125
+ normalize_audio_db: Optional[float] = None,
1126
+ ):
1127
+ super().__init__()
1128
+ if isinstance(manifest_filepath, str):
1129
+ manifest_filepath = manifest_filepath.split(',')
1130
+
1131
+ self.trim = trim_silence
1132
+ self.is_regression_task = is_regression_task
1133
+ self.delimiter = delimiter
1134
+ self.normalize_audio_db = normalize_audio_db
1135
+
1136
+ self.collection = collections.ASRSpeechLabel(
1137
+ manifests_files=manifest_filepath,
1138
+ min_duration=min_duration,
1139
+ max_duration=max_duration,
1140
+ is_regression_task=is_regression_task,
1141
+ index_by_file_id=True,
1142
+ )
1143
+ self.file_occurence = count_occurence(self.collection.mapping)
1144
+
1145
+ self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
1146
+
1147
+ if not is_regression_task:
1148
+ self.labels = labels if labels else self._get_label_set()
1149
+ self.num_classes = len(self.labels) if self.labels is not None else 1
1150
+ self.label2id, self.id2label = {}, {}
1151
+ for label_id, label in enumerate(self.labels):
1152
+ self.label2id[label] = label_id
1153
+ self.id2label[label_id] = label
1154
+ for idx in range(len(self.labels[:5])):
1155
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
1156
+ else:
1157
+ self.labels = []
1158
+ self.num_classes = 1
1159
+
1160
+ audio_tar_filepaths = expand_sharded_filepaths(
1161
+ sharded_filepaths=audio_tar_filepaths,
1162
+ shard_strategy=shard_strategy,
1163
+ world_size=world_size,
1164
+ global_rank=global_rank,
1165
+ )
1166
+ # Put together WebDataset
1167
+ self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)
1168
+
1169
+ if shuffle_n > 0:
1170
+ self._dataset = self._dataset.shuffle(shuffle_n)
1171
+ else:
1172
+ logging.info("WebDataset will not shuffle files within the tar files.")
1173
+
1174
+ self._dataset = (
1175
+ self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
1176
+ .to_tuple('audio', 'key')
1177
+ .pipe(self._filter)
1178
+ .map(f=self._build_sample)
1179
+ )
1180
+
1181
+ def _get_label_set(self):
1182
+ labels = []
1183
+ for sample in self.collection:
1184
+ label_str = sample.label
1185
+ if label_str:
1186
+ label_str_list = label_str.split(self.delimiter) if self.delimiter else label_str.split()
1187
+ labels.extend(label_str_list)
1188
+ return sorted(set(labels))
1189
+
1190
+ def _label_str_to_tensor(self, label_str: str):
1191
+ labels = label_str.split(self.delimiter) if self.delimiter else label_str.split()
1192
+
1193
+ if self.is_regression_task:
1194
+ labels = [float(s) for s in labels]
1195
+ labels = torch.tensor(labels).float()
1196
+ else:
1197
+ labels = [self.label2id[s] for s in labels]
1198
+ labels = torch.tensor(labels).long()
1199
+ return labels
1200
+
1201
+ def _filter(self, iterator):
1202
+ """This function is used to remove samples that have been filtered out by ASRSpeechLabel already.
1203
+ Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
1204
+ that was filtered out (e.g. for duration).
1205
+ Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
1206
+ which may make your code hang as one process will finish before the other.
1207
+ """
1208
+
1209
+ class TarredAudioFilter:
1210
+ def __init__(self, collection, file_occurence):
1211
+ self.iterator = iterator
1212
+ self.collection = collection
1213
+ self.file_occurence = file_occurence
1214
+ self._iterable = self._internal_generator()
1215
+
1216
+ def __iter__(self):
1217
+ self._iterable = self._internal_generator()
1218
+ return self
1219
+
1220
+ def __next__(self):
1221
+ try:
1222
+ values = next(self._iterable)
1223
+ except StopIteration:
1224
+ # reset generator
1225
+ self._iterable = self._internal_generator()
1226
+ values = next(self._iterable)
1227
+
1228
+ return values
1229
+
1230
+ def _internal_generator(self):
1231
+ """
1232
+ WebDataset requires an Iterator, but we require an iterable that yields 1-or-more
1233
+ values per value inside self.iterator.
1234
+
1235
+ Therefore wrap the iterator with a generator function that will yield 1-or-more
1236
+ values per sample in the iterator.
1237
+ """
1238
+ for _, tup in enumerate(self.iterator):
1239
+ audio_bytes, audio_filename = tup
1240
+
1241
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
1242
+ if audio_filename in self.file_occurence:
1243
+ for j in range(0, self.file_occurence[file_id]):
1244
+ if j == 0:
1245
+ audio_filename = file_id
1246
+ else:
1247
+ audio_filename = file_id + "-sub" + str(j)
1248
+ yield audio_bytes, audio_filename
1249
+
1250
+ return TarredAudioFilter(self.collection, self.file_occurence)
1251
+
1252
+ def _build_sample(self, tup):
1253
+ """Builds the training sample by combining the data from the WebDataset with the manifest info.
1254
+ """
1255
+ audio_bytes, audio_filename = tup
1256
+ # Grab manifest entry from self.collection
1257
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
1258
+
1259
+ manifest_idx = self.collection.mapping[file_id]
1260
+ manifest_entry = self.collection[manifest_idx]
1261
+
1262
+ offset = manifest_entry.offset
1263
+ if offset is None:
1264
+ offset = 0
1265
+
1266
+ # Convert audio bytes to IO stream for processing (for SoundFile to read)
1267
+ audio_filestream = io.BytesIO(audio_bytes)
1268
+ features = self.featurizer.process(
1269
+ audio_filestream,
1270
+ offset=offset,
1271
+ duration=manifest_entry.duration,
1272
+ trim=self.trim,
1273
+ normalize_db=self.normalize_audio_db,
1274
+ )
1275
+
1276
+ audio_filestream.close()
1277
+
1278
+ # Audio features
1279
+ f, fl = features, torch.tensor(features.shape[0]).long()
1280
+
1281
+ t = self._label_str_to_tensor(manifest_entry.label)
1282
+
1283
+ tl = torch.tensor(t.size(0)).long()
1284
+
1285
+ return f, fl, t, tl
1286
+
1287
+ def __iter__(self):
1288
+ return self._dataset.__iter__()
1289
+
1290
+ def __len__(self):
1291
+ return len(self.collection)
1292
+
1293
+ def _collate_fn(self, batch):
1294
+ return _speech_collate_fn(batch, pad_id=0)
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_label_dataset.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+
16
+ from omegaconf import DictConfig
17
+
18
+ from nemo.collections.asr.data import audio_to_label
19
+ from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list, get_chain_dataset
20
+ from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
21
+ from nemo.collections.common.data.dataset import ConcatDataset
22
+
23
+
24
+ def get_classification_label_dataset(featurizer, config: dict) -> audio_to_label.AudioToClassificationLabelDataset:
25
+ """
26
+ Instantiates a Classification AudioLabelDataset.
27
+
28
+ Args:
29
+ config: Config of the AudioToClassificationLabelDataset.
30
+
31
+ Returns:
32
+ An instance of AudioToClassificationLabelDataset.
33
+ """
34
+ dataset = audio_to_label.AudioToClassificationLabelDataset(
35
+ manifest_filepath=config['manifest_filepath'],
36
+ labels=config['labels'],
37
+ featurizer=featurizer,
38
+ max_duration=config.get('max_duration', None),
39
+ min_duration=config.get('min_duration', None),
40
+ trim=config.get('trim_silence', False),
41
+ is_regression_task=config.get('is_regression_task', False),
42
+ cal_labels_occurrence=config.get('cal_labels_occurrence', False),
43
+ )
44
+ return dataset
45
+
46
+
47
+ def get_speech_label_dataset(featurizer, config: dict) -> audio_to_label.AudioToSpeechLabelDataset:
48
+ """
49
+ Instantiates a Speech Label (e.g. VAD, speaker recognition) AudioLabelDataset.
50
+
51
+ Args:
52
+ config: Config of the AudioToSpeechLabelDataSet.
53
+
54
+ Returns:
55
+ An instance of AudioToSpeechLabelDataset.
56
+ """
57
+ dataset = audio_to_label.AudioToSpeechLabelDataset(
58
+ manifest_filepath=config['manifest_filepath'],
59
+ labels=config['labels'],
60
+ featurizer=featurizer,
61
+ max_duration=config.get('max_duration', None),
62
+ min_duration=config.get('min_duration', None),
63
+ trim=config.get('trim_silence', False),
64
+ window_length_in_sec=config.get('window_length_in_sec', 0.31),
65
+ shift_length_in_sec=config.get('shift_length_in_sec', 0.01),
66
+ normalize_audio=config.get('normalize_audio', False),
67
+ cal_labels_occurrence=config.get('cal_labels_occurrence', False),
68
+ )
69
+ return dataset
70
+
71
+
72
+ def get_tarred_classification_label_dataset(
73
+ featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int
74
+ ) -> audio_to_label.TarredAudioToClassificationLabelDataset:
75
+ """
76
+ Instantiates a Classification TarredAudioLabelDataset.
77
+
78
+ Args:
79
+ config: Config of the TarredAudioToClassificationLabelDataset.
80
+ shuffle_n: How many samples to look ahead and load to be shuffled.
81
+ See WebDataset documentation for more details.
82
+ global_rank: Global rank of this device.
83
+ world_size: Global world size in the training method.
84
+
85
+ Returns:
86
+ An instance of TarredAudioToClassificationLabelDataset.
87
+ """
88
+ tarred_audio_filepaths = config['tarred_audio_filepaths']
89
+ manifest_filepaths = config['manifest_filepath']
90
+ datasets = []
91
+ tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
92
+ manifest_filepaths = convert_to_config_list(manifest_filepaths)
93
+
94
+ bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets
95
+ if bucketing_weights:
96
+ for idx, weight in enumerate(bucketing_weights):
97
+ if not isinstance(weight, int) or weight <= 0:
98
+ raise ValueError(f"bucket weights must be positive integers")
99
+
100
+ if len(manifest_filepaths) != len(tarred_audio_filepaths):
101
+ raise ValueError(
102
+ f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
103
+ )
104
+
105
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
106
+ zip(tarred_audio_filepaths, manifest_filepaths)
107
+ ):
108
+ if len(tarred_audio_filepath) == 1:
109
+ tarred_audio_filepath = tarred_audio_filepath[0]
110
+ dataset = audio_to_label.TarredAudioToClassificationLabelDataset(
111
+ audio_tar_filepaths=tarred_audio_filepath,
112
+ manifest_filepath=manifest_filepath,
113
+ labels=config['labels'],
114
+ featurizer=featurizer,
115
+ shuffle_n=shuffle_n,
116
+ max_duration=config.get('max_duration', None),
117
+ min_duration=config.get('min_duration', None),
118
+ trim=config.get('trim_silence', False),
119
+ shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
120
+ global_rank=global_rank,
121
+ world_size=world_size,
122
+ is_regression_task=config.get('is_regression_task', False),
123
+ )
124
+
125
+ if bucketing_weights:
126
+ [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
127
+ else:
128
+ datasets.append(dataset)
129
+
130
+ return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
131
+
132
+
133
+ def get_concat_tarred_speech_label_dataset(
134
+ featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int,
135
+ ):
136
+ tarred_audio_filepaths = config['tarred_audio_filepaths']
137
+ manifest_filepaths = config['manifest_filepath']
138
+ datasets = []
139
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
140
+ zip(tarred_audio_filepaths, manifest_filepaths)
141
+ ):
142
+ conf = copy.deepcopy(config)
143
+ conf['manifest_filepath'] = manifest_filepath
144
+ conf['tarred_audio_filepaths'] = tarred_audio_filepath
145
+ dataset = get_tarred_speech_label_dataset(
146
+ config=conf, featurizer=featurizer, shuffle_n=shuffle_n, global_rank=global_rank, world_size=world_size,
147
+ )
148
+ datasets.append(dataset)
149
+
150
+ dataset = ConcatDataset(
151
+ datasets,
152
+ sampling_technique=config.get('concat_sampling_technique', 'temperature'),
153
+ sampling_temperature=config.get('concat_sampling_temperature', 5),
154
+ sampling_probabilities=config.get('concat_sampling_probabilities', None),
155
+ global_rank=global_rank,
156
+ world_size=world_size,
157
+ shuffle=config['shuffle'],
158
+ )
159
+ return dataset
160
+
161
+
162
+ def get_tarred_speech_label_dataset(
163
+ featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int,
164
+ ) -> audio_to_label.TarredAudioToSpeechLabelDataset:
165
+ """
166
+ InInstantiates a Speech Label (e.g. VAD, speaker recognition) TarredAudioLabelDataset.
167
+
168
+ Args:
169
+ config: Config of the TarredAudioToSpeechLabelDataset.
170
+ shuffle_n: How many samples to look ahead and load to be shuffled.
171
+ See WebDataset documentation for more details.
172
+ global_rank: Global rank of this device.
173
+ world_size: Global world size in the training method.
174
+
175
+ Returns:
176
+ An instance of TarredAudioToSpeechLabelDataset.
177
+ """
178
+ tarred_audio_filepaths = config['tarred_audio_filepaths']
179
+ manifest_filepaths = config['manifest_filepath']
180
+ datasets = []
181
+ tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
182
+ manifest_filepaths = convert_to_config_list(manifest_filepaths)
183
+
184
+ bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets
185
+ if bucketing_weights:
186
+ for idx, weight in enumerate(bucketing_weights):
187
+ if not isinstance(weight, int) or weight <= 0:
188
+ raise ValueError(f"bucket weights must be positive integers")
189
+
190
+ if len(manifest_filepaths) != len(tarred_audio_filepaths):
191
+ raise ValueError(
192
+ f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
193
+ )
194
+
195
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
196
+ zip(tarred_audio_filepaths, manifest_filepaths)
197
+ ):
198
+ if len(tarred_audio_filepath) == 1:
199
+ tarred_audio_filepath = tarred_audio_filepath[0]
200
+ dataset = audio_to_label.TarredAudioToSpeechLabelDataset(
201
+ audio_tar_filepaths=tarred_audio_filepath,
202
+ manifest_filepath=manifest_filepath,
203
+ labels=config['labels'],
204
+ featurizer=featurizer,
205
+ shuffle_n=shuffle_n,
206
+ max_duration=config.get('max_duration', None),
207
+ min_duration=config.get('min_duration', None),
208
+ trim=config.get('trim_silence', False),
209
+ window_length_in_sec=config.get('window_length_in_sec', 8),
210
+ shift_length_in_sec=config.get('shift_length_in_sec', 0.075),
211
+ normalize_audio=config.get('normalize_audio', False),
212
+ shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
213
+ global_rank=global_rank,
214
+ world_size=world_size,
215
+ )
216
+
217
+ if bucketing_weights:
218
+ [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
219
+ else:
220
+ datasets.append(dataset)
221
+
222
+ return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
223
+
224
+
225
+ def get_audio_multi_label_dataset(cfg: DictConfig) -> audio_to_label.AudioToMultiLabelDataset:
226
+ if "augmentor" in cfg:
227
+ augmentor = process_augmentations(cfg.augmentor)
228
+ else:
229
+ augmentor = None
230
+
231
+ dataset = audio_to_label.AudioToMultiLabelDataset(
232
+ manifest_filepath=cfg.get("manifest_filepath"),
233
+ sample_rate=cfg.get("sample_rate"),
234
+ labels=cfg.get("labels", None),
235
+ int_values=cfg.get("int_values", False),
236
+ augmentor=augmentor,
237
+ min_duration=cfg.get("min_duration", None),
238
+ max_duration=cfg.get("max_duration", None),
239
+ trim_silence=cfg.get("trim_silence", False),
240
+ is_regression_task=cfg.get("is_regression_task", False),
241
+ cal_labels_occurrence=cfg.get("cal_labels_occurrence", False),
242
+ delimiter=cfg.get("delimiter", None),
243
+ normalize_audio_db=cfg.get("normalize_audio_db", None),
244
+ )
245
+ return dataset
246
+
247
+
248
+ def get_tarred_audio_multi_label_dataset(
249
+ cfg: DictConfig, shuffle_n: int, global_rank: int, world_size: int
250
+ ) -> audio_to_label.TarredAudioToMultiLabelDataset:
251
+
252
+ if "augmentor" in cfg:
253
+ augmentor = process_augmentations(cfg.augmentor)
254
+ else:
255
+ augmentor = None
256
+
257
+ tarred_audio_filepaths = cfg['tarred_audio_filepaths']
258
+ manifest_filepaths = cfg['manifest_filepath']
259
+ datasets = []
260
+ tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
261
+ manifest_filepaths = convert_to_config_list(manifest_filepaths)
262
+
263
+ bucketing_weights = cfg.get('bucketing_weights', None) # For upsampling buckets
264
+ if bucketing_weights:
265
+ for idx, weight in enumerate(bucketing_weights):
266
+ if not isinstance(weight, int) or weight <= 0:
267
+ raise ValueError(f"bucket weights must be positive integers")
268
+
269
+ if len(manifest_filepaths) != len(tarred_audio_filepaths):
270
+ raise ValueError(
271
+ f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
272
+ )
273
+
274
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
275
+ zip(tarred_audio_filepaths, manifest_filepaths)
276
+ ):
277
+ if len(tarred_audio_filepath) == 1:
278
+ tarred_audio_filepath = tarred_audio_filepath[0]
279
+
280
+ dataset = audio_to_label.TarredAudioToMultiLabelDataset(
281
+ audio_tar_filepaths=tarred_audio_filepath,
282
+ manifest_filepath=manifest_filepath,
283
+ sample_rate=cfg["sample_rate"],
284
+ labels=cfg['labels'],
285
+ shuffle_n=shuffle_n,
286
+ int_values=cfg.get("int_values", False),
287
+ augmentor=augmentor,
288
+ min_duration=cfg.get('min_duration', None),
289
+ max_duration=cfg.get('max_duration', None),
290
+ trim_silence=cfg.get('trim_silence', False),
291
+ is_regression_task=cfg.get('is_regression_task', False),
292
+ delimiter=cfg.get("delimiter", None),
293
+ shard_strategy=cfg.get('tarred_shard_strategy', 'scatter'),
294
+ global_rank=global_rank,
295
+ world_size=world_size,
296
+ normalize_audio_db=cfg.get("normalize_audio_db", None),
297
+ )
298
+
299
+ if bucketing_weights:
300
+ [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
301
+ else:
302
+ datasets.append(dataset)
303
+
304
+ return get_chain_dataset(datasets=datasets, ds_config=cfg, rank=global_rank)
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text.py ADDED
@@ -0,0 +1,1366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import json
16
+ import math
17
+ import multiprocessing
18
+ import os
19
+ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
20
+
21
+ import braceexpand
22
+ import numpy as np
23
+ import torch
24
+ import webdataset as wd
25
+ from torch.utils.data import ChainDataset
26
+ from tqdm import tqdm
27
+
28
+ from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
29
+ from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
30
+ from nemo.collections.common import tokenizers
31
+ from nemo.collections.common.parts.preprocessing import collections, parsers
32
+ from nemo.core.classes import Dataset, IterableDataset
33
+ from nemo.core.neural_types import *
34
+ from nemo.utils import logging
35
+ from nemo.utils.data_utils import (
36
+ DataStoreObject,
37
+ datastore_object_get,
38
+ datastore_path_to_webdataset_url,
39
+ is_datastore_cache_shared,
40
+ is_datastore_path,
41
+ is_tarred_path,
42
+ )
43
+ from nemo.utils.get_rank import is_global_rank_zero
44
+
45
+ __all__ = [
46
+ 'AudioToCharDataset',
47
+ 'AudioToBPEDataset',
48
+ 'TarredAudioToCharDataset',
49
+ 'TarredAudioToBPEDataset',
50
+ ]
51
+
52
+
53
+ def _speech_collate_fn(batch, pad_id):
54
+ """collate batch of audio sig, audio len, tokens, tokens len
55
+ Args:
56
+ batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
57
+ LongTensor): A tuple of tuples of signal, signal lengths,
58
+ encoded tokens, and encoded tokens length. This collate func
59
+ assumes the signals are 1d torch tensors (i.e. mono audio).
60
+ """
61
+ packed_batch = list(zip(*batch))
62
+ if len(packed_batch) == 5:
63
+ _, audio_lengths, _, tokens_lengths, sample_ids = packed_batch
64
+ elif len(packed_batch) == 4:
65
+ sample_ids = None
66
+ _, audio_lengths, _, tokens_lengths = packed_batch
67
+ else:
68
+ raise ValueError("Expects 4 or 5 tensors in the batch!")
69
+ max_audio_len = 0
70
+ has_audio = audio_lengths[0] is not None
71
+ if has_audio:
72
+ max_audio_len = max(audio_lengths).item()
73
+ max_tokens_len = max(tokens_lengths).item()
74
+
75
+ audio_signal, tokens = [], []
76
+ for b in batch:
77
+ if len(b) == 5:
78
+ sig, sig_len, tokens_i, tokens_i_len, _ = b
79
+ else:
80
+ sig, sig_len, tokens_i, tokens_i_len = b
81
+ if has_audio:
82
+ sig_len = sig_len.item()
83
+ if sig_len < max_audio_len:
84
+ pad = (0, max_audio_len - sig_len)
85
+ sig = torch.nn.functional.pad(sig, pad)
86
+ audio_signal.append(sig)
87
+ tokens_i_len = tokens_i_len.item()
88
+ if tokens_i_len < max_tokens_len:
89
+ pad = (0, max_tokens_len - tokens_i_len)
90
+ tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
91
+ tokens.append(tokens_i)
92
+
93
+ if has_audio:
94
+ audio_signal = torch.stack(audio_signal)
95
+ audio_lengths = torch.stack(audio_lengths)
96
+ else:
97
+ audio_signal, audio_lengths = None, None
98
+ tokens = torch.stack(tokens)
99
+ tokens_lengths = torch.stack(tokens_lengths)
100
+ if sample_ids is None:
101
+ return audio_signal, audio_lengths, tokens, tokens_lengths
102
+ else:
103
+ sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
104
+ return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids
105
+
106
+
107
+ class ASRManifestProcessor:
108
+ """
109
+ Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds).
110
+ Each new line is a different sample. Example below:
111
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
112
+ ...
113
+ {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
114
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
115
+ Args:
116
+ manifest_filepath: Path to manifest json as described above. Can be comma-separated paths.
117
+ parser: Str for a language specific preprocessor or a callable.
118
+ max_duration: If audio exceeds this length, do not include in dataset.
119
+ min_duration: If audio is less than this length, do not include in dataset.
120
+ max_utts: Limit number of utterances.
121
+ bos_id: Id of beginning of sequence symbol to append if not None.
122
+ eos_id: Id of end of sequence symbol to append if not None.
123
+ pad_id: Id of pad symbol. Defaults to 0.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ manifest_filepath: str,
129
+ parser: Union[str, Callable],
130
+ max_duration: Optional[float] = None,
131
+ min_duration: Optional[float] = None,
132
+ max_utts: int = 0,
133
+ bos_id: Optional[int] = None,
134
+ eos_id: Optional[int] = None,
135
+ pad_id: int = 0,
136
+ index_by_file_id: bool = False,
137
+ ):
138
+ self.parser = parser
139
+
140
+ self.collection = collections.ASRAudioText(
141
+ manifests_files=manifest_filepath,
142
+ parser=parser,
143
+ min_duration=min_duration,
144
+ max_duration=max_duration,
145
+ max_number=max_utts,
146
+ index_by_file_id=index_by_file_id,
147
+ )
148
+
149
+ self.eos_id = eos_id
150
+ self.bos_id = bos_id
151
+ self.pad_id = pad_id
152
+
153
+ def process_text_by_id(self, index: int) -> Tuple[List[int], int]:
154
+ sample = self.collection[index]
155
+ return self.process_text_by_sample(sample)
156
+
157
+ def process_text_by_file_id(self, file_id: str) -> Tuple[List[int], int]:
158
+ manifest_idx = self.collection.mapping[file_id][0]
159
+ sample = self.collection[manifest_idx]
160
+ return self.process_text_by_sample(sample)
161
+
162
+ def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -> Tuple[List[int], int]:
163
+ t, tl = sample.text_tokens, len(sample.text_tokens)
164
+
165
+ if self.bos_id is not None:
166
+ t = [self.bos_id] + t
167
+ tl += 1
168
+ if self.eos_id is not None:
169
+ t = t + [self.eos_id]
170
+ tl += 1
171
+
172
+ return t, tl
173
+
174
+
175
+ def expand_sharded_filepaths(sharded_filepaths, shard_strategy: str, world_size: int, global_rank: int):
176
+ valid_shard_strategies = ['scatter', 'replicate']
177
+ if shard_strategy not in valid_shard_strategies:
178
+ raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")
179
+
180
+ if isinstance(sharded_filepaths, str):
181
+ # Replace '(' and '[' with '{'
182
+ brace_keys_open = ['(', '[', '<', '_OP_']
183
+ for bkey in brace_keys_open:
184
+ if bkey in sharded_filepaths:
185
+ sharded_filepaths = sharded_filepaths.replace(bkey, "{")
186
+
187
+ # Replace ')' and ']' with '}'
188
+ brace_keys_close = [')', ']', '>', '_CL_']
189
+ for bkey in brace_keys_close:
190
+ if bkey in sharded_filepaths:
191
+ sharded_filepaths = sharded_filepaths.replace(bkey, "}")
192
+
193
+ if isinstance(sharded_filepaths, str):
194
+ # Brace expand, set escape=False for Windows compatibility
195
+ sharded_filepaths = list(braceexpand.braceexpand(sharded_filepaths, escape=False))
196
+
197
+ # Expand store paths into WebDataset URLs
198
+ sharded_filepaths = [
199
+ datastore_path_to_webdataset_url(p) if is_datastore_path(p) and is_tarred_path(p) else p
200
+ for p in sharded_filepaths
201
+ ]
202
+
203
+ # Check for distributed and partition shards accordingly
204
+ if world_size > 1:
205
+ if shard_strategy == 'scatter':
206
+ logging.info("All tarred dataset shards will be scattered evenly across all nodes.")
207
+
208
+ if len(sharded_filepaths) % world_size != 0:
209
+ logging.warning(
210
+ f"Number of shards in tarred dataset ({len(sharded_filepaths)}) is not divisible "
211
+ f"by number of distributed workers ({world_size})."
212
+ )
213
+
214
+ begin_idx = (len(sharded_filepaths) // world_size) * global_rank
215
+ end_idx = begin_idx + len(sharded_filepaths) // world_size
216
+ sharded_filepaths = sharded_filepaths[begin_idx:end_idx]
217
+ logging.info(
218
+ "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
219
+ )
220
+
221
+ elif shard_strategy == 'replicate':
222
+ logging.info("All tarred dataset shards will be replicated across all nodes.")
223
+ else:
224
+ raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")
225
+
226
+ return sharded_filepaths
227
+
228
+
229
+ def cache_datastore_manifests(
230
+ manifest_filepaths: Union[str, List[str]],
231
+ cache_audio: bool = False,
232
+ shared_cache: Optional[bool] = None,
233
+ num_workers: Optional[int] = None,
234
+ max_num_workers: int = 20,
235
+ ):
236
+ """Cache manifests and audio from an object store.
237
+ It is assumed that remote manifests are using relative paths.
238
+
239
+ Args:
240
+ manifest_filepaths: list of paths to manifest files (list of strings or a string with `,` as separator)
241
+ cache_audio: If True, audio from manifest will also be cached
242
+ shared_cache: Optional, True if cache is shared across all nodes
243
+ num_workers: Optional, number of workers to be used for download
244
+ max_num_workers: max number of workers to be used for download, used when setting num_workers automatically
245
+ """
246
+ if isinstance(manifest_filepaths, str):
247
+ manifest_filepaths = manifest_filepaths.split(',')
248
+
249
+ num_datastore_manifests = sum([is_datastore_path(f) for f in manifest_filepaths])
250
+
251
+ if num_datastore_manifests > 0:
252
+ # Local utility function
253
+ def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
254
+ """Cache manifests and audio data from object store.
255
+ """
256
+ # Determine the number of workers to use
257
+ if num_workers is None:
258
+ num_workers = os.cpu_count() - 1
259
+ num_workers = min(num_workers, max_num_workers)
260
+
261
+ # Process each manifest file
262
+ for manifest_file in manifest_filepaths:
263
+ # If manifest is on a data store, then cache it.
264
+ # Otherwise, nothing to do.
265
+ if is_datastore_path(manifest_file):
266
+ logging.info('Cache manifest file: %s', manifest_file)
267
+ cached_manifest_file = DataStoreObject(manifest_file).get()
268
+ logging.info('Cached at: %s', str(cached_manifest_file))
269
+
270
+ if cache_audio:
271
+ # Each audio file from manifest will be cached.
272
+ logging.info('Cache audio from manifest file: %s', manifest_file)
273
+ # Assumes that manifest is using relative paths
274
+ manifest_dir = os.path.dirname(manifest_file)
275
+ # Prepare all store objects
276
+ audio_objects = []
277
+ with open(cached_manifest_file, 'r') as f:
278
+ for line in f:
279
+ item = json.loads(line)
280
+ store_path = os.path.join(manifest_dir, item['audio_filepath'])
281
+ audio_objects.append(DataStoreObject(store_path=store_path))
282
+
283
+ if num_workers is not None and num_workers > 1:
284
+ logging.debug('Using multiprocessing with num_workers: %d.', num_workers)
285
+ with multiprocessing.Pool(processes=num_workers) as p:
286
+ result = list(
287
+ tqdm(p.imap(datastore_object_get, audio_objects), total=len(audio_objects))
288
+ )
289
+ else:
290
+ logging.debug('Using a single process.')
291
+ result = []
292
+ for audio_object in tqdm(audio_objects):
293
+ result.append(audio_object.get() is not None)
294
+
295
+ if not all(result):
296
+ raise RuntimeError('Some files not downloaded successfully')
297
+ logging.info('Caching complete')
298
+
299
+ else:
300
+ # Nothing to do here
301
+ logging.debug('Manifest is not on a data store: %s', manifest_file)
302
+
303
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
304
+ logging.debug('Distributed environment is available and initialized.')
305
+
306
+ # Handle distributed environment
307
+ if shared_cache is None:
308
+ shared_cache = is_datastore_cache_shared()
309
+
310
+ if shared_cache:
311
+ logging.debug('Cache is shared among nodes, cache data on global rank zero.')
312
+ is_rank_zero = is_global_rank_zero()
313
+ else:
314
+ logging.debug('Cache is not shared among nodes, cache data on local rank zero.')
315
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
316
+ is_rank_zero = local_rank == 0
317
+
318
+ if is_rank_zero:
319
+ logging.info('Cache data from %s rank 0', 'global' if shared_cache else 'local')
320
+ cache_data(
321
+ manifest_filepaths=manifest_filepaths,
322
+ cache_audio=cache_audio,
323
+ num_workers=num_workers,
324
+ max_num_workers=max_num_workers,
325
+ )
326
+ logging.debug('Reached barrier')
327
+ torch.distributed.barrier()
328
+
329
+ elif is_global_rank_zero():
330
+ # Handle non-distributed environment, e.g., if running on a single GPU
331
+ logging.warning(
332
+ 'Torch distributed is not initialized and caching may be prone to data race conditions. '
333
+ 'Now caching data from global rank 0. If there are other ranks and they pass this '
334
+ 'before rank 0, errors might result.'
335
+ )
336
+ cache_data(
337
+ manifest_filepaths=manifest_filepaths,
338
+ cache_audio=cache_audio,
339
+ num_workers=num_workers,
340
+ max_num_workers=max_num_workers,
341
+ )
342
+ else:
343
+ raise RuntimeError(
344
+ 'Torch distributed is not initialized and caching on nodes other than global rank zero is disabled '
345
+ 'to avoid race condition between different ranks. To ensure distributed environment is '
346
+ 'initialized, please update data config to use `defer_setup = True`.'
347
+ )
348
+
349
+
350
+ """Optionally expand / shard the list of manifests
351
+ This is made to use the same notation as the sharded audio files
352
+
353
+ Args:
354
+ manifest_filepaths: list of manifest files (the sharded notation)
355
+ shard_strategy: scatter or replicate (scatter by default)
356
+ shard_manifests: bool, if False, no sharding / manifest filepath expansion will be attempted
357
+ global_rank: int, the rank of this worker
358
+ world_size: int, total number of workers
359
+ """
360
+
361
+
362
+ def shard_manifests_if_needed(
363
+ manifest_filepaths: Union[str, List[str]],
364
+ shard_strategy: str,
365
+ shard_manifests: bool,
366
+ global_rank: int,
367
+ world_size: int,
368
+ ):
369
+ if shard_manifests:
370
+ if not torch.distributed.is_available():
371
+ logging.warning("Not running in torch.distributed mode. Manifest sharding not available")
372
+ return manifest_filepaths
373
+
374
+ if not torch.distributed.is_initialized():
375
+ logging.warning(
376
+ 'Manifest sharding was requested but torch.distributed is not initialized '
377
+ 'Did you intend to set the defer_setup flag?'
378
+ )
379
+ return manifest_filepaths
380
+
381
+ manifest_filepaths = expand_sharded_filepaths(
382
+ sharded_filepaths=manifest_filepaths,
383
+ shard_strategy=shard_strategy,
384
+ world_size=world_size,
385
+ global_rank=global_rank,
386
+ )
387
+
388
+ return manifest_filepaths
389
+
390
+
391
+ class _AudioTextDataset(Dataset):
392
+ """
393
+ Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds).
394
+ Each new line is a different sample. Example below:
395
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
396
+ ...
397
+ {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
398
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
399
+ Args:
400
+ manifest_filepath: Path to manifest json as described above. Can be comma-separated paths.
401
+ parser: Str for a language specific preprocessor or a callable.
402
+ sample_rate (int): Sample rate to resample loaded audio to
403
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
404
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded
405
+ audio
406
+ max_duration: If audio exceeds this length, do not include in dataset
407
+ min_duration: If audio is less than this length, do not include in dataset
408
+ max_utts: Limit number of utterances
409
+ trim: whether or not to trim silence. Defaults to False
410
+ bos_id: Id of beginning of sequence symbol to append if not None
411
+ eos_id: Id of end of sequence symbol to append if not None
412
+ pad_id: Id of pad symbol. Defaults to 0
413
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
414
+ channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
415
+ """
416
+
417
+ @property
418
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
419
+ """Returns definitions of module output ports.
420
+ """
421
+ return {
422
+ 'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
423
+ 'a_sig_length': NeuralType(tuple('B'), LengthsType()),
424
+ 'transcripts': NeuralType(('B', 'T'), LabelsType()),
425
+ 'transcript_length': NeuralType(tuple('B'), LengthsType()),
426
+ 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
427
+ }
428
+
429
+ def __init__(
430
+ self,
431
+ manifest_filepath: str,
432
+ parser: Union[str, Callable],
433
+ sample_rate: int,
434
+ int_values: bool = False,
435
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
436
+ max_duration: Optional[int] = None,
437
+ min_duration: Optional[int] = None,
438
+ max_utts: int = 0,
439
+ trim: bool = False,
440
+ bos_id: Optional[int] = None,
441
+ eos_id: Optional[int] = None,
442
+ pad_id: int = 0,
443
+ return_sample_id: bool = False,
444
+ channel_selector: Optional[ChannelSelectorType] = None,
445
+ ):
446
+ if type(manifest_filepath) == str:
447
+ manifest_filepath = manifest_filepath.split(",")
448
+
449
+ # If necessary, cache manifests and audio from object store
450
+ cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True)
451
+
452
+ self.manifest_processor = ASRManifestProcessor(
453
+ manifest_filepath=manifest_filepath,
454
+ parser=parser,
455
+ max_duration=max_duration,
456
+ min_duration=min_duration,
457
+ max_utts=max_utts,
458
+ bos_id=bos_id,
459
+ eos_id=eos_id,
460
+ pad_id=pad_id,
461
+ )
462
+ self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
463
+ self.trim = trim
464
+ self.return_sample_id = return_sample_id
465
+ self.channel_selector = channel_selector
466
+
467
+ def get_manifest_sample(self, sample_id):
468
+ return self.manifest_processor.collection[sample_id]
469
+
470
+ def __getitem__(self, index):
471
+ sample = self.manifest_processor.collection[index]
472
+ offset = sample.offset
473
+
474
+ if offset is None:
475
+ offset = 0
476
+
477
+ features = self.featurizer.process(
478
+ sample.audio_file,
479
+ offset=offset,
480
+ duration=sample.duration,
481
+ trim=self.trim,
482
+ orig_sr=sample.orig_sr,
483
+ channel_selector=self.channel_selector,
484
+ )
485
+ f, fl = features, torch.tensor(features.shape[0]).long()
486
+
487
+ t, tl = self.manifest_processor.process_text_by_sample(sample=sample)
488
+
489
+ if self.return_sample_id:
490
+ output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
491
+ else:
492
+ output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
493
+
494
+ return output
495
+
496
+ def __len__(self):
497
+ return len(self.manifest_processor.collection)
498
+
499
+ def _collate_fn(self, batch):
500
+ return _speech_collate_fn(batch, pad_id=self.manifest_processor.pad_id)
501
+
502
+
503
+ class AudioToCharDataset(_AudioTextDataset):
504
+ """
505
+ Dataset that loads tensors via a json file containing paths to audio
506
+ files, transcripts, and durations (in seconds). Each new line is a
507
+ different sample. Example below:
508
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath":
509
+ "/path/to/audio.txt", "duration": 23.147}
510
+ ...
511
+ {"audio_filepath": "/path/to/audio.wav", "text": "the
512
+ transcription", "offset": 301.75, "duration": 0.82, "utt":
513
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
514
+
515
+ Args:
516
+ manifest_filepath: Path to manifest json as described above. Can
517
+ be comma-separated paths.
518
+ labels: String containing all the possible characters to map to
519
+ sample_rate (int): Sample rate to resample loaded audio to
520
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
521
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
522
+ object used to augment loaded audio
523
+ max_duration: If audio exceeds this length, do not include in dataset
524
+ min_duration: If audio is less than this length, do not include
525
+ in dataset
526
+ max_utts: Limit number of utterances
527
+ blank_index: blank character index, default = -1
528
+ unk_index: unk_character index, default = -1
529
+ normalize: whether to normalize transcript text (default): True
530
+ bos_id: Id of beginning of sequence symbol to append if not None
531
+ eos_id: Id of end of sequence symbol to append if not None
532
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
533
+ channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
534
+ """
535
+
536
+ @property
537
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
538
+ """Returns definitions of module output ports.
539
+ """
540
+ return {
541
+ 'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
542
+ 'a_sig_length': NeuralType(tuple('B'), LengthsType()),
543
+ 'transcripts': NeuralType(('B', 'T'), LabelsType()),
544
+ 'transcript_length': NeuralType(tuple('B'), LengthsType()),
545
+ 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
546
+ }
547
+
548
+ def __init__(
549
+ self,
550
+ manifest_filepath: str,
551
+ labels: Union[str, List[str]],
552
+ sample_rate: int,
553
+ int_values: bool = False,
554
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
555
+ max_duration: Optional[float] = None,
556
+ min_duration: Optional[float] = None,
557
+ max_utts: int = 0,
558
+ blank_index: int = -1,
559
+ unk_index: int = -1,
560
+ normalize: bool = True,
561
+ trim: bool = False,
562
+ bos_id: Optional[int] = None,
563
+ eos_id: Optional[int] = None,
564
+ pad_id: int = 0,
565
+ parser: Union[str, Callable] = 'en',
566
+ return_sample_id: bool = False,
567
+ channel_selector: Optional[ChannelSelectorType] = None,
568
+ ):
569
+ self.labels = labels
570
+
571
+ parser = parsers.make_parser(
572
+ labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
573
+ )
574
+
575
+ super().__init__(
576
+ manifest_filepath=manifest_filepath,
577
+ parser=parser,
578
+ sample_rate=sample_rate,
579
+ int_values=int_values,
580
+ augmentor=augmentor,
581
+ max_duration=max_duration,
582
+ min_duration=min_duration,
583
+ max_utts=max_utts,
584
+ trim=trim,
585
+ bos_id=bos_id,
586
+ eos_id=eos_id,
587
+ pad_id=pad_id,
588
+ return_sample_id=return_sample_id,
589
+ channel_selector=channel_selector,
590
+ )
591
+
592
+
593
+ class AudioToBPEDataset(_AudioTextDataset):
594
+ """
595
+ Dataset that loads tensors via a json file containing paths to audio
596
+ files, transcripts, and durations (in seconds). Each new line is a
597
+ different sample. Example below:
598
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath":
599
+ "/path/to/audio.txt", "duration": 23.147}
600
+ ...
601
+ {"audio_filepath": "/path/to/audio.wav", "text": "the
602
+ transcription", "offset": 301.75, "duration": 0.82, "utt":
603
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
604
+
605
+ In practice, the dataset and manifest used for character encoding and byte pair encoding
606
+ are exactly the same. The only difference lies in how the dataset tokenizes the text in
607
+ the manifest.
608
+
609
+ Args:
610
+ manifest_filepath: Path to manifest json as described above. Can
611
+ be comma-separated paths.
612
+ tokenizer: A subclass of the Tokenizer wrapper found in the common collection,
613
+ nemo.collections.common.tokenizers.TokenizerSpec. ASR Models support a subset of
614
+ all available tokenizers.
615
+ sample_rate (int): Sample rate to resample loaded audio to
616
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
617
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
618
+ object used to augment loaded audio
619
+ max_duration: If audio exceeds this length, do not include in dataset
620
+ min_duration: If audio is less than this length, do not include
621
+ in dataset
622
+ max_utts: Limit number of utterances
623
+ trim: Whether to trim silence segments
624
+ use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
625
+ tokens to beginning and ending of speech respectively.
626
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
627
+ channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
628
+ """
629
+
630
+ @property
631
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
632
+ """Returns definitions of module output ports.
633
+ """
634
+ return {
635
+ 'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
636
+ 'a_sig_length': NeuralType(tuple('B'), LengthsType()),
637
+ 'transcripts': NeuralType(('B', 'T'), LabelsType()),
638
+ 'transcript_length': NeuralType(tuple('B'), LengthsType()),
639
+ 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
640
+ }
641
+
642
+ def __init__(
643
+ self,
644
+ manifest_filepath: str,
645
+ tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
646
+ sample_rate: int,
647
+ int_values: bool = False,
648
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
649
+ max_duration: Optional[int] = None,
650
+ min_duration: Optional[int] = None,
651
+ max_utts: int = 0,
652
+ trim: bool = False,
653
+ use_start_end_token: bool = True,
654
+ return_sample_id: bool = False,
655
+ channel_selector: Optional[ChannelSelectorType] = None,
656
+ ):
657
+ if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
658
+ bos_id = tokenizer.bos_id
659
+ else:
660
+ bos_id = None
661
+
662
+ if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0:
663
+ eos_id = tokenizer.eos_id
664
+ else:
665
+ eos_id = None
666
+
667
+ if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0:
668
+ pad_id = tokenizer.pad_id
669
+ else:
670
+ pad_id = 0
671
+
672
+ class TokenizerWrapper:
673
+ def __init__(self, tokenizer):
674
+ if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
675
+ self.is_aggregate = True
676
+ else:
677
+ self.is_aggregate = False
678
+ self._tokenizer = tokenizer
679
+
680
+ def __call__(self, *args):
681
+ if isinstance(args[0], List) and self.is_aggregate:
682
+ t = []
683
+ for span in args[0]:
684
+ t.extend(self._tokenizer.text_to_ids(span['str'], span['lang']))
685
+ return t
686
+
687
+ t = self._tokenizer.text_to_ids(*args)
688
+ return t
689
+
690
+ super().__init__(
691
+ manifest_filepath=manifest_filepath,
692
+ parser=TokenizerWrapper(tokenizer),
693
+ sample_rate=sample_rate,
694
+ int_values=int_values,
695
+ augmentor=augmentor,
696
+ max_duration=max_duration,
697
+ min_duration=min_duration,
698
+ max_utts=max_utts,
699
+ bos_id=bos_id,
700
+ eos_id=eos_id,
701
+ pad_id=pad_id,
702
+ trim=trim,
703
+ return_sample_id=return_sample_id,
704
+ channel_selector=channel_selector,
705
+ )
706
+
707
+
708
+ class _TarredAudioToTextDataset(IterableDataset):
709
+ """
710
+ A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files.
711
+
712
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset),
713
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
714
+ contain the information for one audio file, including at least the transcript and name of the audio
715
+ file within the tarball.
716
+
717
+ Valid formats for the audio_tar_filepaths argument include:
718
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
719
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
720
+
721
+ Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
722
+ This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
723
+ Supported opening braces - { <=> (, [, < and the special tag _OP_.
724
+ Supported closing braces - } <=> ), ], > and the special tag _CL_.
725
+ For SLURM based tasks, we suggest the use of the special tags for ease of use.
726
+
727
+ See the WebDataset documentation for more information about accepted data and input formats.
728
+
729
+ If using multiple workers the number of shards should be divisible by world_size to ensure an
730
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
731
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
732
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
733
+
734
+ Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been
735
+ replaced by shuffle_n (int).
736
+
737
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
738
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
739
+
740
+ Args:
741
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
742
+ string (can be brace-expandable).
743
+ manifest_filepath (str): Path to the manifest.
744
+ parser (callable): A callable which is used to pre-process the text output.
745
+ sample_rate (int): Sample rate to resample loaded audio to
746
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
747
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
748
+ object used to augment loaded audio
749
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
750
+ See WebDataset documentation for more details.
751
+ Defaults to 0.
752
+ min_duration (float): Dataset parameter.
753
+ All training files which have a duration less than min_duration
754
+ are dropped. Note: Duration is read from the manifest JSON.
755
+ Defaults to 0.1.
756
+ max_duration (float): Dataset parameter.
757
+ All training files which have a duration more than max_duration
758
+ are dropped. Note: Duration is read from the manifest JSON.
759
+ Defaults to None.
760
+ blank_index (int): Blank character index, defaults to -1.
761
+ unk_index (int): Unknown character index, defaults to -1.
762
+ normalize (bool): Dataset parameter.
763
+ Whether to use automatic text cleaning.
764
+ It is highly recommended to manually clean text for best results.
765
+ Defaults to True.
766
+ trim (bool): Whether to use trim silence from beginning and end
767
+ of audio signal using librosa.effects.trim().
768
+ Defaults to False.
769
+ bos_id (id): Dataset parameter.
770
+ Beginning of string symbol id used for seq2seq models.
771
+ Defaults to None.
772
+ eos_id (id): Dataset parameter.
773
+ End of string symbol id used for seq2seq models.
774
+ Defaults to None.
775
+ pad_id (id): Token used to pad when collating samples in batches.
776
+ If this is None, pads using 0s.
777
+ Defaults to None.
778
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
779
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
780
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
781
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
782
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
783
+ The benefit of replication is that it allows each node to sample data points from the entire
784
+ dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
785
+
786
+ .. warning::
787
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
788
+ and therefore more than one node may sample the same tarfile, and even sample the same
789
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
790
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
791
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
792
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
793
+ or test datasets.
794
+ shard_manifests (bool): Whether or not to try / shard manifests. Defaults to False.
795
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
796
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
797
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
798
+ """
799
+
800
+ def __init__(
801
+ self,
802
+ audio_tar_filepaths: Union[str, List[str]],
803
+ manifest_filepath: str,
804
+ parser: Callable,
805
+ sample_rate: int,
806
+ int_values: bool = False,
807
+ augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
808
+ shuffle_n: int = 0,
809
+ min_duration: Optional[float] = None,
810
+ max_duration: Optional[float] = None,
811
+ trim: bool = False,
812
+ bos_id: Optional[int] = None,
813
+ eos_id: Optional[int] = None,
814
+ pad_id: int = 0,
815
+ shard_strategy: str = "scatter",
816
+ shard_manifests: bool = False,
817
+ global_rank: int = 0,
818
+ world_size: int = 0,
819
+ return_sample_id: bool = False,
820
+ ):
821
+ self.shard_manifests = shard_manifests
822
+
823
+ # Shard manifests if necessary and possible and then expand the paths
824
+ manifest_filepath = shard_manifests_if_needed(
825
+ shard_manifests=shard_manifests,
826
+ shard_strategy=shard_strategy,
827
+ manifest_filepaths=manifest_filepath,
828
+ world_size=world_size,
829
+ global_rank=global_rank,
830
+ )
831
+
832
+ # If necessary, cache manifests from object store
833
+ cache_datastore_manifests(manifest_filepaths=manifest_filepath)
834
+
835
+ self.manifest_processor = ASRManifestProcessor(
836
+ manifest_filepath=manifest_filepath,
837
+ parser=parser,
838
+ max_duration=max_duration,
839
+ min_duration=min_duration,
840
+ max_utts=0,
841
+ bos_id=bos_id,
842
+ eos_id=eos_id,
843
+ pad_id=pad_id,
844
+ index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID
845
+ )
846
+
847
+ self.len = self._compute_len()
848
+
849
+ self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
850
+ self.trim = trim
851
+ self.eos_id = eos_id
852
+ self.bos_id = bos_id
853
+ self.pad_id = pad_id
854
+ self.return_sample_id = return_sample_id
855
+
856
+ audio_tar_filepaths = expand_sharded_filepaths(
857
+ sharded_filepaths=audio_tar_filepaths,
858
+ shard_strategy=shard_strategy,
859
+ world_size=world_size,
860
+ global_rank=global_rank,
861
+ )
862
+
863
+ # Put together WebDataset
864
+ self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)
865
+
866
+ if shuffle_n > 0:
867
+ self._dataset = self._dataset.shuffle(shuffle_n)
868
+ else:
869
+ logging.info("WebDataset will not shuffle files within the tar files.")
870
+
871
+ self._dataset = (
872
+ self._dataset.rename(audio='wav;ogg;flac', key='__key__')
873
+ .to_tuple('audio', 'key')
874
+ .pipe(self._filter)
875
+ .pipe(self._loop_offsets)
876
+ .map(f=self._build_sample)
877
+ )
878
+
879
+ def _filter(self, iterator):
880
+ """This function is used to remove samples that have been filtered out by ASRAudioText already.
881
+ Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
882
+ that was filtered out (e.g. for duration).
883
+ Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
884
+ which may make your code hang as one process will finish before the other.
885
+ """
886
+
887
+ class TarredAudioFilter:
888
+ def __init__(self, collection):
889
+ self.iterator = iterator
890
+ self.collection = collection
891
+
892
+ def __iter__(self):
893
+ return self
894
+
895
+ def __next__(self):
896
+ while True:
897
+ audio_bytes, audio_filename = next(self.iterator)
898
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
899
+ if file_id in self.collection.mapping:
900
+ return audio_bytes, audio_filename
901
+
902
+ return TarredAudioFilter(self.manifest_processor.collection)
903
+
904
+ def _loop_offsets(self, iterator):
905
+ """This function is used to iterate through utterances with different offsets for each file.
906
+ """
907
+
908
+ class TarredAudioLoopOffsets:
909
+ def __init__(self, collection):
910
+ self.iterator = iterator
911
+ self.collection = collection
912
+ self.current_fn = None
913
+ self.current_bytes = None
914
+ self.offset_id = 0
915
+
916
+ def __iter__(self):
917
+ return self
918
+
919
+ def __next__(self):
920
+ if self.current_fn is None:
921
+ self.current_bytes, self.current_fn = next(self.iterator)
922
+ self.offset_id = 0
923
+ else:
924
+ offset_list = self.collection.mapping[self.current_fn]
925
+ if len(offset_list) == self.offset_id + 1:
926
+ self.current_bytes, self.current_fn = next(self.iterator)
927
+ self.offset_id = 0
928
+ else:
929
+ self.offset_id += 1
930
+
931
+ return self.current_bytes, self.current_fn, self.offset_id
932
+
933
+ return TarredAudioLoopOffsets(self.manifest_processor.collection)
934
+
935
+ def _collate_fn(self, batch):
936
+ return _speech_collate_fn(batch, self.pad_id)
937
+
938
+ def _build_sample(self, tup):
939
+ """Builds the training sample by combining the data from the WebDataset with the manifest info.
940
+ """
941
+ audio_bytes, audio_filename, offset_id = tup
942
+
943
+ # Grab manifest entry from self.manifest_preprocessor.collection
944
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
945
+ manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id]
946
+ manifest_entry = self.manifest_processor.collection[manifest_idx]
947
+
948
+ offset = manifest_entry.offset
949
+ if offset is None:
950
+ offset = 0
951
+
952
+ # Convert audio bytes to IO stream for processing (for SoundFile to read)
953
+ audio_filestream = io.BytesIO(audio_bytes)
954
+ features = self.featurizer.process(
955
+ audio_filestream,
956
+ offset=offset,
957
+ duration=manifest_entry.duration,
958
+ trim=self.trim,
959
+ orig_sr=manifest_entry.orig_sr,
960
+ )
961
+ audio_filestream.close()
962
+
963
+ # Audio features
964
+ f, fl = features, torch.tensor(features.shape[0]).long()
965
+
966
+ # Text features
967
+ t, tl = manifest_entry.text_tokens, len(manifest_entry.text_tokens)
968
+
969
+ self.manifest_processor.process_text_by_sample(sample=manifest_entry)
970
+
971
+ if self.bos_id is not None:
972
+ t = [self.bos_id] + t
973
+ tl += 1
974
+ if self.eos_id is not None:
975
+ t = t + [self.eos_id]
976
+ tl += 1
977
+
978
+ if self.return_sample_id:
979
+ return f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), manifest_idx
980
+ else:
981
+ return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
982
+
983
+ def get_manifest_sample(self, sample_id):
984
+ return self.manifest_processor.collection[sample_id]
985
+
986
+ def __iter__(self):
987
+ return self._dataset.__iter__()
988
+
989
+ def _compute_len(self):
990
+ if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized():
991
+ my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda()
992
+ torch.distributed.all_reduce(my_len)
993
+ my_len = my_len.int()
994
+ logging.info(f'Sharded manifests: Total length: {my_len}')
995
+ else:
996
+ my_len = len(self.manifest_processor.collection)
997
+
998
+ return my_len
999
+
1000
+ def __len__(self):
1001
+ return self.len
1002
+
1003
+
1004
+ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
1005
+ """
1006
+ A similar Dataset to the AudioToCharDataset, but which loads tarred audio files.
1007
+
1008
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset),
1009
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
1010
+ contain the information for one audio file, including at least the transcript and name of the audio
1011
+ file within the tarball.
1012
+
1013
+ Valid formats for the audio_tar_filepaths argument include:
1014
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
1015
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
1016
+
1017
+ See the WebDataset documentation for more information about accepted data and input formats.
1018
+
1019
+ If using multiple workers the number of shards should be divisible by world_size to ensure an
1020
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
1021
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
1022
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
1023
+
1024
+ Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been
1025
+ replaced by shuffle_n (int).
1026
+
1027
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
1028
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
1029
+
1030
+ Args:
1031
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
1032
+ string (can be brace-expandable).
1033
+ manifest_filepath (str): Path to the manifest.
1034
+ labels (list): List of characters that can be output by the ASR model.
1035
+ For Jasper, this is the 28 character set {a-z '}. The CTC blank
1036
+ symbol is automatically added later for models using ctc.
1037
+ sample_rate (int): Sample rate to resample loaded audio to
1038
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
1039
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
1040
+ object used to augment loaded audio
1041
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
1042
+ See WebDataset documentation for more details.
1043
+ Defaults to 0.
1044
+ min_duration (float): Dataset parameter.
1045
+ All training files which have a duration less than min_duration
1046
+ are dropped. Note: Duration is read from the manifest JSON.
1047
+ Defaults to 0.1.
1048
+ max_duration (float): Dataset parameter.
1049
+ All training files which have a duration more than max_duration
1050
+ are dropped. Note: Duration is read from the manifest JSON.
1051
+ Defaults to None.
1052
+ blank_index (int): Blank character index, defaults to -1.
1053
+ unk_index (int): Unknown character index, defaults to -1.
1054
+ normalize (bool): Dataset parameter.
1055
+ Whether to use automatic text cleaning.
1056
+ It is highly recommended to manually clean text for best results.
1057
+ Defaults to True.
1058
+ trim (bool): Whether to use trim silence from beginning and end
1059
+ of audio signal using librosa.effects.trim().
1060
+ Defaults to False.
1061
+ bos_id (id): Dataset parameter.
1062
+ Beginning of string symbol id used for seq2seq models.
1063
+ Defaults to None.
1064
+ eos_id (id): Dataset parameter.
1065
+ End of string symbol id used for seq2seq models.
1066
+ Defaults to None.
1067
+ pad_id (id): Token used to pad when collating samples in batches.
1068
+ If this is None, pads using 0s.
1069
+ Defaults to None.
1070
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
1071
+
1072
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
1073
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
1074
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
1075
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
1076
+ The benefit of replication is that it allows each node to sample data points from the entire
1077
+ dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
1078
+
1079
+ .. warning::
1080
+
1081
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
1082
+ and therefore more than one node may sample the same tarfile, and even sample the same
1083
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
1084
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
1085
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
1086
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
1087
+ or test datasets.
1088
+
1089
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
1090
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
1091
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
1092
+ """
1093
+
1094
+ def __init__(
1095
+ self,
1096
+ audio_tar_filepaths: Union[str, List[str]],
1097
+ manifest_filepath: str,
1098
+ labels: List[str],
1099
+ sample_rate: int,
1100
+ int_values: bool = False,
1101
+ augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
1102
+ shuffle_n: int = 0,
1103
+ min_duration: Optional[float] = None,
1104
+ max_duration: Optional[float] = None,
1105
+ blank_index: int = -1,
1106
+ unk_index: int = -1,
1107
+ normalize: bool = True,
1108
+ trim: bool = False,
1109
+ bos_id: Optional[int] = None,
1110
+ eos_id: Optional[int] = None,
1111
+ parser: Optional[str] = 'en',
1112
+ pad_id: int = 0,
1113
+ shard_strategy: str = "scatter",
1114
+ shard_manifests: bool = False,
1115
+ global_rank: int = 0,
1116
+ world_size: int = 0,
1117
+ return_sample_id: bool = False,
1118
+ ):
1119
+ self.labels = labels
1120
+
1121
+ parser = parsers.make_parser(
1122
+ labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
1123
+ )
1124
+
1125
+ super().__init__(
1126
+ audio_tar_filepaths=audio_tar_filepaths,
1127
+ manifest_filepath=manifest_filepath,
1128
+ parser=parser,
1129
+ sample_rate=sample_rate,
1130
+ int_values=int_values,
1131
+ augmentor=augmentor,
1132
+ shuffle_n=shuffle_n,
1133
+ min_duration=min_duration,
1134
+ max_duration=max_duration,
1135
+ trim=trim,
1136
+ bos_id=bos_id,
1137
+ eos_id=eos_id,
1138
+ pad_id=pad_id,
1139
+ shard_strategy=shard_strategy,
1140
+ shard_manifests=shard_manifests,
1141
+ global_rank=global_rank,
1142
+ world_size=world_size,
1143
+ return_sample_id=return_sample_id,
1144
+ )
1145
+
1146
+
1147
+ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
1148
+ """
1149
+ A similar Dataset to the AudioToBPEDataset, but which loads tarred audio files.
1150
+
1151
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToBPEDataset),
1152
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
1153
+ contain the information for one audio file, including at least the transcript and name of the audio
1154
+ file within the tarball.
1155
+
1156
+ Valid formats for the audio_tar_filepaths argument include:
1157
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
1158
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
1159
+
1160
+ See the WebDataset documentation for more information about accepted data and input formats.
1161
+
1162
+ If using multiple workers the number of shards should be divisible by world_size to ensure an
1163
+ even split among workers. If it is not divisible, logging will give a warning but training will proceed.
1164
+ In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
1165
+ is applied. We currently do not check for this, but your program may hang if the shards are uneven!
1166
+
1167
+ Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been
1168
+ replaced by shuffle_n (int).
1169
+
1170
+ Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
1171
+ after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.
1172
+
1173
+ Args:
1174
+ audio_tar_filepaths: Either a list of audio tarball filepaths, or a
1175
+ string (can be brace-expandable).
1176
+ manifest_filepath (str): Path to the manifest.
1177
+ tokenizer (TokenizerSpec): Either a Word Piece Encoding tokenizer (BERT),
1178
+ or a Sentence Piece Encoding tokenizer (BPE). The CTC blank
1179
+ symbol is automatically added later for models using ctc.
1180
+ sample_rate (int): Sample rate to resample loaded audio to
1181
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
1182
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
1183
+ object used to augment loaded audio
1184
+ shuffle_n (int): How many samples to look ahead and load to be shuffled.
1185
+ See WebDataset documentation for more details.
1186
+ Defaults to 0.
1187
+ min_duration (float): Dataset parameter.
1188
+ All training files which have a duration less than min_duration
1189
+ are dropped. Note: Duration is read from the manifest JSON.
1190
+ Defaults to 0.1.
1191
+ max_duration (float): Dataset parameter.
1192
+ All training files which have a duration more than max_duration
1193
+ are dropped. Note: Duration is read from the manifest JSON.
1194
+ Defaults to None.
1195
+ trim (bool): Whether to use trim silence from beginning and end
1196
+ of audio signal using librosa.effects.trim().
1197
+ Defaults to False.
1198
+ use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
1199
+ tokens to beginning and ending of speech respectively.
1200
+ pad_id (id): Token used to pad when collating samples in batches.
1201
+ If this is None, pads using 0s.
1202
+ Defaults to None.
1203
+ shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
1204
+
1205
+ - `scatter`: The default shard strategy applied by WebDataset, where each node gets
1206
+ a unique set of shards, which are permanently pre-allocated and never changed at runtime.
1207
+ - `replicate`: Optional shard strategy, where each node gets all of the set of shards
1208
+ available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
1209
+ The benefit of replication is that it allows each node to sample data points from the entire
1210
+ dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.
1211
+
1212
+ .. warning::
1213
+
1214
+ Replicated strategy allows every node to sample the entire set of available tarfiles,
1215
+ and therefore more than one node may sample the same tarfile, and even sample the same
1216
+ data points! As such, there is no assured guarantee that all samples in the dataset will be
1217
+ sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific
1218
+ occasions (when the number of shards is not divisible with ``world_size``), will not sample
1219
+ the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
1220
+ or test datasets.
1221
+
1222
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
1223
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
1224
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
1225
+ """
1226
+
1227
+ def __init__(
1228
+ self,
1229
+ audio_tar_filepaths: Union[str, List[str]],
1230
+ manifest_filepath: str,
1231
+ tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
1232
+ sample_rate: int,
1233
+ int_values: bool = False,
1234
+ augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
1235
+ shuffle_n: int = 0,
1236
+ min_duration: Optional[float] = None,
1237
+ max_duration: Optional[float] = None,
1238
+ trim: bool = False,
1239
+ use_start_end_token: bool = True,
1240
+ shard_strategy: str = "scatter",
1241
+ shard_manifests: bool = False,
1242
+ global_rank: int = 0,
1243
+ world_size: int = 0,
1244
+ return_sample_id: bool = False,
1245
+ ):
1246
+ if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
1247
+ bos_id = tokenizer.bos_id
1248
+ else:
1249
+ bos_id = None
1250
+
1251
+ if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0:
1252
+ eos_id = tokenizer.eos_id
1253
+ else:
1254
+ eos_id = None
1255
+
1256
+ if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0:
1257
+ pad_id = tokenizer.pad_id
1258
+ else:
1259
+ pad_id = 0
1260
+
1261
+ class TokenizerWrapper:
1262
+ def __init__(self, tokenizer):
1263
+ if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
1264
+ self.is_aggregate = True
1265
+ else:
1266
+ self.is_aggregate = False
1267
+ self._tokenizer = tokenizer
1268
+
1269
+ def __call__(self, *args):
1270
+ if isinstance(args[0], List) and self.is_aggregate:
1271
+ t = []
1272
+ for span in args[0]:
1273
+ t.extend(self._tokenizer.text_to_ids(span['str'], span['lang']))
1274
+ return t
1275
+
1276
+ t = self._tokenizer.text_to_ids(*args)
1277
+ return t
1278
+
1279
+ super().__init__(
1280
+ audio_tar_filepaths=audio_tar_filepaths,
1281
+ manifest_filepath=manifest_filepath,
1282
+ parser=TokenizerWrapper(tokenizer),
1283
+ sample_rate=sample_rate,
1284
+ int_values=int_values,
1285
+ augmentor=augmentor,
1286
+ shuffle_n=shuffle_n,
1287
+ min_duration=min_duration,
1288
+ max_duration=max_duration,
1289
+ trim=trim,
1290
+ bos_id=bos_id,
1291
+ eos_id=eos_id,
1292
+ pad_id=pad_id,
1293
+ shard_strategy=shard_strategy,
1294
+ shard_manifests=shard_manifests,
1295
+ global_rank=global_rank,
1296
+ world_size=world_size,
1297
+ return_sample_id=return_sample_id,
1298
+ )
1299
+
1300
+
1301
+ class BucketingDataset(IterableDataset):
1302
+ """
1303
+ A Dataset which wraps another IterableDataset and adopts it for bucketing
1304
+ Args:
1305
+ dataset (IterableDataset): The IterableDataset to get wrapped
1306
+ bucketing_batch_size (int): Number of samples to build a batch
1307
+ """
1308
+
1309
+ def __init__(
1310
+ self, dataset: IterableDataset, bucketing_batch_size: int,
1311
+ ):
1312
+ self.wrapped_dataset = dataset
1313
+ self.bucketing_batch_size = bucketing_batch_size
1314
+ super().__init__()
1315
+
1316
+ def _collate_fn(self, batch):
1317
+ return _speech_collate_fn(batch[0], self.wrapped_dataset.pad_id)
1318
+
1319
+ def __iter__(self):
1320
+ return BucketingIterator(
1321
+ wrapped_ds=self.wrapped_dataset._dataset, bucketing_batch_size=self.bucketing_batch_size
1322
+ ).__iter__()
1323
+
1324
+ def __len__(self):
1325
+ return int(math.ceil(len(self.wrapped_dataset) / float(self.bucketing_batch_size)))
1326
+
1327
+
1328
+ class BucketingIterator:
1329
+ def __init__(self, wrapped_ds, bucketing_batch_size):
1330
+ self.wrapped_ds = wrapped_ds
1331
+ self.wrapped_iter = None
1332
+ self.bucketing_batch_size = bucketing_batch_size
1333
+
1334
+ def __iter__(self):
1335
+ self.wrapped_iter = iter(self.wrapped_ds)
1336
+ return self
1337
+
1338
+ def __next__(self):
1339
+ batches = []
1340
+ for idx in range(self.bucketing_batch_size):
1341
+ try:
1342
+ sample = next(self.wrapped_iter)
1343
+ except StopIteration:
1344
+ break
1345
+ batches.append(sample)
1346
+ if len(batches) == 0:
1347
+ raise StopIteration
1348
+ return batches
1349
+
1350
+
1351
+ class RandomizedChainDataset(ChainDataset):
1352
+ def __init__(self, datasets: Iterable[Dataset], rnd_seed=0) -> None:
1353
+ super(RandomizedChainDataset, self).__init__(list(datasets))
1354
+ self.rnd_gen = np.random.RandomState(rnd_seed)
1355
+
1356
+ def __iter__(self):
1357
+ shuffled_order = self.rnd_gen.permutation(len(self.datasets))
1358
+ for dataset_idx in shuffled_order:
1359
+ d = self.datasets[dataset_idx]
1360
+ assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
1361
+ for idx, x in enumerate(d):
1362
+ yield x
1363
+ # in case d is an infinite dataset, we want to break the loop
1364
+ # so that the other datasets get a chance to yield too
1365
+ if idx >= len(d) - 1:
1366
+ break
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dali.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import operator
17
+ import os.path
18
+ import time
19
+ from collections.abc import Iterator
20
+ from typing import Callable, List, Optional, Union
21
+
22
+ import torch
23
+ from omegaconf import DictConfig
24
+
25
+ from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor, expand_sharded_filepaths
26
+ from nemo.collections.common.parts.preprocessing import parsers
27
+ from nemo.utils import logging, model_utils
28
+
29
+ try:
30
+ import nvidia.dali as dali
31
+ from nvidia.dali.pipeline import Pipeline
32
+ from nvidia.dali.plugin.pytorch import DALIGenericIterator as DALIPytorchIterator
33
+ from nvidia.dali.plugin.pytorch import LastBatchPolicy as LastBatchPolicy
34
+
35
+ HAVE_DALI = True
36
+ except (ImportError, ModuleNotFoundError):
37
+ HAVE_DALI = False
38
+
39
+ __all__ = [
40
+ 'AudioToCharDALIDataset',
41
+ 'AudioToBPEDALIDataset',
42
+ ]
43
+
44
+ """
45
+ Below minimum version is required to access the "read_idxs" argument in
46
+ dali.fn.readers.nemo_asr
47
+ """
48
+ __DALI_MINIMUM_VERSION__ = "1.11"
49
+
50
+ DALI_INSTALLATION_MESSAGE = (
51
+ "Could not import `nvidia.dali`.\n"
52
+ "Please install DALI by following the steps provided here - \n"
53
+ "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
54
+ )
55
+
56
+
57
+ def is_dali_supported(min_version: str, verbose: bool = False) -> bool:
58
+ """
59
+ Checks if DALI in installed, and version is >= min_verion.
60
+
61
+ Args:
62
+ min_version: A semver str that is the minimum requirement.
63
+ verbose: Whether to log the installation instructions if DALI is not found.
64
+
65
+ Returns:
66
+ bool - whether DALI could be imported or not.
67
+ """
68
+ module_available, _ = model_utils.check_lib_version(
69
+ 'nvidia.dali', checked_version=min_version, operator=operator.ge
70
+ )
71
+
72
+ # If DALI is not installed
73
+ if module_available is None:
74
+ if verbose:
75
+ logging.info(DALI_INSTALLATION_MESSAGE)
76
+
77
+ return False
78
+
79
+ return module_available
80
+
81
+
82
+ class DALIOutputs(object):
83
+ def __init__(self, out_dict):
84
+ self._has_processed_signal = 'processed_signal' in out_dict and 'processed_signal_len' in out_dict
85
+ if not self._has_processed_signal:
86
+ assert 'audio' in out_dict and 'audio_len' in out_dict
87
+ assert 'transcript' in out_dict and 'transcript_len' in out_dict
88
+ if self._has_processed_signal:
89
+ self._outs = (
90
+ out_dict['processed_signal'],
91
+ out_dict['processed_signal_len'].reshape(-1),
92
+ out_dict['transcript'],
93
+ out_dict['transcript_len'].reshape(-1),
94
+ )
95
+ else:
96
+ self._outs = (
97
+ out_dict['audio'],
98
+ out_dict['audio_len'].reshape(-1),
99
+ out_dict['transcript'],
100
+ out_dict['transcript_len'].reshape(-1),
101
+ )
102
+
103
+ @property
104
+ def has_processed_signal(self):
105
+ return self._has_processed_signal
106
+
107
+ def __getitem__(self, key):
108
+ return self._outs[key]
109
+
110
+ def __len__(self):
111
+ return len(self._outs)
112
+
113
+
114
+ class _AudioTextDALIDataset(Iterator):
115
+ """
116
+ NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a sample descriptor in JSON,
117
+ including audio files, transcripts, and durations (in seconds).
118
+ Here's an example:
119
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
120
+ ...
121
+ {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
122
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
123
+
124
+ Args:
125
+ manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths.
126
+ device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'.
127
+ batch_size (int): Number of samples in a batch.
128
+ parser (str, callable): A str for an inbuilt parser, or a callable with signature f(str) -> List[int].
129
+ sample_rate (int): Sample rate to resample loaded audio to.
130
+ num_threads (int): Number of CPU processing threads to be created by the DALI pipeline.
131
+ max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files.
132
+ min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files.
133
+ bos_id (int): Id of beginning of sequence symbol to append if not None
134
+ eos_id (int): Id of end of sequence symbol to append if not None
135
+ pad_id (int): Id used to pad the input. Defaults to 0 if not provided.
136
+ trim (bool): If True, it will extract the nonsilent region of the loaded audio signal.
137
+ shuffle (bool): If set to True, the dataset will shuffled after loading.
138
+ drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size.
139
+ If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
140
+ device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
141
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
142
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
143
+ preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
144
+ return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ manifest_filepath: str,
150
+ device: str,
151
+ batch_size: int,
152
+ parser: Union[str, Callable],
153
+ audio_tar_filepaths: Optional[Union[str, List[str]]] = None,
154
+ audio_tar_index_filepaths: Optional[Union[str, List[str]]] = None,
155
+ sample_rate: int = 16000,
156
+ num_threads: int = 4,
157
+ max_duration: float = 0.0,
158
+ min_duration: float = 0.0,
159
+ bos_id: Optional[int] = None,
160
+ eos_id: Optional[int] = None,
161
+ pad_id: int = 0,
162
+ trim: bool = False,
163
+ shuffle: bool = False,
164
+ drop_last: bool = False,
165
+ shard_strategy: str = "scatter",
166
+ device_id: int = 0,
167
+ global_rank: int = 0,
168
+ world_size: int = 1,
169
+ preprocessor_cfg: DictConfig = None,
170
+ return_sample_id: bool = False,
171
+ ):
172
+ self.drop_last = drop_last # used by lr_scheduler
173
+ if return_sample_id:
174
+ raise ValueError(
175
+ "Currently DALI data layers don't support returning the sample_id and return_sample_id can not be enabled."
176
+ )
177
+ self.return_sample_id = return_sample_id
178
+
179
+ if not HAVE_DALI:
180
+ raise ModuleNotFoundError(
181
+ f"{self} requires NVIDIA DALI to be installed. "
182
+ f"See: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html#id1"
183
+ )
184
+
185
+ if device not in ('cpu', 'gpu'):
186
+ raise ValueError(
187
+ f"{self} received an unexpected device argument {device}. Supported values are: 'cpu', 'gpu'"
188
+ )
189
+
190
+ device_id = device_id if device == 'gpu' else None
191
+
192
+ self.batch_size = batch_size # Used by NeMo
193
+
194
+ self.device = device
195
+ self.device_id = device_id
196
+
197
+ if world_size > 1:
198
+ self.shard_id = global_rank
199
+ self.num_shards = world_size
200
+ else:
201
+ self.shard_id = None
202
+ self.num_shards = None
203
+
204
+ self.eos_id = eos_id
205
+ self.bos_id = bos_id
206
+ self.sample_rate = sample_rate
207
+
208
+ self.pipe = Pipeline(
209
+ batch_size=batch_size,
210
+ num_threads=num_threads,
211
+ device_id=self.device_id,
212
+ exec_async=True,
213
+ exec_pipelined=True,
214
+ )
215
+
216
+ has_preprocessor = preprocessor_cfg is not None
217
+ if has_preprocessor:
218
+ if preprocessor_cfg._target_ == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor":
219
+ feature_type = "mel_spectrogram"
220
+ elif preprocessor_cfg._target_ == "nemo.collections.asr.modules.AudioToMFCCPreprocessor":
221
+ feature_type = "mfcc"
222
+ else:
223
+ raise ValueError(
224
+ f"{self} received an unexpected preprocessor configuration: {preprocessor_cfg._target_}."
225
+ f" Supported preprocessors are: AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor"
226
+ )
227
+
228
+ # Default values taken from AudioToMelSpectrogramPreprocessor
229
+ params = preprocessor_cfg
230
+ self.dither = params['dither'] if 'dither' in params else 0.0
231
+ self.preemph = params['preemph'] if 'preemph' in params else 0.97
232
+ self.window_size_sec = params['window_size'] if 'window_size' in params else 0.02
233
+ self.window_stride_sec = params['window_stride'] if 'window_stride' in params else 0.01
234
+ self.sample_rate = params['sample_rate'] if 'sample_rate' in params else sample_rate
235
+ self.window_size = int(self.window_size_sec * self.sample_rate)
236
+ self.window_stride = int(self.window_stride_sec * self.sample_rate)
237
+
238
+ normalize = params['normalize'] if 'normalize' in params else 'per_feature'
239
+ if normalize == 'per_feature': # Each freq channel independently
240
+ self.normalization_axes = (1,)
241
+ elif normalize == 'all_features':
242
+ self.normalization_axes = (0, 1)
243
+ else:
244
+ raise ValueError(
245
+ f"{self} received {normalize} for the normalize parameter."
246
+ f" It must be either 'per_feature' or 'all_features'."
247
+ )
248
+
249
+ self.window = None
250
+ window_name = params['window'] if 'window' in params else 'hann'
251
+ torch_windows = {
252
+ 'hann': torch.hann_window,
253
+ 'hamming': torch.hamming_window,
254
+ 'blackman': torch.blackman_window,
255
+ 'bartlett': torch.bartlett_window,
256
+ 'none': None,
257
+ }
258
+
259
+ if window_name == 'ones':
260
+ window_tensor = torch.ones(self.window_size)
261
+ else:
262
+ try:
263
+ window_fn = torch_windows.get(window_name, None)
264
+ except:
265
+ raise ValueError(
266
+ f"{self} received '{window_name}' for the window parameter."
267
+ f" It must be one of: ('hann', 'ones', 'hamming', 'blackman', 'bartlett', None)."
268
+ f" None is equivalent to 'hann'."
269
+ )
270
+ window_tensor = window_fn(self.window_size, periodic=False) if window_fn else None
271
+ self.window = window_tensor.numpy().tolist() if window_tensor is not None else None
272
+
273
+ self.n_fft = params['n_fft'] if 'n_fft' in params else 2 ** math.ceil(math.log2(self.window_size))
274
+ self.n_mels = params['n_mels'] if 'n_mels' in params else 64
275
+ self.n_mfcc = params['n_mfcc'] if 'n_mfcc' in params else 64
276
+
277
+ features = params['features'] if 'features' in params else 0
278
+ if features > 0:
279
+ if feature_type == 'mel_spectrogram':
280
+ self.n_mels = features
281
+ elif feature_type == 'mfcc':
282
+ self.n_mfcc = features
283
+
284
+ # TODO Implement frame splicing
285
+ if 'frame_splicing' in params:
286
+ assert params['frame_splicing'] == 1, "Frame splicing is not implemented"
287
+
288
+ self.freq_low = params['lowfreq'] if 'lowfreq' in params else 0.0
289
+ self.freq_high = params['highfreq'] if 'highfreq' in params else self.sample_rate / 2.0
290
+ self.log_features = params['log'] if 'log' in params else True
291
+
292
+ # We want to avoid taking the log of zero
293
+ # There are two options: either adding or clamping to a small value
294
+
295
+ self.log_zero_guard_type = params['log_zero_guard_type'] if 'log_zero_guard_type' in params else 'add'
296
+ if self.log_zero_guard_type not in ["add", "clamp"]:
297
+ raise ValueError(
298
+ f"{self} received {self.log_zero_guard_type} for the "
299
+ f"log_zero_guard_type parameter. It must be either 'add' or "
300
+ f"'clamp'."
301
+ )
302
+
303
+ self.log_zero_guard_value = (
304
+ params['log_zero_guard_value'] if 'log_zero_guard_value' in params else 2 ** -24
305
+ )
306
+ if isinstance(self.log_zero_guard_value, str):
307
+ if self.log_zero_guard_value == "tiny":
308
+ self.log_zero_guard_value = torch.finfo(torch.float32).tiny
309
+ elif self.log_zero_guard_value == "eps":
310
+ self.log_zero_guard_value = torch.finfo(torch.float32).eps
311
+ else:
312
+ raise ValueError(
313
+ f"{self} received {self.log_zero_guard_value} for the log_zero_guard_type parameter."
314
+ f"It must be either a number, 'tiny', or 'eps'"
315
+ )
316
+
317
+ self.mag_power = params['mag_power'] if 'mag_power' in params else 2
318
+ if self.mag_power != 1.0 and self.mag_power != 2.0:
319
+ raise ValueError(
320
+ f"{self} received {self.mag_power} for the mag_power parameter." f" It must be either 1.0 or 2.0."
321
+ )
322
+
323
+ self.pad_to = max(params['pad_to'], 1) if 'pad_to' in params else 16
324
+ self.pad_value = params['pad_value'] if 'pad_value' in params else 0.0
325
+
326
+ with self.pipe:
327
+ if audio_tar_filepaths is None and audio_tar_index_filepaths is None:
328
+ audio, indices = dali.fn.readers.nemo_asr(
329
+ name="Reader",
330
+ manifest_filepaths=manifest_filepath.split(','),
331
+ dtype=dali.types.FLOAT,
332
+ downmix=True,
333
+ sample_rate=float(self.sample_rate),
334
+ min_duration=min_duration,
335
+ max_duration=max_duration,
336
+ read_sample_rate=False,
337
+ read_text=False,
338
+ read_idxs=True,
339
+ random_shuffle=shuffle,
340
+ shard_id=self.shard_id,
341
+ num_shards=self.num_shards,
342
+ pad_last_batch=True,
343
+ )
344
+
345
+ self.is_tarred_dataset = False
346
+
347
+ elif audio_tar_filepaths is not None and audio_tar_index_filepaths is not None:
348
+ audio_tar_filepaths = expand_sharded_filepaths(
349
+ audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank
350
+ )
351
+ audio_tar_index_filepaths = expand_sharded_filepaths(
352
+ audio_tar_index_filepaths,
353
+ shard_strategy=shard_strategy,
354
+ world_size=world_size,
355
+ global_rank=global_rank,
356
+ )
357
+
358
+ if len(audio_tar_filepaths) != len(audio_tar_index_filepaths) and len(audio_tar_index_filepaths) != 0:
359
+ raise ValueError(
360
+ f"Number of filepaths provided for `audio_tar_filepaths` must match "
361
+ f"`audio_tar_index_filepaths`. Got {len(audio_tar_filepaths)} audio_tar_filepaths and "
362
+ f"{len(audio_tar_index_filepaths)} audio_tar_index_filepaths."
363
+ )
364
+
365
+ tar_file = dali.fn.readers.webdataset(
366
+ paths=audio_tar_filepaths,
367
+ index_paths=audio_tar_index_filepaths,
368
+ name="Reader",
369
+ ext=["wav"],
370
+ missing_component_behavior="error",
371
+ random_shuffle=shuffle,
372
+ shard_id=self.shard_id,
373
+ num_shards=self.num_shards,
374
+ pad_last_batch=True,
375
+ )
376
+ audio, _ = dali.fn.decoders.audio(
377
+ tar_file, dtype=dali.types.FLOAT, downmix=True, sample_rate=float(self.sample_rate),
378
+ )
379
+ indices = dali.fn.get_property(tar_file, key="source_info")
380
+ indices = dali.fn.pad(indices)
381
+
382
+ self.is_tarred_dataset = True
383
+
384
+ else:
385
+ raise RuntimeError(
386
+ "When using DALI datasets, either `audio_tar_filepaths` "
387
+ "and `audio_tar_index_filepaths` should either both be None (sequential dataset)"
388
+ "or provided (tarred dataset)."
389
+ )
390
+
391
+ # Extract nonsilent region, if necessary
392
+ if trim:
393
+ # Need to extract non-silent region before moving to the GPU
394
+ roi_start, roi_len = dali.fn.nonsilent_region(audio, cutoff_db=-60)
395
+ audio = audio.gpu() if self.device == 'gpu' else audio
396
+ audio = dali.fn.slice(
397
+ audio, roi_start, roi_len, normalized_anchor=False, normalized_shape=False, axes=[0]
398
+ )
399
+ else:
400
+ audio = audio.gpu() if self.device == 'gpu' else audio
401
+
402
+ if not has_preprocessor:
403
+ # No preprocessing, the output is the audio signal
404
+ audio_len = dali.fn.shapes(dali.fn.reshape(audio, shape=[-1]))
405
+ audio = dali.fn.pad(audio)
406
+ self.pipe.set_outputs(audio, audio_len, indices)
407
+ else:
408
+ # Additive gaussian noise (dither)
409
+ if self.dither > 0.0:
410
+ gaussian_noise = dali.fn.random.normal(audio)
411
+ audio = audio + self.dither * gaussian_noise
412
+
413
+ # Preemphasis filter
414
+ if self.preemph > 0.0:
415
+ audio = dali.fn.preemphasis_filter(audio, preemph_coeff=self.preemph, border='zero')
416
+
417
+ # Power spectrogram
418
+ spec = dali.fn.spectrogram(
419
+ audio,
420
+ nfft=self.n_fft,
421
+ window_length=self.window_size,
422
+ window_step=self.window_stride,
423
+ window_fn=self.window,
424
+ )
425
+
426
+ if feature_type == 'mel_spectrogram' or feature_type == 'mfcc':
427
+ # Spectrogram to Mel Spectrogram
428
+ spec = dali.fn.mel_filter_bank(
429
+ spec,
430
+ sample_rate=self.sample_rate,
431
+ nfilter=self.n_mels,
432
+ normalize=True,
433
+ freq_low=self.freq_low,
434
+ freq_high=self.freq_high,
435
+ )
436
+ # Mel Spectrogram to MFCC
437
+ if feature_type == 'mfcc':
438
+ spec = dali.fn.mfcc(spec, n_mfcc=self.n_mfcc)
439
+
440
+ # Logarithm
441
+ if self.log_zero_guard_type == 'add':
442
+ spec = spec + self.log_zero_guard_value
443
+
444
+ spec = dali.fn.to_decibels(
445
+ spec, multiplier=math.log(10), reference=1.0, cutoff_db=math.log(self.log_zero_guard_value)
446
+ )
447
+
448
+ # Normalization
449
+ spec = dali.fn.normalize(spec, axes=self.normalization_axes, epsilon=1e-5 ** 2, ddof=1)
450
+
451
+ # Extracting the length of the spectrogram
452
+ spec_len = dali.fn.slice(dali.fn.shapes(spec), 1, 1, axes=(0,))
453
+
454
+ # Pads feature dimension to be a multiple of `pad_to` and the temporal dimension to be as big as the largest sample (shape -1)
455
+ spec = dali.fn.pad(spec, fill_value=self.pad_value, axes=(0, 1), align=(self.pad_to, 1), shape=(1, -1))
456
+ self.pipe.set_outputs(spec, spec_len, indices)
457
+
458
+ x = time.time()
459
+ # Building DALI pipeline
460
+ self.pipe.build()
461
+ y = time.time()
462
+
463
+ logging.info(f"Time for pipe.build() : {(y - x)} seconds")
464
+
465
+ if has_preprocessor:
466
+ output_names = ['processed_signal', 'processed_signal_len', 'manifest_indices']
467
+ else:
468
+ output_names = ['audio', 'audio_len', 'manifest_indices']
469
+
470
+ x = time.time()
471
+ last_batch_policy = LastBatchPolicy.DROP if drop_last else LastBatchPolicy.PARTIAL
472
+ self._iter = DALIPytorchIterator(
473
+ [self.pipe],
474
+ output_map=output_names,
475
+ reader_name="Reader",
476
+ last_batch_policy=last_batch_policy,
477
+ dynamic_shape=True,
478
+ auto_reset=True,
479
+ )
480
+ y = time.time()
481
+ logging.info(f"Time for DALIPytorchIterator to initialize : {(y - x)} seconds")
482
+
483
+ # TODO come up with a better solution
484
+ class DummyDataset:
485
+ def __init__(self, parent):
486
+ self.parent = parent
487
+
488
+ def __len__(self):
489
+ return self.parent.size
490
+
491
+ self.dataset = DummyDataset(self) # Used by NeMo
492
+
493
+ x = time.time()
494
+ self.manifest_processor = ASRManifestProcessor(
495
+ manifest_filepath=manifest_filepath,
496
+ parser=parser,
497
+ max_duration=max_duration,
498
+ min_duration=min_duration,
499
+ max_utts=0,
500
+ bos_id=bos_id,
501
+ eos_id=eos_id,
502
+ pad_id=pad_id,
503
+ index_by_file_id=self.is_tarred_dataset,
504
+ )
505
+ y = time.time()
506
+ logging.info(f"Time to build nemo manifest processor - {(y - x)} seconds")
507
+
508
+ def reset(self):
509
+ self._iter.reset()
510
+
511
+ def __iter__(self):
512
+ return self
513
+
514
+ def next(self):
515
+ return self.__next__()
516
+
517
+ @property
518
+ def size(self):
519
+ return self._iter.size
520
+
521
+ def __len__(self):
522
+ return len(self._iter)
523
+
524
+ def __next__(self):
525
+ outputs = self._iter.next()
526
+ assert len(outputs) == 1
527
+ dali_out = outputs[0]
528
+ manifest_indices = dali_out['manifest_indices'].numpy()
529
+
530
+ out = {}
531
+ out_names = ['processed_signal', 'processed_signal_len', 'audio', 'audio_len']
532
+ for out_name in out_names:
533
+ if out_name in dali_out:
534
+ out[out_name] = dali_out[out_name].detach().clone()
535
+
536
+ text_tokens = []
537
+ text_tokens_len = []
538
+ max_len = 0
539
+ batch_size = manifest_indices.shape[0]
540
+ for i, manifest_index in enumerate(manifest_indices):
541
+
542
+ if not self.is_tarred_dataset:
543
+ # Loose-file dataset. Index is integer based.
544
+ manifest_index = manifest_index[0]
545
+ text, text_length = self.manifest_processor.process_text_by_id(manifest_index)
546
+ else:
547
+ # Tarred-file dataset. Index is filename based.
548
+ resolved_manifest_indices = manifest_index.tobytes().decode().split(":")
549
+ resolved_manifest_index = resolved_manifest_indices[2] # we require just the filename segment
550
+ resolved_manifest_index = os.path.splitext(resolved_manifest_index)[0] # we dont need file extension
551
+ text, text_length = self.manifest_processor.process_text_by_file_id(resolved_manifest_index)
552
+
553
+ text_tokens_len.append(text_length)
554
+ text_tokens.append(text)
555
+ if text_length > max_len:
556
+ max_len = text_length
557
+
558
+ transcript_out = torch.full([batch_size, max_len], fill_value=self.manifest_processor.pad_id, dtype=torch.long)
559
+ for i, n in enumerate(text_tokens_len):
560
+ transcript_out[i, :n] = torch.tensor(text_tokens[i], dtype=torch.long)
561
+ transcript_len_out = torch.tensor(text_tokens_len, dtype=torch.long)
562
+
563
+ out['transcript'] = transcript_out
564
+ out['transcript_len'] = transcript_len_out
565
+ return DALIOutputs(out)
566
+
567
+
568
+ class AudioToCharDALIDataset(_AudioTextDALIDataset):
569
+ """
570
+ Character based NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a
571
+ sample descriptor in JSON, including audio files, transcripts, and durations (in seconds).
572
+ Here's an example:
573
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
574
+ ...
575
+ {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
576
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
577
+
578
+ Args:
579
+ manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths.
580
+ device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'.
581
+ batch_size (int): Number of samples in a batch.
582
+ labels (List[str]): String containing all the possible characters to map to.
583
+ sample_rate (int): Sample rate to resample loaded audio to.
584
+ num_threads (int): Number of CPU processing threads to be created by the DALI pipeline.
585
+ max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files.
586
+ min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files.
587
+ blank_index (int): blank character index, default = -1
588
+ unk_index (int): unk_character index, default = -1
589
+ normalize (bool): whether to normalize transcript text (default): True
590
+ bos_id (int): Id of beginning of sequence symbol to append if not None
591
+ eos_id (int): Id of end of sequence symbol to append if not None
592
+ pad_id (int): Id used to pad the input. Defaults to 0 if not provided.
593
+ trim (bool): If True, it will extract the nonsilent region of the loaded audio signal.
594
+ shuffle (bool): If set to True, the dataset will shuffled after loading.
595
+ drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size.
596
+ If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
597
+ parser (str, callable): A str for an inbuilt parser, or a callable with signature f(str) -> List[int].
598
+ device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
599
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
600
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
601
+ preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
602
+ return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
603
+ """
604
+
605
+ def __init__(
606
+ self,
607
+ manifest_filepath: str,
608
+ device: str,
609
+ batch_size: int,
610
+ labels: Union[str, List[str]],
611
+ sample_rate: int = 16000,
612
+ audio_tar_filepaths: Optional[Union[str, List[str]]] = None,
613
+ audio_tar_index_filepaths: Optional[Union[str, List[str]]] = None,
614
+ num_threads: int = 4,
615
+ max_duration: float = 0.0,
616
+ min_duration: float = 0.0,
617
+ blank_index: int = -1,
618
+ unk_index: int = -1,
619
+ normalize: bool = True,
620
+ bos_id: Optional[int] = None,
621
+ eos_id: Optional[int] = None,
622
+ pad_id: int = 0,
623
+ trim: bool = False,
624
+ shuffle: bool = False,
625
+ drop_last: bool = False,
626
+ parser: Union[str, Callable] = 'en',
627
+ shard_strategy: str = "scatter",
628
+ device_id: int = 0,
629
+ global_rank: int = 0,
630
+ world_size: int = 1,
631
+ preprocessor_cfg: DictConfig = None,
632
+ return_sample_id: bool = False,
633
+ ):
634
+ self.labels = labels
635
+
636
+ parser = parsers.make_parser(
637
+ labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
638
+ )
639
+
640
+ super().__init__(
641
+ manifest_filepath=manifest_filepath,
642
+ device=device,
643
+ batch_size=batch_size,
644
+ audio_tar_filepaths=audio_tar_filepaths,
645
+ audio_tar_index_filepaths=audio_tar_index_filepaths,
646
+ sample_rate=sample_rate,
647
+ num_threads=num_threads,
648
+ max_duration=max_duration,
649
+ min_duration=min_duration,
650
+ bos_id=bos_id,
651
+ eos_id=eos_id,
652
+ pad_id=pad_id,
653
+ trim=trim,
654
+ shuffle=shuffle,
655
+ drop_last=drop_last,
656
+ parser=parser,
657
+ shard_strategy=shard_strategy,
658
+ device_id=device_id,
659
+ global_rank=global_rank,
660
+ world_size=world_size,
661
+ preprocessor_cfg=preprocessor_cfg,
662
+ return_sample_id=return_sample_id,
663
+ )
664
+
665
+
666
+ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
667
+ """
668
+ Subword based NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a
669
+ sample descriptor in JSON, including audio files, transcripts, and durations (in seconds).
670
+ Here's an example:
671
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
672
+ ...
673
+ {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
674
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
675
+
676
+ Args:
677
+ manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths.
678
+ tokenizer (TokenizerSpec): A TokenizerSpec implementation that wraps a tokenization implementation.
679
+ device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'.
680
+ batch_size (int): Number of samples in a batch.
681
+ sample_rate (int): Sample rate to resample loaded audio to.
682
+ num_threads (int): Number of CPU processing threads to be created by the DALI pipeline.
683
+ max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files.
684
+ min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files.
685
+ bos_id (int): Id of beginning of sequence symbol to append if not None. Injected from the tokenizer.
686
+ eos_id (int): Id of end of sequence symbol to append if not None. Injected from the tokenizer.
687
+ pad_id (int): Id used to pad the input. Defaults to 0 if not provided. Injected from the tokenizer.
688
+ trim (bool): If True, it will extract the nonsilent region of the loaded audio signal.
689
+ shuffle (bool): If set to True, the dataset will shuffled after loading.
690
+ drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size.
691
+ If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
692
+
693
+ device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
694
+ global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
695
+ world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
696
+ preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
697
+ use_start_end_token (bool): Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and
698
+ ending of speech respectively.
699
+ return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
700
+ """
701
+
702
+ def __init__(
703
+ self,
704
+ manifest_filepath: str,
705
+ tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
706
+ device: str,
707
+ batch_size: int,
708
+ sample_rate: int = 16000,
709
+ audio_tar_filepaths: Optional[Union[str, List[str]]] = None,
710
+ audio_tar_index_filepaths: Optional[Union[str, List[str]]] = None,
711
+ num_threads: int = 4,
712
+ max_duration: float = 0.0,
713
+ min_duration: float = 0.0,
714
+ trim: bool = False,
715
+ shuffle: bool = False,
716
+ drop_last: bool = False,
717
+ shard_strategy: str = "scatter",
718
+ device_id: int = 0,
719
+ global_rank: int = 0,
720
+ world_size: int = 1,
721
+ preprocessor_cfg: DictConfig = None,
722
+ use_start_end_token: bool = True,
723
+ return_sample_id: bool = False,
724
+ ):
725
+
726
+ if use_start_end_token and hasattr(tokenizer, 'bos_token'):
727
+ bos_id = tokenizer.bos_id
728
+ else:
729
+ bos_id = None
730
+
731
+ if use_start_end_token and hasattr(tokenizer, 'eos_token'):
732
+ eos_id = tokenizer.eos_id
733
+ else:
734
+ eos_id = None
735
+
736
+ if hasattr(tokenizer, 'pad_token'):
737
+ pad_id = tokenizer.pad_id
738
+ else:
739
+ pad_id = 0
740
+
741
+ class TokenizerWrapper:
742
+ def __init__(self, tokenizer):
743
+ self._tokenizer = tokenizer
744
+
745
+ def __call__(self, text):
746
+ t = self._tokenizer.text_to_ids(text)
747
+ return t
748
+
749
+ super().__init__(
750
+ manifest_filepath=manifest_filepath,
751
+ device=device,
752
+ batch_size=batch_size,
753
+ sample_rate=sample_rate,
754
+ audio_tar_filepaths=audio_tar_filepaths,
755
+ audio_tar_index_filepaths=audio_tar_index_filepaths,
756
+ num_threads=num_threads,
757
+ max_duration=max_duration,
758
+ min_duration=min_duration,
759
+ bos_id=bos_id,
760
+ eos_id=eos_id,
761
+ pad_id=pad_id,
762
+ trim=trim,
763
+ shuffle=shuffle,
764
+ drop_last=drop_last,
765
+ parser=TokenizerWrapper(tokenizer),
766
+ shard_strategy=shard_strategy,
767
+ device_id=device_id,
768
+ global_rank=global_rank,
769
+ world_size=world_size,
770
+ preprocessor_cfg=preprocessor_cfg,
771
+ return_sample_id=return_sample_id,
772
+ )
SoundScribe/SpeakerID/nemo/collections/asr/data/audio_to_text_dataset.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import json
17
+ import random
18
+ from math import isclose
19
+ from typing import Any, List, Optional, Union
20
+
21
+ import torch
22
+ from omegaconf import DictConfig, OmegaConf, open_dict
23
+ from omegaconf.listconfig import ListConfig
24
+ from pytorch_lightning.callbacks import BasePredictionWriter
25
+ from torch.utils.data import ChainDataset
26
+
27
+ from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
28
+ from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
29
+ from nemo.collections.common.data.dataset import CodeSwitchedDataset, ConcatDataset
30
+ from nemo.utils import logging
31
+
32
+
33
+ def inject_dataloader_value_from_model_config(model_cfg: dict, dataloader_cfg: DictConfig, key: str):
34
+ """
35
+ Extracts the label set provided at the top level of the model, and propagates it to the dataloader
36
+ config.
37
+
38
+ Args:
39
+ model_cfg: A DictConfig representing the model's config.
40
+ dataloader_cfg: A DictConfig representing the individual data loader
41
+ key: A str value representing a key in the model_cfg whose value will be propagated to the
42
+ dataloader config.
43
+ """
44
+ if key not in model_cfg:
45
+ logging.info(
46
+ f"Model level config does not contain `{key}`, please explicitly provide `{key}` to the dataloaders."
47
+ )
48
+ return
49
+
50
+ if not isinstance(dataloader_cfg, DictConfig):
51
+ dataloader_cfg = DictConfig(dataloader_cfg)
52
+
53
+ # If key exists in the data loader config (either set explicitly or as a placeholder (via None))
54
+ if key in dataloader_cfg:
55
+ # Dataloader `labels` is provided and is non-null
56
+ if dataloader_cfg[key] is not None and model_cfg[key] != dataloader_cfg[key]:
57
+ # Model level `labels` dont match Dataloader level `labels`
58
+ logging.warning(
59
+ f'`{key}` is explicitly provided to the data loader, and is different from '
60
+ f'the `{key}` provided at the model level config.\n'
61
+ f'If this is incorrect, please set the dataloader\'s `{key}` to None.'
62
+ )
63
+
64
+ else:
65
+ # Dataloader `key` is None or values match
66
+ # Propagate from model level `key` (even if they match)
67
+ with open_dict(dataloader_cfg):
68
+ dataloader_cfg[key] = model_cfg[key]
69
+
70
+ else:
71
+ # If key key doesnt even exist in dataloader_cfg, inject it explicitly
72
+ with open_dict(dataloader_cfg):
73
+ dataloader_cfg[key] = model_cfg[key]
74
+
75
+
76
+ def get_concat_char_dataset(
77
+ config: dict, global_rank: int, world_size: int, augmentor: Optional['AudioAugmentor'] = None
78
+ ) -> ConcatDataset:
79
+ """
80
+ Instantiates an instance of ConcatDataset containing one or more intances of
81
+ Character Encoding based AudioToCharDataset.
82
+
83
+ Args:
84
+ config: Config of the AudioToCharDataset.
85
+ global_rank: Global rank of this device.
86
+ world_size: Global world size in the training method.
87
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
88
+
89
+ Returns:
90
+ An instance of ConcatDataset containing one or more instances of AudioToCharDataset.
91
+ """
92
+ if 'labels' not in config:
93
+ logging.warning(f"dataset does not have explicitly defined labels")
94
+
95
+ manifest_filepaths = config['manifest_filepath']
96
+ datasets = []
97
+
98
+ # needed to support validation Concat Datasets that arrive here as
99
+ # [[dataset1,dataset2]] otherwise ModelPT would interfere
100
+ if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str):
101
+ logging.info(f"removing an extra nesting level from {manifest_filepaths}")
102
+ manifest_filepaths = config['manifest_filepath'][0]
103
+
104
+ for manifest_filepath in manifest_filepaths:
105
+ conf = copy.deepcopy(config)
106
+ conf['manifest_filepath'] = manifest_filepath
107
+
108
+ dataset = get_char_dataset(config=conf, augmentor=augmentor)
109
+ datasets.append(dataset)
110
+
111
+ dataset = ConcatDataset(
112
+ datasets,
113
+ sampling_technique=config.get('concat_sampling_technique', 'temperature'),
114
+ sampling_temperature=config.get('concat_sampling_temperature', 5),
115
+ sampling_scale=config.get('concat_sampling_scale', 1),
116
+ sampling_probabilities=config.get('concat_sampling_probabilities', None),
117
+ shuffle=config.get('concat_shuffle', True),
118
+ seed=config.get('concat_sampling_seed', None),
119
+ global_rank=global_rank,
120
+ world_size=world_size,
121
+ )
122
+ return dataset
123
+
124
+
125
+ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) -> audio_to_text.AudioToCharDataset:
126
+ """
127
+ Instantiates a Character Encoding based AudioToCharDataset.
128
+
129
+ Args:
130
+ config: Config of the AudioToCharDataset.
131
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
132
+
133
+ Returns:
134
+ An instance of AudioToCharDataset.
135
+ """
136
+ if 'labels' not in config:
137
+ logging.warning(f"dataset does not have explicitly defined labels")
138
+
139
+ dataset = audio_to_text.AudioToCharDataset(
140
+ manifest_filepath=config['manifest_filepath'],
141
+ labels=config.get('labels', None),
142
+ sample_rate=config['sample_rate'],
143
+ int_values=config.get('int_values', False),
144
+ augmentor=augmentor,
145
+ max_duration=config.get('max_duration', None),
146
+ min_duration=config.get('min_duration', None),
147
+ max_utts=config.get('max_utts', 0),
148
+ blank_index=config.get('blank_index', -1),
149
+ unk_index=config.get('unk_index', -1),
150
+ normalize=config.get('normalize_transcripts', False),
151
+ trim=config.get('trim_silence', False),
152
+ parser=config.get('parser', 'en'),
153
+ return_sample_id=config.get('return_sample_id', False),
154
+ channel_selector=config.get('channel_selector', None),
155
+ )
156
+ return dataset
157
+
158
+
159
+ def get_concat_bpe_dataset(
160
+ config: dict,
161
+ tokenizer: 'TokenizerSpec',
162
+ global_rank: int,
163
+ world_size: int,
164
+ augmentor: Optional['AudioAugmentor'] = None,
165
+ ) -> ConcatDataset:
166
+ """
167
+ Instantiates a ContactDataset based on several Byte Pair Encoding / Word Piece Encoding based AudioToBPEDatasets.
168
+
169
+ Args:
170
+ config: Config of the AudioToBPEDataset.
171
+ tokenizer: An instance of a TokenizerSpec object.
172
+ global_rank: Global rank of this device.
173
+ world_size: Global world size in the training method.
174
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
175
+
176
+ Returns:
177
+ An instance of ConcatDataset containing several instances of AudioToBPEDataset.
178
+ """
179
+ manifest_filepaths = config['manifest_filepath']
180
+ datasets = []
181
+
182
+ # needed to support validation Concat Datasets that arrive here as
183
+ # [[dataset1,dataset2]] otherwise ModelPT would interfere
184
+ if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str):
185
+ logging.info(f"removing an extra nesting level from {manifest_filepaths}")
186
+ manifest_filepaths = config['manifest_filepath'][0]
187
+
188
+ for manifest_filepath in manifest_filepaths:
189
+ conf = copy.deepcopy(config)
190
+ conf['manifest_filepath'] = manifest_filepath
191
+ dataset = get_bpe_dataset(config=conf, tokenizer=tokenizer, augmentor=augmentor)
192
+ datasets.append(dataset)
193
+
194
+ dataset = ConcatDataset(
195
+ datasets,
196
+ sampling_technique=config.get('concat_sampling_technique', 'temperature'),
197
+ sampling_temperature=config.get('concat_sampling_temperature', 5),
198
+ sampling_scale=config.get('concat_sampling_scale', 1),
199
+ sampling_probabilities=config.get('concat_sampling_probabilities', None),
200
+ shuffle=config.get('concat_shuffle', True),
201
+ seed=config.get('concat_sampling_seed', None),
202
+ global_rank=global_rank,
203
+ world_size=world_size,
204
+ )
205
+ return dataset
206
+
207
+
208
+ def get_bpe_dataset(
209
+ config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['AudioAugmentor'] = None
210
+ ) -> audio_to_text.AudioToBPEDataset:
211
+ """
212
+ Instantiates a Byte Pair Encoding / Word Piece Encoding based AudioToBPEDataset.
213
+
214
+ Args:
215
+ config: Config of the AudioToBPEDataset.
216
+ tokenizer: An instance of a TokenizerSpec object.
217
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
218
+
219
+ Returns:
220
+ An instance of AudioToBPEDataset.
221
+ """
222
+ dataset = audio_to_text.AudioToBPEDataset(
223
+ manifest_filepath=config['manifest_filepath'],
224
+ tokenizer=tokenizer,
225
+ sample_rate=config['sample_rate'],
226
+ int_values=config.get('int_values', False),
227
+ augmentor=augmentor,
228
+ max_duration=config.get('max_duration', None),
229
+ min_duration=config.get('min_duration', None),
230
+ max_utts=config.get('max_utts', 0),
231
+ trim=config.get('trim_silence', False),
232
+ use_start_end_token=config.get('use_start_end_token', True),
233
+ return_sample_id=config.get('return_sample_id', False),
234
+ channel_selector=config.get('channel_selector', None),
235
+ )
236
+ return dataset
237
+
238
+
239
+ def get_concat_tarred_dataset(
240
+ config: dict,
241
+ shuffle_n: int,
242
+ global_rank: int,
243
+ world_size: int,
244
+ tokenizer: Optional['TokenizerSpec'] = None,
245
+ augmentor: Optional['AudioAugmentor'] = None,
246
+ ) -> ConcatDataset:
247
+ """
248
+ Instantiates a ConcatDataset containing multiple Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.
249
+
250
+ Args:
251
+ config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
252
+ shuffle_n: How many samples to look ahead and load to be shuffled.
253
+ See WebDataset documentation for more details.
254
+ tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
255
+ global_rank: Global rank of this device.
256
+ world_size: Global world size in the training method.
257
+ Passsing None would return a char-based dataset.
258
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
259
+
260
+ Returns:
261
+ An instance of ConcatDataset containing one or more TarredAudioToBPEDatasets or TarredAudioToCharDatasets.
262
+ """
263
+
264
+ tarred_audio_filepaths = config['tarred_audio_filepaths']
265
+ manifest_filepaths = config['manifest_filepath']
266
+ datasets = []
267
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
268
+ zip(tarred_audio_filepaths, manifest_filepaths)
269
+ ):
270
+ conf = copy.deepcopy(config)
271
+ conf['manifest_filepath'] = manifest_filepath
272
+ conf['tarred_audio_filepaths'] = tarred_audio_filepath
273
+ dataset = get_tarred_dataset(
274
+ config=conf,
275
+ tokenizer=tokenizer,
276
+ shuffle_n=shuffle_n,
277
+ global_rank=global_rank,
278
+ world_size=world_size,
279
+ augmentor=augmentor,
280
+ )
281
+ datasets.append(dataset)
282
+
283
+ dataset = ConcatDataset(
284
+ datasets,
285
+ sampling_technique=config.get('concat_sampling_technique', 'temperature'),
286
+ sampling_temperature=config.get('concat_sampling_temperature', 5),
287
+ sampling_scale=config.get('concat_sampling_scale', 1),
288
+ sampling_probabilities=config.get('concat_sampling_probabilities', None),
289
+ shuffle=config.get('concat_shuffle', True),
290
+ seed=config.get('concat_sampling_seed', None),
291
+ global_rank=global_rank,
292
+ world_size=world_size,
293
+ )
294
+ return dataset
295
+
296
+
297
+ def get_tarred_dataset(
298
+ config: dict,
299
+ shuffle_n: int,
300
+ global_rank: int,
301
+ world_size: int,
302
+ tokenizer: Optional['TokenizerSpec'] = None,
303
+ augmentor: Optional['AudioAugmentor'] = None,
304
+ ) -> Union[audio_to_text.TarredAudioToBPEDataset, audio_to_text.TarredAudioToCharDataset]:
305
+ """
306
+ Instantiates a Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.
307
+
308
+ Args:
309
+ config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
310
+ shuffle_n: How many samples to look ahead and load to be shuffled.
311
+ See WebDataset documentation for more details.
312
+ tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
313
+ global_rank: Global rank of this device.
314
+ world_size: Global world size in the training method.
315
+ Passsing None would return a char-based dataset.
316
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
317
+
318
+ Returns:
319
+ An instance of TarredAudioToBPEDataset or TarredAudioToCharDataset.
320
+ """
321
+ tarred_audio_filepaths = config['tarred_audio_filepaths']
322
+ manifest_filepaths = config['manifest_filepath']
323
+ datasets = []
324
+ tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
325
+ manifest_filepaths = convert_to_config_list(manifest_filepaths)
326
+
327
+ bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets
328
+ if bucketing_weights:
329
+ for idx, weight in enumerate(bucketing_weights):
330
+ if not isinstance(weight, int) or weight <= 0:
331
+ raise ValueError(f"bucket weights must be positive integers")
332
+
333
+ if len(manifest_filepaths) != len(tarred_audio_filepaths):
334
+ raise ValueError(
335
+ f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
336
+ )
337
+
338
+ if 'labels' not in config:
339
+ logging.warning(f"dataset does not have explicitly defined labels")
340
+
341
+ if 'max_utts' in config:
342
+ raise ValueError('"max_utts" parameter is not supported for tarred datasets')
343
+
344
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
345
+ zip(tarred_audio_filepaths, manifest_filepaths)
346
+ ):
347
+ if len(tarred_audio_filepath) == 1:
348
+ tarred_audio_filepath = tarred_audio_filepath[0]
349
+ if len(manifest_filepath) == 1:
350
+ manifest_filepath = manifest_filepath[0]
351
+
352
+ if tokenizer is None:
353
+ dataset = audio_to_text.TarredAudioToCharDataset(
354
+ audio_tar_filepaths=tarred_audio_filepath,
355
+ manifest_filepath=manifest_filepath,
356
+ labels=config.get('labels', None),
357
+ sample_rate=config['sample_rate'],
358
+ int_values=config.get('int_values', False),
359
+ augmentor=augmentor,
360
+ shuffle_n=shuffle_n,
361
+ max_duration=config.get('max_duration', None),
362
+ min_duration=config.get('min_duration', None),
363
+ blank_index=config.get('blank_index', -1),
364
+ unk_index=config.get('unk_index', -1),
365
+ normalize=config.get('normalize_transcripts', False),
366
+ trim=config.get('trim_silence', False),
367
+ parser=config.get('parser', 'en'),
368
+ shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
369
+ shard_manifests=config.get('shard_manifests', False),
370
+ global_rank=global_rank,
371
+ world_size=world_size,
372
+ return_sample_id=config.get('return_sample_id', False),
373
+ )
374
+ else:
375
+ dataset = audio_to_text.TarredAudioToBPEDataset(
376
+ audio_tar_filepaths=tarred_audio_filepath,
377
+ manifest_filepath=manifest_filepath,
378
+ tokenizer=tokenizer,
379
+ sample_rate=config['sample_rate'],
380
+ int_values=config.get('int_values', False),
381
+ augmentor=augmentor,
382
+ shuffle_n=shuffle_n,
383
+ max_duration=config.get('max_duration', None),
384
+ min_duration=config.get('min_duration', None),
385
+ trim=config.get('trim_silence', False),
386
+ use_start_end_token=config.get('use_start_end_token', True),
387
+ shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
388
+ shard_manifests=config.get('shard_manifests', False),
389
+ global_rank=global_rank,
390
+ world_size=world_size,
391
+ return_sample_id=config.get('return_sample_id', False),
392
+ )
393
+ if bucketing_weights:
394
+ [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
395
+ else:
396
+ datasets.append(dataset)
397
+
398
+ return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
399
+
400
+
401
+ def get_code_switched_dataset(
402
+ config: dict,
403
+ shuffle_n: int,
404
+ global_rank: int,
405
+ world_size: int,
406
+ tokenizer: Optional['TokenizerSpec'] = None,
407
+ augmentor: Optional['AudioAugmentor'] = None,
408
+ ) -> CodeSwitchedDataset:
409
+
410
+ if 'manifest_filepath' not in config:
411
+ raise ValueError("`manifest_filepath` must be provided in the dataset config if `is_code_switched=True`")
412
+ if 'code_switched' not in config:
413
+ raise ValueError("`code_switched` param group must be in the dataset config if `is_code_switched=True`")
414
+
415
+ manifest_filepaths = config['manifest_filepath']
416
+ tarred_audio_filepaths = config.get('tarred_audio_filepaths', None)
417
+
418
+ cs_config = OmegaConf.to_container(config['code_switched'])
419
+
420
+ # needed to support validation Datasets that arrive here as
421
+ # [[dataset1,dataset2]] otherwise ModelPT would interfere
422
+ if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str):
423
+ manifest_filepaths = config['manifest_filepath'][0]
424
+ if tarred_audio_filepaths is None:
425
+ tarred_audio_filepaths = [None] * len(manifest_filepaths)
426
+
427
+ if len(manifest_filepaths) != len(tarred_audio_filepaths):
428
+ raise ValueError(
429
+ f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of items."
430
+ )
431
+
432
+ datasets = []
433
+ for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
434
+ zip(tarred_audio_filepaths, manifest_filepaths)
435
+ ):
436
+ conf = copy.deepcopy(config)
437
+ conf['manifest_filepath'] = manifest_filepath
438
+ with open_dict(conf):
439
+ conf['tarred_audio_filepaths'] = tarred_audio_filepath
440
+ if tarred_audio_filepath is None or len(tarred_audio_filepath) == 0:
441
+ if tokenizer is None:
442
+ dataset = get_char_dataset(config=conf, augmentor=None)
443
+ else:
444
+ dataset = get_bpe_dataset(config=conf, tokenizer=tokenizer, augmentor=None)
445
+ else:
446
+ dataset = get_tarred_dataset(
447
+ config=conf,
448
+ tokenizer=tokenizer,
449
+ shuffle_n=shuffle_n,
450
+ global_rank=global_rank,
451
+ world_size=world_size,
452
+ augmentor=None,
453
+ )
454
+ datasets.append(dataset)
455
+
456
+ config = OmegaConf.to_container(config)
457
+
458
+ dataset = CodeSwitchedDataset(
459
+ datasets,
460
+ shuffle=cs_config.get('shuffle', True),
461
+ min_duration=cs_config.get('min_duration', 4),
462
+ max_duration=cs_config.get('max_duration', 20),
463
+ min_monolingual=cs_config.get('min_monolingual', 0.3),
464
+ lang_probs=cs_config.get('probs', None),
465
+ db_norm=cs_config.get('db_norm', -25.0),
466
+ pause_start=cs_config.get('pause_start', 0),
467
+ pause_join=cs_config.get('pause_join', 0),
468
+ pause_end=cs_config.get('pause_end', 0),
469
+ sampling_scales=cs_config.get('sampling_scales', None),
470
+ seed=cs_config.get('seed', None),
471
+ global_rank=global_rank,
472
+ world_size=world_size,
473
+ pure_random=cs_config.get('pure_random', False),
474
+ force_monochannel=cs_config.get('force_monochannel', True),
475
+ infinity_mode=cs_config.get('infinity_mode', False),
476
+ sample_rate=config['sample_rate'],
477
+ augmentor=augmentor,
478
+ )
479
+
480
+ return dataset
481
+
482
+
483
+ def get_dali_char_dataset(
484
+ config: dict,
485
+ shuffle: bool,
486
+ device_id: int,
487
+ global_rank: int,
488
+ world_size: int,
489
+ preprocessor_cfg: Optional[DictConfig] = None,
490
+ ) -> audio_to_text_dali.AudioToCharDALIDataset:
491
+ """
492
+ Instantiates a Character Encoding based AudioToCharDALIDataset.
493
+
494
+ Args:
495
+ config: Config of the AudioToCharDALIDataset.
496
+ shuffle: Bool flag whether to shuffle the dataset.
497
+ device_id: Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
498
+ global_rank: Global rank of this device.
499
+ world_size: Global world size in the training method.
500
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
501
+ preprocessor_cfg: Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
502
+
503
+ Returns:
504
+ An instance of AudioToCharDALIDataset.
505
+ """
506
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
507
+ dataset = audio_to_text_dali.AudioToCharDALIDataset(
508
+ manifest_filepath=config['manifest_filepath'],
509
+ device=device,
510
+ batch_size=config['batch_size'],
511
+ labels=config['labels'],
512
+ sample_rate=config['sample_rate'],
513
+ audio_tar_filepaths=config.get('tarred_audio_filepaths', None),
514
+ audio_tar_index_filepaths=config.get('tarred_audio_index_filepaths', None),
515
+ max_duration=config.get('max_duration', None),
516
+ min_duration=config.get('min_duration', None),
517
+ blank_index=config.get('blank_index', -1),
518
+ unk_index=config.get('unk_index', -1),
519
+ normalize=config.get('normalize_transcripts', False),
520
+ trim=config.get('trim_silence', False),
521
+ parser=config.get('parser', 'en'),
522
+ shuffle=shuffle,
523
+ shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
524
+ device_id=device_id,
525
+ global_rank=global_rank,
526
+ world_size=world_size,
527
+ preprocessor_cfg=preprocessor_cfg,
528
+ return_sample_id=config.get('return_sample_id', False),
529
+ )
530
+ return dataset
531
+
532
+
533
+ def get_dali_bpe_dataset(
534
+ config: dict,
535
+ tokenizer,
536
+ shuffle: bool,
537
+ device_id: int,
538
+ global_rank: int,
539
+ world_size: int,
540
+ preprocessor_cfg: Optional[DictConfig] = None,
541
+ ) -> audio_to_text_dali.AudioToCharDALIDataset:
542
+ """
543
+ Instantiates a Subword Encoding based AudioToBPEDALIDataset.
544
+
545
+ Args:
546
+ config: Config of the AudioToBPEDALIDataset.
547
+ tokenizer: An implementation of NeMo TokenizerSpec.
548
+ shuffle: Bool flag whether to shuffle the dataset.
549
+ device_id: Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
550
+ global_rank: Global rank of this device.
551
+ world_size: Global world size in the training method.
552
+ preprocessor_cfg: Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
553
+
554
+ Returns:
555
+ An instance of AudioToCharDALIDataset.
556
+ """
557
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
558
+ dataset = audio_to_text_dali.AudioToBPEDALIDataset(
559
+ manifest_filepath=config['manifest_filepath'],
560
+ tokenizer=tokenizer,
561
+ device=device,
562
+ batch_size=config['batch_size'],
563
+ sample_rate=config['sample_rate'],
564
+ audio_tar_filepaths=config.get('tarred_audio_filepaths', None),
565
+ audio_tar_index_filepaths=config.get('tarred_audio_index_filepaths', None),
566
+ max_duration=config.get('max_duration', None),
567
+ min_duration=config.get('min_duration', None),
568
+ trim=config.get('trim_silence', False),
569
+ use_start_end_token=config.get('use_start_end_token', True),
570
+ shuffle=shuffle,
571
+ shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
572
+ device_id=device_id,
573
+ global_rank=global_rank,
574
+ world_size=world_size,
575
+ preprocessor_cfg=preprocessor_cfg,
576
+ return_sample_id=config.get('return_sample_id', False),
577
+ )
578
+ return dataset
579
+
580
+
581
+ def get_audio_to_text_char_dataset_from_config(
582
+ config, local_rank: int, global_rank: int, world_size: int, preprocessor_cfg: Optional[DictConfig] = None
583
+ ):
584
+ """
585
+ Construct Audio-To-Text Char dataset from a config.
586
+ Args:
587
+ config: dataset config
588
+ local_rank: model local rank
589
+ global_rank: model global rand
590
+ world_size: world size
591
+ preprocessor_cfg: preprocessor config, for DALI dataset
592
+
593
+ Returns:
594
+ constructed dataset or None if dataset config is invalid or nothing to load
595
+ """
596
+ if 'augmentor' in config:
597
+ augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size)
598
+ else:
599
+ augmentor = None
600
+
601
+ is_concat = config.get('is_concat', False)
602
+ if is_concat:
603
+ if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
604
+ logging.warning(
605
+ f"Concat dataset requires `concat_sampling_technique` but it was not provided. Config: {config}"
606
+ )
607
+ return None
608
+ if config['concat_sampling_technique'] == 'random':
609
+ if not 'concat_sampling_probabilities' in config:
610
+ logging.warning(f"Concat dataset requires `concat_sampling_probabilities` list. Config: {config}")
611
+ return None
612
+ else:
613
+ if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
614
+ logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
615
+ return None
616
+
617
+ shuffle = config['shuffle']
618
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
619
+ if config.get('use_dali', False):
620
+ device_id = local_rank if device == 'gpu' else None
621
+ dataset = get_dali_char_dataset(
622
+ config=config,
623
+ shuffle=shuffle,
624
+ device_id=device_id,
625
+ global_rank=global_rank,
626
+ world_size=world_size,
627
+ preprocessor_cfg=preprocessor_cfg,
628
+ )
629
+ return dataset
630
+
631
+ # Instantiate a code-switched dataset if config is present
632
+ if config.get('is_code_switched', False):
633
+ if 'manifest_filepath' in config and config['manifest_filepath'] is None:
634
+ logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
635
+ return None
636
+ if not ('code_switched' in config and config['code_switched'] is not None):
637
+ logging.warning(
638
+ f"Code switched dataset requires `*_ds.code_switched.*` dict but it was not provided. Config: {config}"
639
+ )
640
+ return None
641
+ if (
642
+ ('probs' in config['code_switched'])
643
+ and (config['code_switched']['probs'] is not None)
644
+ and (not isclose(sum(config['code_switched']['probs']), 1, abs_tol=1e-6))
645
+ ):
646
+ logging.warning(f"`.code_switched.probs` need to sum to 1. Config: {config['code_switched']}")
647
+ return None
648
+
649
+ shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
650
+ dataset = get_code_switched_dataset(
651
+ config=config,
652
+ shuffle_n=shuffle_n,
653
+ global_rank=global_rank,
654
+ world_size=world_size,
655
+ tokenizer=None,
656
+ augmentor=augmentor,
657
+ )
658
+ # Instantiate tarred dataset loader or normal dataset loader
659
+ elif config.get('is_tarred', False):
660
+ if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
661
+ 'manifest_filepath' in config and config['manifest_filepath'] is None
662
+ ):
663
+ logging.warning(
664
+ "Could not load dataset as `manifest_filepath` was None or "
665
+ f"`tarred_audio_filepaths` is None. Provided config : {config}"
666
+ )
667
+ return None
668
+
669
+ shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
670
+ if is_concat:
671
+ dataset = get_concat_tarred_dataset(
672
+ config=config,
673
+ shuffle_n=shuffle_n,
674
+ global_rank=global_rank,
675
+ world_size=world_size,
676
+ augmentor=augmentor,
677
+ )
678
+ else:
679
+ dataset = get_tarred_dataset(
680
+ config=config,
681
+ shuffle_n=shuffle_n,
682
+ global_rank=global_rank,
683
+ world_size=world_size,
684
+ augmentor=augmentor,
685
+ )
686
+ else:
687
+ if 'manifest_filepath' in config and config['manifest_filepath'] is None:
688
+ logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
689
+ return None
690
+ if is_concat:
691
+ dataset = get_concat_char_dataset(
692
+ config=config, global_rank=global_rank, world_size=world_size, augmentor=augmentor
693
+ )
694
+ else:
695
+ dataset = get_char_dataset(config=config, augmentor=augmentor)
696
+ return dataset
697
+
698
+
699
+ def get_audio_to_text_bpe_dataset_from_config(
700
+ config,
701
+ local_rank: int,
702
+ global_rank: int,
703
+ world_size: int,
704
+ tokenizer,
705
+ preprocessor_cfg: Optional[DictConfig] = None,
706
+ ):
707
+ """
708
+ Construct Audio-To-Text BPE dataset from a config.
709
+ Args:
710
+ config: BPE dataset config
711
+ local_rank: model local rank
712
+ global_rank: model global rand
713
+ world_size: world size
714
+ tokenizer: BPE tokenizer
715
+ preprocessor_cfg: preprocessor config, for DALI BPE dataset
716
+
717
+ Returns:
718
+ constructed dataset or None if dataset config is invalid or nothing to load
719
+ """
720
+ if 'augmentor' in config:
721
+ augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size)
722
+ else:
723
+ augmentor = None
724
+
725
+ is_concat = config.get('is_concat', False)
726
+ if is_concat:
727
+ if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
728
+ logging.warning(
729
+ f"Concat dataset requires `concat_sampling_technique` but it was not provided. Config: {config}"
730
+ )
731
+ return None
732
+
733
+ if config['concat_sampling_technique'] == 'random':
734
+ if not 'concat_sampling_probabilities' in config:
735
+ logging.warning(f"Concat dataset requires `concat_sampling_probabilities` list. Config: {config}")
736
+ return None
737
+ else:
738
+ if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
739
+ logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
740
+ return None
741
+
742
+ shuffle = config['shuffle']
743
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
744
+ if config.get('use_dali', False):
745
+ device_id = local_rank if device == 'gpu' else None
746
+ dataset = get_dali_bpe_dataset(
747
+ config=config,
748
+ tokenizer=tokenizer,
749
+ shuffle=shuffle,
750
+ device_id=device_id,
751
+ global_rank=global_rank,
752
+ world_size=world_size,
753
+ preprocessor_cfg=preprocessor_cfg,
754
+ )
755
+ return dataset
756
+
757
+ # Instantiate a code-switched dataset if config is present
758
+ if config.get('is_code_switched', False):
759
+ if 'manifest_filepath' in config and config['manifest_filepath'] is None:
760
+ logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
761
+ return None
762
+ if not ('code_switched' in config and config['code_switched'] is not None):
763
+ logging.warning(
764
+ f"Code switched dataset requires `*_ds.code_switched.*` dict but it was not provided. Config: {config}"
765
+ )
766
+ return None
767
+ if (
768
+ ('probs' in config['code_switched'])
769
+ and (config['code_switched']['probs'] is not None)
770
+ and (not isclose(sum(config['code_switched']['probs']), 1, abs_tol=1e-6))
771
+ ):
772
+ logging.warning(f"`.code_switched.probs` need to sum to 1. Config: {config['code_switched']}")
773
+ return None
774
+
775
+ shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
776
+ dataset = get_code_switched_dataset(
777
+ config=config,
778
+ shuffle_n=shuffle_n,
779
+ global_rank=global_rank,
780
+ world_size=world_size,
781
+ tokenizer=tokenizer,
782
+ augmentor=augmentor,
783
+ )
784
+ # Instantiate tarred dataset loader or normal dataset loader
785
+ elif config.get('is_tarred', False):
786
+ if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
787
+ 'manifest_filepath' in config and config['manifest_filepath'] is None
788
+ ):
789
+ logging.warning(
790
+ "Could not load dataset as `manifest_filepath` was None or "
791
+ f"`tarred_audio_filepaths` is None. Provided config : {config}"
792
+ )
793
+ return None
794
+
795
+ shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
796
+ if is_concat:
797
+ dataset = get_concat_tarred_dataset(
798
+ config=config,
799
+ tokenizer=tokenizer,
800
+ shuffle_n=shuffle_n,
801
+ global_rank=global_rank,
802
+ world_size=world_size,
803
+ augmentor=augmentor,
804
+ )
805
+ else:
806
+ dataset = get_tarred_dataset(
807
+ config=config,
808
+ tokenizer=tokenizer,
809
+ shuffle_n=shuffle_n,
810
+ global_rank=global_rank,
811
+ world_size=world_size,
812
+ augmentor=augmentor,
813
+ )
814
+ else:
815
+ if 'manifest_filepath' in config and config['manifest_filepath'] is None:
816
+ logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
817
+ return None
818
+ if is_concat:
819
+ dataset = get_concat_bpe_dataset(
820
+ config=config,
821
+ global_rank=global_rank,
822
+ world_size=world_size,
823
+ tokenizer=tokenizer,
824
+ augmentor=augmentor,
825
+ )
826
+ else:
827
+ dataset = get_bpe_dataset(config=config, tokenizer=tokenizer, augmentor=augmentor)
828
+ return dataset
829
+
830
+
831
+ class ASRPredictionWriter(BasePredictionWriter):
832
+ def __init__(self, dataset, output_file: str):
833
+ super().__init__(write_interval="batch")
834
+ self.outf = open(output_file, 'w', encoding='utf-8')
835
+ self.dataset = dataset
836
+ self.samples_num = 0
837
+
838
+ def write_on_batch_end(
839
+ self,
840
+ trainer,
841
+ pl_module: 'LightningModule',
842
+ prediction: Any,
843
+ batch_indices: List[int],
844
+ batch: Any,
845
+ batch_idx: int,
846
+ dataloader_idx: int,
847
+ ):
848
+ for sample_id, transcribed_text in prediction:
849
+ item = {}
850
+ sample = self.dataset.get_manifest_sample(sample_id)
851
+ item["audio_filepath"] = sample.audio_file
852
+ item["offset"] = sample.offset
853
+ item["duration"] = sample.duration
854
+ item["text"] = sample.text_raw
855
+ item["pred_text"] = transcribed_text
856
+ self.outf.write(json.dumps(item) + "\n")
857
+ self.samples_num += 1
858
+ return
859
+
860
+ def close_output_file(self):
861
+ self.outf.close()
862
+ return self.samples_num
863
+
864
+
865
+ def convert_to_config_list(initial_list):
866
+ if type(initial_list) is str:
867
+ initial_list = initial_list.split(",")
868
+ if initial_list is None or initial_list == []:
869
+ raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")
870
+ if not isinstance(initial_list, ListConfig):
871
+ initial_list = ListConfig([initial_list])
872
+
873
+ for list_idx, list_val in enumerate(initial_list):
874
+ if type(list_val) != type(initial_list[0]):
875
+ raise ValueError(
876
+ "manifest_filepaths and tarred_audio_filepaths need to be a list of lists for bucketing or just a list of strings"
877
+ )
878
+ if type(initial_list[0]) is not ListConfig:
879
+ initial_list = ListConfig([initial_list])
880
+ return initial_list
881
+
882
+
883
+ def get_chain_dataset(datasets, ds_config, rank=0):
884
+ if len(datasets) > 1:
885
+ if ds_config.get('bucketing_batch_size', None) is not None:
886
+ bucketing_batch_sizes = calc_bucketing_batch_sizes(ds_config, len(datasets))
887
+ logging.info(
888
+ f"Batch bucketing is enabled for {len(datasets)} buckets with adaptive batch sizes of {bucketing_batch_sizes}!"
889
+ )
890
+ for idx, dataset in enumerate(datasets):
891
+ datasets[idx] = audio_to_text.BucketingDataset(
892
+ dataset=dataset, bucketing_batch_size=bucketing_batch_sizes[idx]
893
+ )
894
+ else:
895
+ logging.info(
896
+ f"Batch bucketing is enabled for {len(datasets)} buckets with fixed batch size of {ds_config['batch_size']}!"
897
+ )
898
+
899
+ if len(datasets) == 1:
900
+ return datasets[0]
901
+ bucketing_strategy = ds_config.get('bucketing_strategy', 'synced_randomized')
902
+ if bucketing_strategy == 'fixed_order':
903
+ return ChainDataset(datasets)
904
+ elif bucketing_strategy == 'synced_randomized':
905
+ return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=0)
906
+ elif bucketing_strategy == 'fully_randomized':
907
+ return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=random.randint(0, 30000) + rank)
908
+ else:
909
+ raise ValueError(
910
+ f'bucketing_strategy={bucketing_strategy} is not supported! Supported strategies are [fixed_order, fully_randomized, synced_randomized].'
911
+ )
912
+
913
+
914
+ def calc_bucketing_batch_sizes(ds_config, datasets_len):
915
+ bucketing_batch_size = ds_config['bucketing_batch_size']
916
+ bucketing_weights = ds_config.get('bucketing_weights', None) # To adjust for upsampled buckets
917
+
918
+ bucketing_batch_sizes = []
919
+
920
+ if ds_config['batch_size'] != 1:
921
+ raise ValueError(
922
+ f"batch_size should be set to one when bucketing_batch_size is set and adaptive bucketing is enabled (batch_size={ds_config['batch_size']}!"
923
+ )
924
+ if type(bucketing_batch_size) == int: # linear scaling
925
+ if bucketing_weights: # Want same batchsize for the same duplicated bucket
926
+ for idx, weight in enumerate(bucketing_weights):
927
+ scale_factor = datasets_len - idx
928
+ [bucketing_batch_sizes.append(scale_factor * bucketing_batch_size) for _ in range(weight)]
929
+ else:
930
+ for idx in range(datasets_len):
931
+ scale_factor = datasets_len - idx
932
+ bucketing_batch_sizes.append(scale_factor * bucketing_batch_size)
933
+ elif isinstance(bucketing_batch_size, ListConfig) or isinstance(
934
+ bucketing_batch_size, list
935
+ ): # assigned bucket sizes
936
+ if bucketing_weights: # Want same batchsize for same duplicated bucket
937
+ for idx, weight in enumerate(bucketing_weights):
938
+ [bucketing_batch_sizes.append(bucketing_batch_size[idx]) for _ in range(weight)]
939
+ else:
940
+ bucketing_batch_sizes = bucketing_batch_size
941
+ else:
942
+ raise ValueError(
943
+ f"bucketing_batch_size should be an integer or a list (bucketing_batch_size={bucketing_batch_size})!"
944
+ )
945
+
946
+ if len(bucketing_batch_sizes) != datasets_len:
947
+ raise ValueError(
948
+ f"batch_size should have the same length as the number of buckets ({len(bucketing_batch_sizes)}!={datasets_len}) "
949
+ )
950
+ return bucketing_batch_sizes
SoundScribe/SpeakerID/nemo/collections/asr/data/data_simulation.py ADDED
The diff for this file is too large to render. See raw diff
 
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, List, Optional
15
+
16
+ import torch
17
+
18
+ from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader
19
+ from nemo.collections.common.parts.preprocessing import collections
20
+ from nemo.core.classes import Dataset
21
+ from nemo.core.neural_types import AcousticEncodedRepresentation, LabelsType, LengthsType, NeuralType
22
+ from nemo.utils import logging
23
+
24
+
25
+ def _feature_collate_fn(batch):
26
+ """collate batch of feat sig, feat len, labels, labels len, assuming all features have the same shape.
27
+ Args:
28
+ batch (FloatTensor, LongTensor, LongTensor, LongTensor): A tuple of tuples of feature, feature lengths,
29
+ encoded labels, and encoded labels length.
30
+ """
31
+ packed_batch = list(zip(*batch))
32
+ if len(packed_batch) == 5:
33
+ _, feat_lengths, _, labels_lengths, sample_ids = packed_batch
34
+ elif len(packed_batch) == 4:
35
+ sample_ids = None
36
+ _, feat_lengths, _, labels_lengths = packed_batch
37
+ else:
38
+ raise ValueError("Expects 4 or 5 tensors in the batch!")
39
+
40
+ features, labels = [], []
41
+ for b in batch:
42
+ feat_i, labels_i = b[0], b[2]
43
+ features.append(feat_i)
44
+ labels.append(labels_i)
45
+
46
+ features = torch.stack(features)
47
+ feat_lengths = torch.stack(feat_lengths)
48
+
49
+ labels = torch.stack(labels)
50
+ labels_lengths = torch.stack(labels_lengths)
51
+
52
+ if sample_ids is None:
53
+ return features, feat_lengths, labels, labels_lengths
54
+ else:
55
+ sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
56
+ return features, feat_lengths, labels, labels_lengths, sample_ids
57
+
58
+
59
+ def _audio_feature_collate_fn(batch, feat_pad_val, label_pad_id):
60
+ """collate batch of audio feature, audio len, labels, labels len
61
+ Args:
62
+ batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
63
+ LongTensor): A tuple of tuples of feature, feature lengths,
64
+ labels, and label lengths. This collate func assumes the
65
+ features are torch tensors of Log-Melspectrogram (i.e. [N_MEL, T]).
66
+ """
67
+ packed_batch = list(zip(*batch))
68
+ if len(packed_batch) == 5:
69
+ _, feat_lengths, _, labels_lengths, sample_ids = packed_batch
70
+ elif len(packed_batch) == 4:
71
+ sample_ids = None
72
+ _, feat_lengths, _, labels_lengths = packed_batch
73
+ else:
74
+ raise ValueError("Expects 4 or 5 tensors in the batch!")
75
+ max_feat_len = 0
76
+ has_feat = feat_lengths[0] is not None
77
+ if has_feat:
78
+ max_feat_len = max(feat_lengths).item()
79
+ max_labels_len = max(labels_lengths).item()
80
+
81
+ features, labels = [], []
82
+ for b in batch:
83
+ feat_i, feat_i_len, label_i, label_i_len = b[0], b[1], b[2], b[3]
84
+
85
+ if has_feat:
86
+ feat_i_len = feat_i_len.item()
87
+ if feat_i_len < max_feat_len:
88
+ pad = (0, max_feat_len - feat_i_len)
89
+ feat_i = torch.nn.functional.pad(feat_i, pad, value=feat_pad_val)
90
+ features.append(feat_i)
91
+
92
+ label_i_len = label_i_len.item()
93
+ if label_i_len < max_labels_len:
94
+ pad = (0, max_labels_len - label_i_len)
95
+ label_i = torch.nn.functional.pad(label_i, pad, value=label_pad_id)
96
+ labels.append(label_i)
97
+
98
+ if has_feat:
99
+ features = torch.stack(features)
100
+ feature_lengths = torch.stack(feat_lengths)
101
+ else:
102
+ features, feat_lengths = None, None
103
+ labels = torch.stack(labels)
104
+ labels_lengths = torch.stack(labels_lengths)
105
+
106
+ if sample_ids is None:
107
+ return features, feature_lengths, labels, labels_lengths
108
+ else:
109
+ sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
110
+ return features, feature_lengths, labels, labels_lengths, sample_ids
111
+
112
+
113
+ def _vad_feature_segment_collate_fn(batch, window_length_in_sec, shift_length_in_sec, frame_unit_in_sec):
114
+ """collate batch of audio features, features len, tokens, tokens len
115
+ Args:
116
+ batch (Optional[FloatTensor], Optional[LongTensor], LongTensor,
117
+ LongTensor): A tuple of tuples of signal, signal lengths,
118
+ encoded tokens, and encoded tokens length. This collate func
119
+ assumes the signals are 1d torch tensors (i.e. mono audio).
120
+ batch size equals to 1.
121
+ """
122
+ slice_length = int(window_length_in_sec / frame_unit_in_sec)
123
+ audio_features, feat_lengths, _, tokens_lengths = zip(*batch)
124
+
125
+ slice_length = int(min(slice_length, max(feat_lengths)))
126
+ shift = int(shift_length_in_sec / frame_unit_in_sec)
127
+ has_audio = feat_lengths[0] is not None
128
+
129
+ f_dim = audio_features[0].shape[0]
130
+ audio_features, num_slices, tokens, feat_lengths = [], [], [], []
131
+ append_len_start = torch.div(slice_length, 2, rounding_mode='trunc')
132
+ append_len_end = slice_length - torch.div(slice_length, 2, rounding_mode='trunc')
133
+ for feat_i, feat_i_len, tokens_i, _ in batch:
134
+ start = torch.zeros(f_dim, append_len_start)
135
+ end = torch.zeros(f_dim, append_len_end)
136
+ feat_i = torch.cat((start, feat_i, end), dim=1)
137
+ feat_i_len += slice_length
138
+
139
+ if has_audio:
140
+ slices = max(1, torch.div(feat_i_len - slice_length, shift, rounding_mode='trunc'))
141
+
142
+ for slice_id in range(slices):
143
+ start_idx = slice_id * shift
144
+ end_idx = start_idx + slice_length
145
+ feat_slice = feat_i[:, start_idx:end_idx]
146
+ audio_features.append(feat_slice)
147
+
148
+ num_slices.append(slices)
149
+ tokens.extend([tokens_i] * slices)
150
+ feat_lengths.extend([slice_length] * slices)
151
+
152
+ if has_audio:
153
+ audio_features = torch.stack(audio_features)
154
+ feat_lengths = torch.tensor(feat_lengths)
155
+ else:
156
+ audio_features, feat_lengths = None, None
157
+
158
+ tokens = torch.stack(tokens)
159
+ tokens_lengths = torch.tensor(num_slices)
160
+ return audio_features, feat_lengths, tokens, tokens_lengths
161
+
162
+
163
+ class _FeatureSeqSpeakerLabelDataset(Dataset):
164
+ """
165
+ Dataset that loads tensors via a json file containing paths to feature files, sequences of labels.
166
+ Each new line is a different sample. Example below:
167
+ and their target labels. JSON files should be of the following format:
168
+ {"feature_filepath": "/path/to/feature_0.p", "seq_label": speakerA speakerB SpeakerA ....} \
169
+ ...
170
+ {"feature_filepath": "/path/to/feature_n.p", "seq_label": target_seq_label_n}
171
+ target_seq_label_n is the string of sequence of speaker label, separated by space.
172
+
173
+ Args:
174
+ manifest_filepath (str): Dataset parameter. Path to JSON containing data.
175
+ labels (Optional[list]): Dataset parameter. List of unique labels collected from all samples.
176
+ feature_loader : Dataset parameter. Feature loader to load (external) feature.
177
+ """
178
+
179
+ @property
180
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
181
+ """Returns definitions of module output ports.
182
+ """
183
+ # TODO output type for external features
184
+ output_types = {
185
+ 'external_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
186
+ 'feat_length': NeuralType(tuple('B'), LengthsType()),
187
+ }
188
+
189
+ if self.is_speaker_emb:
190
+ output_types.update(
191
+ {
192
+ 'embs': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
193
+ 'embs_length': NeuralType(tuple('B'), LengthsType()),
194
+ 'label': NeuralType(('B', 'T'), LabelsType()),
195
+ 'label_length': NeuralType(tuple('B'), LengthsType()),
196
+ }
197
+ )
198
+ else:
199
+ output_types.update(
200
+ {'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),}
201
+ )
202
+
203
+ return output_types
204
+
205
+ def __init__(
206
+ self, *, manifest_filepath: str, labels: List[str], feature_loader, is_speaker_emb: bool = False,
207
+ ):
208
+ super().__init__()
209
+ self.collection = collections.ASRFeatureSequenceLabel(manifests_files=manifest_filepath.split(','),)
210
+
211
+ self.feature_loader = feature_loader
212
+ self.labels = labels if labels else self.collection.uniq_labels
213
+ self.is_speaker_emb = is_speaker_emb
214
+
215
+ self.label2id, self.id2label = {}, {}
216
+ for label_id, label in enumerate(self.labels):
217
+ self.label2id[label] = label_id
218
+ self.id2label[label_id] = label
219
+
220
+ for idx in range(len(self.labels[:5])):
221
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
222
+
223
+ def __len__(self):
224
+ return len(self.collection)
225
+
226
+ def __getitem__(self, index):
227
+ sample = self.collection[index]
228
+
229
+ features = self.feature_loader.process(sample.feature_file)
230
+ f, fl = features, torch.tensor(features.shape[0]).long()
231
+
232
+ t = torch.tensor(sample.seq_label).float()
233
+ tl = torch.tensor(len(sample.seq_label)).long()
234
+
235
+ return f, fl, t, tl
236
+
237
+
238
+ class FeatureToSeqSpeakerLabelDataset(_FeatureSeqSpeakerLabelDataset):
239
+ """
240
+ Dataset that loads tensors via a json file containing paths to feature
241
+ files and sequence of speakers. Each new line is a
242
+ different sample. Example below:
243
+ {"feature_filepath": "/path/to/feature_0.p", "seq_label": speakerA speakerB SpeakerA ....} \
244
+ ...
245
+ {"feature_filepath": "/path/to/feature_n.p", "seq_label": target_seq_label_n}
246
+ target_seq_label_n is the string of sequence of speaker label, separated by space.
247
+
248
+ Args:
249
+ manifest_filepath (str): Path to manifest json as described above. Canbe comma-separated paths.
250
+ labels (Optional[list]): String containing all the possible labels to map to
251
+ if None then automatically picks from ASRFeatureSequenceLabel collection.
252
+ feature_loader, Feature load to loader (external) feature.
253
+
254
+ """
255
+
256
+ def _collate_fn(self, batch):
257
+ return _feature_collate_fn(batch)
258
+
259
+
260
+ class FeatureToLabelDataset(Dataset):
261
+ """
262
+ Dataset that loads tensors via a json file containing paths to feature files and their labels.
263
+ Each new line is a different sample. Example below:
264
+ and their target labels. JSON files should be of the following format:
265
+ {"feature_filepath": "/path/to/audio_feature.pt", "label": "1"}
266
+ ...
267
+ {"feature_filepath": "/path/to/audio_feature.pt", "label": "0"}
268
+ Args:
269
+ manifest_filepath (str): Path to JSON containing data.
270
+ labels (Optional[list]): List of unique labels collected from all samples.
271
+ augmentor (Optional): feature augmentation
272
+ window_length_in_sec (float): Window length in seconds.
273
+ shift_length_in_sec (float): Shift length in seconds.
274
+ is_regression_task (bool): if True, the labels are treated as for a regression task.
275
+ cal_labels_occurrence (bool): if True, the labels occurrence will be calculated.
276
+ zero_spec_db_val (float): Value to replace non-speech signals in log-melspectrogram.
277
+ min_duration (float): Minimum duration of the audio file in seconds.
278
+ max_duration (float): Maximum duration of the audio file in seconds.
279
+ """
280
+
281
+ ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
282
+ FRAME_UNIT_TIME_SECS = 0.01
283
+
284
+ @property
285
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
286
+ """Returns definitions of module output ports.
287
+ """
288
+ output_types = {
289
+ 'audio_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
290
+ 'feat_length': NeuralType(tuple('B'), LengthsType()),
291
+ 'labels': NeuralType(('B'), LabelsType()),
292
+ 'labels_length': NeuralType(tuple('B'), LengthsType()),
293
+ }
294
+
295
+ return output_types
296
+
297
+ def __init__(
298
+ self,
299
+ *,
300
+ manifest_filepath: str,
301
+ labels: List[str] = None,
302
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
303
+ window_length_in_sec: float = 0.63,
304
+ shift_length_in_sec: float = 0.01,
305
+ is_regression_task: bool = False,
306
+ cal_labels_occurrence: Optional[bool] = False,
307
+ zero_spec_db_val: float = -16.635,
308
+ min_duration: Optional[float] = None,
309
+ max_duration: Optional[float] = None,
310
+ ):
311
+ super().__init__()
312
+ self.window_length_in_sec = window_length_in_sec
313
+ self.shift_length_in_sec = shift_length_in_sec
314
+ self.zero_spec_db_val = zero_spec_db_val
315
+
316
+ if isinstance(manifest_filepath, str):
317
+ manifest_filepath = manifest_filepath.split(',')
318
+
319
+ self.collection = collections.ASRFeatureLabel(
320
+ manifests_files=manifest_filepath,
321
+ is_regression_task=is_regression_task,
322
+ cal_labels_occurrence=cal_labels_occurrence,
323
+ min_duration=min_duration,
324
+ max_duration=max_duration,
325
+ )
326
+
327
+ self.feature_loader = ExternalFeatureLoader(augmentor=augmentor)
328
+ self.labels = labels if labels else self.collection.uniq_labels
329
+
330
+ self.is_regression_task = is_regression_task
331
+
332
+ if not is_regression_task:
333
+ self.labels = labels if labels else self.collection.uniq_labels
334
+ self.num_classes = len(self.labels) if self.labels is not None else 1
335
+ self.label2id, self.id2label = {}, {}
336
+ self.id2occurrence, self.labels_occurrence = {}, []
337
+
338
+ for label_id, label in enumerate(self.labels):
339
+ self.label2id[label] = label_id
340
+ self.id2label[label_id] = label
341
+ if cal_labels_occurrence:
342
+ self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
343
+
344
+ if cal_labels_occurrence:
345
+ self.labels_occurrence = [self.id2occurrence[k] for k in sorted(self.id2occurrence)]
346
+
347
+ for idx in range(len(self.labels[:5])):
348
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
349
+ else:
350
+ self.labels = []
351
+ self.num_classes = 1
352
+
353
+ def __len__(self):
354
+ return len(self.collection)
355
+
356
+ def __getitem__(self, index):
357
+ sample = self.collection[index]
358
+
359
+ features = self.feature_loader.process(sample.feature_file)
360
+ f, fl = features, torch.tensor(features.shape[1]).long()
361
+
362
+ t = torch.tensor(self.label2id[sample.label])
363
+ tl = torch.tensor(1).long()
364
+
365
+ return f, fl, t, tl
366
+
367
+ def _collate_fn(self, batch):
368
+ return _audio_feature_collate_fn(batch, self.zero_spec_db_val, 0)
369
+
370
+ def _vad_segment_collate_fn(self, batch):
371
+ return _vad_feature_segment_collate_fn(
372
+ batch, self.window_length_in_sec, self.shift_length_in_sec, self.FRAME_UNIT_TIME_SECS
373
+ )
374
+
375
+
376
+ class FeatureToMultiLabelDataset(Dataset):
377
+ """
378
+ Dataset that loads tensors via a json file containing paths to feature files and their labels.
379
+ Each new line is a different sample. Example below:
380
+ and their target labels. JSON files should be of the following format:
381
+ {"feature_filepath": "/path/to/audio_feature.pt", "label": "1 1 0 0 1"}
382
+ ...
383
+ {"feature_filepath": "/path/to/audio_feature.pt", "label": "0 1 0 0"}
384
+ Args:
385
+ manifest_filepath (str): Path to JSON containing data.
386
+ labels (Optional[list]): List of unique labels collected from all samples.
387
+ augmentor (Optional): feature augmentation
388
+ delimiter (str): delimiter to split the labels.
389
+ is_regression_task (bool): if True, the labels are treated as for a regression task.
390
+ cal_labels_occurrence (bool): if True, the labels occurrence will be calculated.
391
+ zero_spec_db_val (float): Value to replace non-speech signals in log-melspectrogram.
392
+ min_duration (float): Minimum duration of the audio file in seconds.
393
+ max_duration (float): Maximum duration of the audio file in seconds.
394
+ """
395
+
396
+ ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
397
+
398
+ @property
399
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
400
+ """Returns definitions of module output ports.
401
+ """
402
+ output_types = {
403
+ 'audio_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
404
+ 'feat_length': NeuralType(tuple('B'), LengthsType()),
405
+ 'labels': NeuralType(('B', 'T'), LabelsType()),
406
+ 'labels_length': NeuralType(tuple('B'), LengthsType()),
407
+ }
408
+
409
+ return output_types
410
+
411
+ def __init__(
412
+ self,
413
+ *,
414
+ manifest_filepath: str,
415
+ labels: List[str] = None,
416
+ augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
417
+ delimiter: Optional[str] = None,
418
+ is_regression_task: bool = False,
419
+ cal_labels_occurrence: Optional[bool] = False,
420
+ zero_spec_db_val: float = -16.635,
421
+ min_duration: Optional[float] = None,
422
+ max_duration: Optional[float] = None,
423
+ ):
424
+ super().__init__()
425
+ self.delimiter = delimiter
426
+ self.zero_spec_db_val = zero_spec_db_val
427
+
428
+ if isinstance(manifest_filepath, str):
429
+ manifest_filepath = manifest_filepath.split(',')
430
+
431
+ self.collection = collections.ASRFeatureLabel(
432
+ manifests_files=manifest_filepath,
433
+ is_regression_task=is_regression_task,
434
+ cal_labels_occurrence=cal_labels_occurrence,
435
+ delimiter=delimiter,
436
+ min_duration=min_duration,
437
+ max_duration=max_duration,
438
+ )
439
+
440
+ self.is_regression_task = is_regression_task
441
+ self.feature_loader = ExternalFeatureLoader(augmentor=augmentor)
442
+ self.labels = labels if labels else self.collection.uniq_labels
443
+
444
+ self.label2id, self.id2label = {}, {}
445
+ if not is_regression_task:
446
+ self.labels = labels if labels else self._get_label_set()
447
+ self.num_classes = len(self.labels) if self.labels is not None else 1
448
+ self.label2id, self.id2label = {}, {}
449
+ for label_id, label in enumerate(self.labels):
450
+ self.label2id[label] = label_id
451
+ self.id2label[label_id] = label
452
+ if cal_labels_occurrence:
453
+ self.id2occurrence[label_id] = self.collection.labels_occurrence[label]
454
+ self.labels_occurrence.append(self.id2occurrence[label_id])
455
+
456
+ for idx in range(len(self.labels[:5])):
457
+ logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))
458
+ else:
459
+ self.labels = []
460
+ self.num_classes = 1
461
+
462
+ def _get_label_set(self):
463
+ labels = []
464
+ for sample in self.collection:
465
+ label_str = sample.label
466
+ if label_str:
467
+ label_str_list = label_str.split(self.delimiter) if self.delimiter else label_str.split()
468
+ labels.extend(label_str_list)
469
+ return sorted(set(labels))
470
+
471
+ def _label_str_to_tensor(self, label_str: str):
472
+ labels = label_str.split(self.delimiter) if self.delimiter else label_str.split()
473
+
474
+ if self.is_regression_task:
475
+ labels = [float(s) for s in labels]
476
+ labels = torch.tensor(labels).float()
477
+ else:
478
+ labels = [self.label2id[s] for s in labels]
479
+ labels = torch.tensor(labels).long()
480
+ return labels
481
+
482
+ def __len__(self):
483
+ return len(self.collection)
484
+
485
+ def __getitem__(self, index):
486
+ sample = self.collection[index]
487
+
488
+ features = self.feature_loader.process(sample.feature_file)
489
+ f, fl = features, torch.tensor(features.shape[1]).long()
490
+
491
+ t = self._label_str_to_tensor(sample.label)
492
+ tl = torch.tensor(t.size(0)).long()
493
+
494
+ return f, fl, t, tl
495
+
496
+ def _collate_fn(self, batch):
497
+ return _audio_feature_collate_fn(batch, self.zero_spec_db_val, 0)
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_label_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from nemo.collections.asr.data import feature_to_label
17
+
18
+
19
+ def get_feature_seq_speakerlabel_dataset(
20
+ feature_loader, config: dict
21
+ ) -> feature_to_label.FeatureToSeqSpeakerLabelDataset:
22
+ """
23
+ Instantiates a FeatureSeqSpeakerLabelDataset.
24
+ Args:
25
+ config: Config of the FeatureToSeqSpeakerLabelDataset.
26
+
27
+ Returns:
28
+ An instance of FeatureToSeqSpeakerLabelDataset.
29
+ """
30
+ dataset = feature_to_label.FeatureToSeqSpeakerLabelDataset(
31
+ manifest_filepath=config['manifest_filepath'], labels=config['labels'], feature_loader=feature_loader,
32
+ )
33
+ return dataset
34
+
35
+
36
+ def get_feature_label_dataset(
37
+ config: dict, augmentor: Optional['FeatureAugmentor'] = None
38
+ ) -> feature_to_label.FeatureToLabelDataset:
39
+ dataset = feature_to_label.FeatureToLabelDataset(
40
+ manifest_filepath=config['manifest_filepath'],
41
+ labels=config['labels'],
42
+ augmentor=augmentor,
43
+ window_length_in_sec=config.get("window_length_in_sec", 0.63),
44
+ shift_length_in_sec=config.get("shift_length_in_sec", 0.08),
45
+ is_regression_task=config.get("is_regression_task", False),
46
+ cal_labels_occurrence=config.get("cal_labels_occurrence", False),
47
+ zero_spec_db_val=config.get("zero_spec_db_val", -16.635),
48
+ max_duration=config.get('max_duration', None),
49
+ min_duration=config.get('min_duration', None),
50
+ )
51
+ return dataset
52
+
53
+
54
+ def get_feature_multi_label_dataset(
55
+ config: dict, augmentor: Optional['FeatureAugmentor'] = None
56
+ ) -> feature_to_label.FeatureToMultiLabelDataset:
57
+ dataset = feature_to_label.FeatureToMultiLabelDataset(
58
+ manifest_filepath=config['manifest_filepath'],
59
+ labels=config['labels'],
60
+ augmentor=augmentor,
61
+ delimiter=config.get('delimiter', None),
62
+ is_regression_task=config.get("is_regression_task", False),
63
+ cal_labels_occurrence=config.get("cal_labels_occurrence", False),
64
+ zero_spec_db_val=config.get("zero_spec_db_val", -16.635),
65
+ max_duration=config.get('max_duration', None),
66
+ min_duration=config.get('min_duration', None),
67
+ )
68
+ return dataset
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from nemo.collections.asr.data.feature_to_label import _audio_feature_collate_fn
20
+ from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader
21
+ from nemo.collections.asr.parts.preprocessing.features import normalize_batch
22
+ from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
23
+ from nemo.collections.asr.parts.utils.vad_utils import load_speech_segments_from_rttm
24
+ from nemo.collections.common import tokenizers
25
+ from nemo.collections.common.parts.preprocessing import collections, parsers
26
+ from nemo.core.classes import Dataset
27
+ from nemo.core.neural_types import AcousticEncodedRepresentation, LabelsType, LengthsType, NeuralType
28
+
29
+
30
+ class ASRFeatureManifestProcessor:
31
+ def __init__(
32
+ self,
33
+ manifest_filepath: str,
34
+ parser: Union[str, Callable],
35
+ max_duration: Optional[float] = None,
36
+ min_duration: Optional[float] = None,
37
+ max_utts: int = 0,
38
+ bos_id: Optional[int] = None,
39
+ eos_id: Optional[int] = None,
40
+ pad_id: int = 0,
41
+ index_by_file_id: bool = False,
42
+ ):
43
+ self.parser = parser
44
+ self.collection = collections.ASRFeatureText(
45
+ manifests_files=manifest_filepath,
46
+ parser=parser,
47
+ min_duration=min_duration,
48
+ max_duration=max_duration,
49
+ max_number=max_utts,
50
+ index_by_file_id=index_by_file_id,
51
+ )
52
+
53
+ self.eos_id = eos_id
54
+ self.bos_id = bos_id
55
+ self.pad_id = pad_id
56
+
57
+ def process_text_by_id(self, index: int) -> Tuple[List[int], int]:
58
+ sample = self.collection[index]
59
+ return self.process_text_by_sample(sample)
60
+
61
+ def process_text_by_file_id(self, file_id: str) -> Tuple[List[int], int]:
62
+ manifest_idx = self.collection.mapping[file_id][0]
63
+ sample = self.collection[manifest_idx]
64
+ return self.process_text_by_sample(sample)
65
+
66
+ def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -> Tuple[List[int], int]:
67
+ t, tl = sample.text_tokens, len(sample.text_tokens)
68
+
69
+ if self.bos_id is not None:
70
+ t = [self.bos_id] + t
71
+ tl += 1
72
+ if self.eos_id is not None:
73
+ t = t + [self.eos_id]
74
+ tl += 1
75
+
76
+ return t, tl
77
+
78
+
79
+ class _FeatureTextDataset(Dataset):
80
+ """
81
+ Dataset that loads tensors via a json file containing paths to audio feature files, transcripts,
82
+ durations (in seconds) and optional RTTM files. Each new line is a different sample. Example below:
83
+ {"feature_filepath": "/path/to/audio_feature.pt", "text_filepath": "/path/to/audio.txt",
84
+ "rttm_filepath": "/path/to/audio_rttm.rttm", "duration": 23.147}
85
+ ...
86
+ {"feature_filepath": "/path/to/audio_feature.pt", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
87
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
88
+ Args:
89
+ manifest_filepath (str): Path to manifest json as described above. Can be comma-separated paths.
90
+ parser: Str for a language specific preprocessor or a callable.
91
+ normalize (bool): whether and where to normalize feature, must be one of [None, "post_norm", "pre_norm"]
92
+ normalize_type (Union[str, dict]): how to normalize feature, see `nemo.collections.asr.parts.preprocessing.features.normalize_batch`
93
+ use_rttm (bool): whether to use RTTM files if there is any, default to False
94
+ rttm_mode (str): how to use RTTM files, must be one of ['mask', 'drop'], default to 'mask'
95
+ feat_min_len (int): minimum length of feature when rttm_mode=deop, default to 4.
96
+ feat_mask_val (Optional[float]): value used to mask features with RTTM files, default to None to use zero mel-spectralgram
97
+ frame_unit_time_secs (float): time in seconds for each frame
98
+ sample_rate (int): Sample rate to resample loaded audio to
99
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
100
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded audio
101
+ max_duration (float): If audio exceeds this length, do not include in dataset
102
+ min_duration (float): If audio is less than this length, do not include in dataset
103
+ max_utts (int): Limit number of utterances
104
+ trim (bool): whether or not to trim silence. Defaults to False
105
+ bos_id (int): Id of beginning of sequence symbol to append if not None
106
+ eos_id (int): Id of end of sequence symbol to append if not None
107
+ pad_id (int): Id of pad symbol. Defaults to 0
108
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
109
+ channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
110
+ """
111
+
112
+ ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
113
+ NORM_MODES = ["pre_norm", "post_norm"]
114
+ RTTM_MODES = ["mask", "drop"]
115
+
116
+ @property
117
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
118
+ """Returns definitions of module output ports.
119
+ """
120
+ return {
121
+ 'features': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
122
+ 'feature_length': NeuralType(tuple('B'), LengthsType()),
123
+ 'transcripts': NeuralType(('B', 'T'), LabelsType()),
124
+ 'transcript_length': NeuralType(tuple('B'), LengthsType()),
125
+ 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
126
+ }
127
+
128
+ def __init__(
129
+ self,
130
+ manifest_filepath: str,
131
+ parser: Union[str, Callable],
132
+ normalize: Optional[str] = "post_norm",
133
+ normalize_type: Union[str, dict] = "per_feature",
134
+ use_rttm: bool = False,
135
+ rttm_mode: str = "mask",
136
+ feat_min_len: int = 4,
137
+ feat_mask_val: Optional[float] = None,
138
+ frame_unit_time_secs: float = 0.01,
139
+ sample_rate: Optional[int] = 16000,
140
+ augmentor: 'nemo.collections.asr.parts.perturb.FeatureAugmentor' = None,
141
+ max_duration: Optional[int] = None,
142
+ min_duration: Optional[int] = None,
143
+ max_utts: int = 0,
144
+ trim: bool = False,
145
+ bos_id: Optional[int] = None,
146
+ eos_id: Optional[int] = None,
147
+ pad_id: int = 0,
148
+ return_sample_id: bool = False,
149
+ channel_selector: Optional[ChannelSelectorType] = None,
150
+ ):
151
+ if type(manifest_filepath) == str:
152
+ manifest_filepath = manifest_filepath.split(",")
153
+
154
+ self.sample_rate = sample_rate
155
+ self.normalize = normalize
156
+ self.normalize_type = normalize_type
157
+ self.use_rttm = use_rttm
158
+ self.rttm_mode = rttm_mode
159
+ if self.use_rttm and self.rttm_mode not in self.RTTM_MODES:
160
+ raise ValueError(f"`rttm_mode` must be one of {self.RTTM_MODES}, got `{rttm_mode}` instead")
161
+
162
+ self.feat_min_len = feat_min_len
163
+ if feat_mask_val is not None:
164
+ self.feat_mask_val = feat_mask_val
165
+ elif normalize == "pre_norm":
166
+ self.feat_mask_val = 0.0 # similar to SpectralAugmentation
167
+ else:
168
+ self.feat_mask_val = self.ZERO_LEVEL_SPEC_DB_VAL
169
+
170
+ if normalize is not None and normalize not in self.NORM_MODES:
171
+ raise ValueError(f"`normalize` must be one of {self.NORM_MODES}, got `{normalize}` instead")
172
+
173
+ self.frame_unit_time_secs = frame_unit_time_secs
174
+
175
+ self.manifest_processor = ASRFeatureManifestProcessor(
176
+ manifest_filepath=manifest_filepath,
177
+ parser=parser,
178
+ max_duration=max_duration,
179
+ min_duration=min_duration,
180
+ max_utts=max_utts,
181
+ bos_id=bos_id,
182
+ eos_id=eos_id,
183
+ pad_id=pad_id,
184
+ )
185
+ self.featurizer = ExternalFeatureLoader(augmentor=augmentor)
186
+ self.trim = trim
187
+ self.return_sample_id = return_sample_id
188
+ self.channel_selector = channel_selector
189
+
190
+ def get_manifest_sample(self, sample_id):
191
+ return self.manifest_processor.collection[sample_id]
192
+
193
+ def __getitem__(self, index):
194
+ sample = self.manifest_processor.collection[index]
195
+ offset = sample.offset
196
+
197
+ if offset is None:
198
+ offset = 0
199
+
200
+ features = self.featurizer.process(sample.feature_file)
201
+
202
+ f, fl = features, torch.tensor(features.shape[1]).long()
203
+
204
+ t, tl = self.manifest_processor.process_text_by_sample(sample=sample)
205
+
206
+ # Feature normalization
207
+ if self.normalize is None:
208
+ if self.use_rttm and sample.rttm_file:
209
+ f = self.process_features_with_rttm(f, offset, sample.rttm_file, self.feat_mask_val)
210
+ elif self.normalize == "post_norm":
211
+ # (Optional) Masking based on RTTM file
212
+ if self.use_rttm and sample.rttm_file:
213
+ f = self.process_features_with_rttm(f, offset, sample.rttm_file, self.feat_mask_val)
214
+
215
+ f = self.normalize_feature(f)
216
+ else: # pre-norm
217
+ f = self.normalize_feature(f)
218
+ # (Optional) Masking based on RTTM file
219
+ if self.use_rttm and sample.rttm_file:
220
+ f = self.process_features_with_rttm(f, offset, sample.rttm_file, self.feat_mask_val)
221
+
222
+ if self.return_sample_id:
223
+ output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
224
+ else:
225
+ output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
226
+
227
+ return output
228
+
229
+ def process_features_with_rttm(self, features, offset, rttm_file, mask_val):
230
+ segments = load_speech_segments_from_rttm(rttm_file)
231
+ new_features = features.clone()
232
+ sid, fid = 0, 0
233
+ for i in range(features.size(1)):
234
+ t = offset + i * self.frame_unit_time_secs
235
+ while sid < len(segments) - 1 and segments[sid][1] < t:
236
+ sid += 1
237
+ if segments[sid][1] == 0 or t < segments[sid][0] or t > segments[sid][1]:
238
+ # not in speech segment
239
+ if self.rttm_mode == "drop":
240
+ # drop the frame
241
+ continue
242
+ else:
243
+ # mask the frame with specified value
244
+ new_features[:, i] = mask_val
245
+ fid += 1
246
+ else:
247
+ # in speech segment
248
+ new_features[:, fid] = features[:, i]
249
+ fid += 1
250
+
251
+ if fid < self.feat_min_len and self.rttm_mode == "drop":
252
+ new_features[:, : self.feat_min_len] = mask_val
253
+ return new_features[:, : self.feat_min_len]
254
+ return new_features[:, :fid]
255
+
256
+ def __len__(self):
257
+ return len(self.manifest_processor.collection)
258
+
259
+ def _collate_fn(self, batch):
260
+ return _audio_feature_collate_fn(
261
+ batch, feat_pad_val=self.feat_mask_val, label_pad_id=self.manifest_processor.pad_id
262
+ )
263
+
264
+ def normalize_feature(self, feat):
265
+ """
266
+ Args:
267
+ feat: feature tensor of shape [M, T]
268
+ """
269
+ feat = feat.unsqueeze(0) # add batch dim
270
+ feat, _, _ = normalize_batch(feat, torch.tensor([feat.size(-1)]), self.normalize_type)
271
+ return feat.squeeze(0) # delete batch dim
272
+
273
+
274
+ class FeatureToCharDataset(_FeatureTextDataset):
275
+ """
276
+ Dataset that loads tensors via a json file containing paths to audio feature
277
+ files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a
278
+ different sample. Example below:
279
+ {"feature_filepath": "/path/to/audio_feature.pt", "text_filepath":
280
+ "/path/to/audio.txt", "duration": 23.147, "rttm_filepath": "/path/to/audio_rttm.rttm",}
281
+ ...
282
+ {"feature_filepath": "/path/to/audio_feature.pt", "text": "the
283
+ transcription", "offset": 301.75, "duration": 0.82, "utt":
284
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
285
+
286
+ Args:
287
+ manifest_filepath (str): Path to manifest json as described above. Can
288
+ be comma-separated paths.
289
+ labels (str): String containing all the possible characters to map to
290
+ normalize (str): how to normalize feature, must be one of [None, "post_norm", "pre_norm"]
291
+ normalize_type (Union[str, dict]): how to normalize feature, see `nemo.collections.asr.parts.preprocessing.features.normalize_batch`
292
+ use_rttm (bool): whether to use RTTM files if there is any, default to False
293
+ rttm_mode (str): how to use RTTM files, must be one of ['mask', 'drop'], default to 'mask'
294
+ feat_min_len (int): minimum length of feature, default to 4
295
+ feat_mask_val (Optional[float]): value used to mask features with RTTM files, default to None to use zero mel-spectralgram
296
+ frame_unit_time_secs: time in seconds for each frame
297
+ sample_rate (int): Sample rate to resample loaded audio to
298
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
299
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
300
+ object used to augment loaded audio
301
+ max_duration: If audio exceeds this length, do not include in dataset
302
+ min_duration: If audio is less than this length, do not include
303
+ in dataset
304
+ max_utts: Limit number of utterances
305
+ blank_index: blank character index, default = -1
306
+ unk_index: unk_character index, default = -1
307
+ bos_id: Id of beginning of sequence symbol to append if not None
308
+ eos_id: Id of end of sequence symbol to append if not None
309
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
310
+ channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ manifest_filepath: str,
316
+ labels: Union[str, List[str]],
317
+ normalize: Optional[str] = "post_norm",
318
+ normalize_type: Union[str, dict] = "per_feature",
319
+ use_rttm: bool = False,
320
+ rttm_mode: str = "mask",
321
+ feat_min_len: int = 4,
322
+ feat_mask_val: Optional[float] = None,
323
+ frame_unit_time_secs: float = 0.01,
324
+ sample_rate: Optional[int] = 16000,
325
+ augmentor: 'nemo.collections.asr.parts.perturb.FeatureAugmentor' = None,
326
+ max_duration: Optional[int] = None,
327
+ min_duration: Optional[int] = None,
328
+ max_utts: int = 0,
329
+ blank_index: int = -1,
330
+ unk_index: int = -1,
331
+ trim: bool = False,
332
+ bos_id: Optional[int] = None,
333
+ eos_id: Optional[int] = None,
334
+ pad_id: int = 0,
335
+ parser: Union[str, Callable] = 'en',
336
+ return_sample_id: bool = False,
337
+ channel_selector: Optional[ChannelSelectorType] = None,
338
+ ):
339
+ self.labels = labels
340
+
341
+ parser = parsers.make_parser(
342
+ labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize
343
+ )
344
+
345
+ super().__init__(
346
+ manifest_filepath=manifest_filepath,
347
+ parser=parser,
348
+ normalize=normalize,
349
+ normalize_type=normalize_type,
350
+ use_rttm=use_rttm,
351
+ rttm_mode=rttm_mode,
352
+ feat_min_len=feat_min_len,
353
+ feat_mask_val=feat_mask_val,
354
+ frame_unit_time_secs=frame_unit_time_secs,
355
+ sample_rate=sample_rate,
356
+ augmentor=augmentor,
357
+ max_duration=max_duration,
358
+ min_duration=min_duration,
359
+ max_utts=max_utts,
360
+ trim=trim,
361
+ bos_id=bos_id,
362
+ eos_id=eos_id,
363
+ pad_id=pad_id,
364
+ return_sample_id=return_sample_id,
365
+ channel_selector=channel_selector,
366
+ )
367
+
368
+
369
+ class FeatureToBPEDataset(_FeatureTextDataset):
370
+ """
371
+ Dataset that loads tensors via a json file containing paths to audio feature
372
+ files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a different sample.
373
+ Example below:
374
+ {"audio_filepath": "/path/to/audio.wav", "text_filepath":
375
+ "/path/to/audio.txt", "duration": 23.147, "rttm_filepath": "/path/to/audio_rttm.rttm",}
376
+ ...
377
+ {"audio_filepath": "/path/to/audio.wav", "text": "the
378
+ transcription", "offset": 301.75, "duration": 0.82, "utt":
379
+ "utterance_id", "ctm_utt": "en_4156", "side": "A"}
380
+
381
+ In practice, the dataset and manifest used for character encoding and byte pair encoding
382
+ are exactly the same. The only difference lies in how the dataset tokenizes the text in
383
+ the manifest.
384
+
385
+ Args:
386
+ manifest_filepath (str): Path to manifest json as described above. Can
387
+ be comma-separated paths.
388
+ tokenizer: A subclass of the Tokenizer wrapper found in the common collection,
389
+ nemo.collections.common.tokenizers.TokenizerSpec. ASR Models support a subset of
390
+ all available tokenizers.
391
+ normalize (str): how to normalize feature, must be one of [None, "post_norm", "pre_norm"]
392
+ normalize_type (Union[str, dict]): how to normalize feature, see `nemo.collections.asr.parts.preprocessing.features.normalize_batch`
393
+ use_rttm (bool): whether to use RTTM files if there is any, default to False
394
+ rttm_mode (str): how to use RTTM files, must be one of ['mask', 'drop'], default to 'mask'
395
+ feat_min_len (int): minimum length of feature, default to 4
396
+ feat_mask_val (Optional[float]): value used to mask features with RTTM files, default to None to use zero mel-spectralgram
397
+ frame_unit_time_secs: time in seconds for each frame
398
+ sample_rate (int): Sample rate to resample loaded audio to
399
+ int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
400
+ augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
401
+ object used to augment loaded audio
402
+ max_duration: If audio exceeds this length, do not include in dataset
403
+ min_duration: If audio is less than this length, do not include
404
+ in dataset
405
+ max_utts: Limit number of utterances
406
+ trim: Whether to trim silence segments
407
+ use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
408
+ tokens to beginning and ending of speech respectively.
409
+ return_sample_id (bool): whether to return the sample_id as a part of each sample
410
+ channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
411
+ """
412
+
413
+ def __init__(
414
+ self,
415
+ manifest_filepath: str,
416
+ tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
417
+ normalize: Optional[str] = "post_norm",
418
+ normalize_type: Union[str, dict] = "per_feature",
419
+ use_rttm: bool = False,
420
+ rttm_mode: str = "mask",
421
+ feat_min_len: int = 4,
422
+ feat_mask_val: Optional[float] = None,
423
+ frame_unit_time_secs: float = 0.01,
424
+ sample_rate: Optional[int] = 16000,
425
+ augmentor: 'nemo.collections.asr.parts.perturb.FeatureAugmentor' = None,
426
+ max_duration: Optional[int] = None,
427
+ min_duration: Optional[int] = None,
428
+ max_utts: int = 0,
429
+ use_start_end_token: bool = True,
430
+ trim: bool = False,
431
+ return_sample_id: bool = False,
432
+ channel_selector: Optional[ChannelSelectorType] = None,
433
+ ):
434
+ if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
435
+ bos_id = tokenizer.bos_id
436
+ else:
437
+ bos_id = None
438
+
439
+ if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0:
440
+ eos_id = tokenizer.eos_id
441
+ else:
442
+ eos_id = None
443
+
444
+ if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0:
445
+ pad_id = tokenizer.pad_id
446
+ else:
447
+ pad_id = 0
448
+
449
+ class TokenizerWrapper:
450
+ def __init__(self, tokenizer):
451
+ if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
452
+ self.is_aggregate = True
453
+ else:
454
+ self.is_aggregate = False
455
+ self._tokenizer = tokenizer
456
+
457
+ def __call__(self, *args):
458
+ if isinstance(args[0], List) and self.is_aggregate:
459
+ t = []
460
+ for span in args[0]:
461
+ t.extend(self._tokenizer.text_to_ids(span['str'], span['lang']))
462
+ return t
463
+
464
+ t = self._tokenizer.text_to_ids(*args)
465
+ return t
466
+
467
+ super().__init__(
468
+ manifest_filepath=manifest_filepath,
469
+ parser=TokenizerWrapper(tokenizer),
470
+ normalize=normalize,
471
+ normalize_type=normalize_type,
472
+ use_rttm=use_rttm,
473
+ rttm_mode=rttm_mode,
474
+ feat_min_len=feat_min_len,
475
+ feat_mask_val=feat_mask_val,
476
+ frame_unit_time_secs=frame_unit_time_secs,
477
+ sample_rate=sample_rate,
478
+ augmentor=augmentor,
479
+ max_duration=max_duration,
480
+ min_duration=min_duration,
481
+ max_utts=max_utts,
482
+ trim=trim,
483
+ bos_id=bos_id,
484
+ eos_id=eos_id,
485
+ pad_id=pad_id,
486
+ return_sample_id=return_sample_id,
487
+ channel_selector=channel_selector,
488
+ )
SoundScribe/SpeakerID/nemo/collections/asr/data/feature_to_text_dataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from nemo.collections.asr.data.feature_to_text import FeatureToBPEDataset, FeatureToCharDataset
18
+ from nemo.utils import logging
19
+
20
+
21
+ def get_char_dataset(config: dict, augmentor: Optional['FeatureAugmentor'] = None) -> FeatureToCharDataset:
22
+ """
23
+ Instantiates a Character Encoding based FeatureToCharDataset.
24
+
25
+ Args:
26
+ config: Config of the FeatureToCharDataset.
27
+ augmentor: Optional AudioAugmentor object for augmentations on audio data.
28
+
29
+ Returns:
30
+ An instance of FeatureToCharDataset.
31
+ """
32
+ if 'labels' not in config:
33
+ logging.warning(f"dataset does not have explicitly defined labels")
34
+
35
+ dataset = FeatureToCharDataset(
36
+ manifest_filepath=config['manifest_filepath'],
37
+ labels=config.get('labels', None),
38
+ normalize=config.get('normalize', 'post_norm'),
39
+ normalize_type=config.get('normalize_type', 'per_feature'),
40
+ use_rttm=config.get('use_rttm', False),
41
+ rttm_mode=config.get('rttm_mode', 'mask'),
42
+ feat_min_len=config.get('feat_min_len', 4),
43
+ feat_mask_val=config.get('feat_mask_val', None),
44
+ frame_unit_time_secs=config.get('frame_unit_time_secs', 0.01),
45
+ sample_rate=config.get('sample_rate', 16000),
46
+ augmentor=augmentor,
47
+ max_duration=config.get('max_duration', None),
48
+ min_duration=config.get('min_duration', None),
49
+ max_utts=config.get('max_utts', 0),
50
+ blank_index=config.get('blank_index', -1),
51
+ unk_index=config.get('unk_index', -1),
52
+ trim=config.get('trim_silence', False),
53
+ parser=config.get('parser', 'en'),
54
+ return_sample_id=config.get('return_sample_id', False),
55
+ channel_selector=config.get('channel_selector', None),
56
+ )
57
+ return dataset
58
+
59
+
60
+ def get_bpe_dataset(
61
+ config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['FeatureAugmentor'] = None
62
+ ) -> FeatureToBPEDataset:
63
+ """
64
+ Instantiates a Byte Pair Encoding / Word Piece Encoding based FeatureoToBPEDataset.
65
+
66
+ Args:
67
+ config: Config of the FeatureToBPEDataset.
68
+ tokenizer: An instance of a TokenizerSpec object.
69
+ augmentor: Optional FeatureAugmentor object for augmentations on audio features.
70
+
71
+ Returns:
72
+ An instance of FeatureToBPEDataset.
73
+ """
74
+ dataset = FeatureToBPEDataset(
75
+ manifest_filepath=config['manifest_filepath'],
76
+ tokenizer=tokenizer,
77
+ normalize=config.get('normalize', 'post_norm'),
78
+ normalize_type=config.get('normalize_type', 'per_feature'),
79
+ use_rttm=config.get('use_rttm', False),
80
+ rttm_mode=config.get('rttm_mode', 'mask'),
81
+ feat_min_len=config.get('feat_min_len', 4),
82
+ feat_mask_val=config.get('feat_mask_val', None),
83
+ frame_unit_time_secs=config.get('frame_unit_time_secs', 0.01),
84
+ sample_rate=config.get('sample_rate', 16000),
85
+ augmentor=augmentor,
86
+ max_duration=config.get('max_duration', None),
87
+ min_duration=config.get('min_duration', None),
88
+ max_utts=config.get('max_utts', 0),
89
+ trim=config.get('trim_silence', False),
90
+ use_start_end_token=config.get('use_start_end_token', True),
91
+ return_sample_id=config.get('return_sample_id', False),
92
+ channel_selector=config.get('channel_selector', None),
93
+ )
94
+ return dataset
SoundScribe/SpeakerID/nemo/collections/asr/data/text_to_text.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import concurrent.futures
18
+ import copy
19
+ import gc
20
+ import json
21
+ import math
22
+ import random
23
+ from pathlib import Path
24
+ from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.utils.data
29
+ from torch.nn.utils.rnn import pad_sequence
30
+ from tqdm.auto import tqdm
31
+
32
+ from nemo.collections.asr.data.audio_to_text import _speech_collate_fn
33
+ from nemo.collections.common.tokenizers import TokenizerSpec
34
+ from nemo.core.classes import Dataset, IterableDataset
35
+ from nemo.utils import logging
36
+
37
+ try:
38
+ from nemo_text_processing.text_normalization.normalize import Normalizer
39
+ except Exception as e:
40
+ pass # Normalizer imported only for annotation purposes, error can be ignored
41
+
42
+ AnyPath = Union[Path, str]
43
+
44
+
45
+ class TextToTextItem(NamedTuple):
46
+ tts_text: torch.Tensor # normalized and tokenized text for TTS
47
+ transcript: torch.Tensor # tokenized text for ASR
48
+ speaker: int # speaker id for multi-speaker TTS
49
+
50
+
51
+ class TextToTextBatch(NamedTuple):
52
+ tts_texts: torch.Tensor # tokenized texts for tts
53
+ tts_text_lengths: torch.Tensor
54
+ transcripts: torch.Tensor # tokenized texts for ASR
55
+ transcript_lengths: torch.Tensor
56
+ speakers: torch.Tensor # speaker ids for multi-speaker TTS
57
+
58
+ @staticmethod
59
+ def collate_fn(batch: List[TextToTextItem], asr_pad_id: int, tts_text_pad_id: int) -> TextToTextBatch:
60
+ return TextToTextBatch(
61
+ tts_texts=pad_sequence([item.tts_text for item in batch], batch_first=True, padding_value=tts_text_pad_id),
62
+ tts_text_lengths=torch.tensor([item.tts_text.shape[0] for item in batch]).long(),
63
+ transcripts=pad_sequence([item.transcript for item in batch], batch_first=True, padding_value=asr_pad_id),
64
+ transcript_lengths=torch.tensor([item.transcript.shape[0] for item in batch]).long(),
65
+ speakers=torch.tensor([item.speaker for item in batch]).long(),
66
+ )
67
+
68
+
69
+ class TextOrAudioToTextBatch(NamedTuple):
70
+ audio_signals: torch.Tensor
71
+ audio_signal_lengths: torch.Tensor
72
+ tts_texts: torch.Tensor
73
+ tts_text_lengths: torch.Tensor
74
+ speakers: torch.Tensor
75
+ transcripts: torch.Tensor
76
+ transcript_lengths: torch.Tensor
77
+
78
+ @staticmethod
79
+ def collate_fn(
80
+ batch: List[Union[TextToTextItem, tuple]], tts_text_pad_id: int, asr_pad_id: int
81
+ ) -> Union[TextToTextBatch, TextOrAudioToTextBatch, tuple]:
82
+ """
83
+ Collate function for dataloader
84
+ Can accept mixed batch of text-to-text items and audio-text items (typical for ASR)
85
+ """
86
+ text_items: List[TextToTextItem] = [item for item in batch if isinstance(item, TextToTextItem)]
87
+ if not text_items:
88
+ # pure audio-text batch
89
+ return _speech_collate_fn(batch=batch, pad_id=asr_pad_id)
90
+
91
+ asr_items = [item for item in batch if not isinstance(item, TextToTextItem)]
92
+
93
+ if not asr_items:
94
+ # pure text-to-text batch
95
+ return TextToTextBatch.collate_fn(batch=text_items, asr_pad_id=asr_pad_id, tts_text_pad_id=tts_text_pad_id)
96
+
97
+ # mixed batch
98
+
99
+ # each asr item is a tuple:
100
+ # audio_signal (0), audio_length (1), transcript (2), transcript_length (3), sample_id (4, optional)
101
+ audio_signals = pad_sequence([item[0] for item in asr_items], batch_first=True, padding_value=0.0)
102
+ audio_signal_lengths = torch.tensor([item[1] for item in asr_items]).long()
103
+
104
+ tts_texts = pad_sequence(
105
+ [item.tts_text for item in text_items], batch_first=True, padding_value=tts_text_pad_id
106
+ )
107
+ tts_text_lengths = torch.tensor([item.tts_text.shape[0] for item in text_items]).long()
108
+ speakers = torch.tensor([item.speaker for item in text_items]).long()
109
+
110
+ transcripts = pad_sequence(
111
+ [item.transcript for item in text_items] + [item[2] for item in asr_items],
112
+ batch_first=True,
113
+ padding_value=asr_pad_id,
114
+ )
115
+ transcript_lengths = torch.tensor(
116
+ [item.transcript.shape[0] for item in text_items] + [item[3] for item in asr_items]
117
+ ).long()
118
+
119
+ return TextOrAudioToTextBatch(
120
+ audio_signals=audio_signals,
121
+ audio_signal_lengths=audio_signal_lengths,
122
+ tts_texts=tts_texts,
123
+ tts_text_lengths=tts_text_lengths,
124
+ speakers=speakers,
125
+ transcripts=transcripts,
126
+ transcript_lengths=transcript_lengths,
127
+ )
128
+
129
+
130
+ def _asr_text_to_tokens(text: str) -> np.ndarray:
131
+ """
132
+ Helper function for asr tokenization with multiprocessing pool only.
133
+ Must be defined on the top level.
134
+ Expects asr_tokenizer_global, asr_bos_id_global, asr_eos_id_global to exist in the current pool process
135
+ """
136
+ ids = asr_tokenizer_global.text_to_ids(text)
137
+ if asr_bos_id_global is not None:
138
+ ids = [asr_bos_id_global] + ids
139
+ if asr_eos_id_global is not None:
140
+ ids.append(asr_eos_id_global)
141
+ return np.asarray(ids)
142
+
143
+
144
+ def _tts_text_to_tokens(text: str) -> np.ndarray:
145
+ """
146
+ Helper function for asr tokenization with multiprocessing pool only.
147
+ Must be defined on the top level.
148
+ Expects tts_tokenizer_global to exist in the current pool process
149
+ """
150
+ return np.asarray(tts_tokenizer_global(text))
151
+
152
+
153
+ def _iterate_manifest(filepath: AnyPath) -> Iterable[Dict[str, Any]]:
154
+ """
155
+ Helper function to iterate manifest
156
+ """
157
+ with open(filepath, "r", encoding="utf-8") as f:
158
+ for line in f:
159
+ record = json.loads(line)
160
+ yield record
161
+
162
+
163
+ class TextToTextDatasetBase:
164
+ """
165
+ Base class for loading text-to-text manifests
166
+ Map-style and Iterable datasets should inherit this class
167
+ """
168
+
169
+ asr_pad_id: int
170
+ tts_text_pad_id: int
171
+ asr_bos_id: Optional[int] = None
172
+ asr_eos_id: Optional[int] = None
173
+ data: List[Dict[str, Any]]
174
+
175
+ def __init__(
176
+ self,
177
+ manifest_filepath: Union[AnyPath, List[AnyPath]],
178
+ speakers_filepath: Union[AnyPath, List[AnyPath]],
179
+ asr_tokenizer: TokenizerSpec,
180
+ asr_use_start_end_token: bool,
181
+ tts_parser: Callable,
182
+ tts_text_pad_id: int,
183
+ tts_text_normalizer: "Normalizer",
184
+ tts_text_normalizer_call_kwargs: Dict,
185
+ min_words: int = 1,
186
+ max_words: int = 1_000_000,
187
+ tokenizer_workers: int = 1,
188
+ num_parts: int = 1,
189
+ current_part_index: int = 0,
190
+ ):
191
+ super().__init__()
192
+ # ASR tokenizer setup
193
+ if asr_use_start_end_token and hasattr(asr_tokenizer, 'bos_token'):
194
+ self.asr_bos_id = asr_tokenizer.bos_id
195
+
196
+ if asr_use_start_end_token and hasattr(asr_tokenizer, 'eos_token'):
197
+ self.asr_eos_id = asr_tokenizer.eos_id
198
+
199
+ if hasattr(asr_tokenizer, 'pad_token'):
200
+ self.asr_pad_id = asr_tokenizer.pad_id
201
+ else:
202
+ self.asr_pad_id = 0
203
+
204
+ self.asr_tokenizer = asr_tokenizer
205
+
206
+ # TTS tokenizer setup
207
+ self.tts_parser = tts_parser
208
+ self.tts_normalizer = tts_text_normalizer
209
+ self.tts_normalizer_kwargs = tts_text_normalizer_call_kwargs
210
+ self.tts_text_pad_id = tts_text_pad_id
211
+
212
+ # Load speakers
213
+ if isinstance(speakers_filepath, str):
214
+ speakers_filepath = speakers_filepath.split(",")
215
+ elif isinstance(speakers_filepath, Path):
216
+ speakers_filepath = [speakers_filepath]
217
+ speakers: Set[int] = set()
218
+ for filepath in speakers_filepath:
219
+ with open(Path(filepath).expanduser(), "r") as f:
220
+ speakers.update(map(int, f.read().split()))
221
+ self.speakers = np.asarray(sorted(speakers))
222
+ logging.info(f"Loaded {len(self.speakers)} speakers")
223
+
224
+ # Load manifest
225
+ if isinstance(manifest_filepath, str):
226
+ manifest_filepath = manifest_filepath.split(",")
227
+ elif isinstance(manifest_filepath, Path):
228
+ manifest_filepath = [manifest_filepath]
229
+ self.manifest_paths = [Path(filepath) for filepath in manifest_filepath]
230
+
231
+ num_skipped_words = 0
232
+ num_skipped_utterances = 0
233
+ asr_texts = []
234
+ tts_texts = []
235
+ need_normalization = False
236
+
237
+ for manifest_path in self.manifest_paths:
238
+ for tmp_item in tqdm(_iterate_manifest(manifest_path)):
239
+ text = tmp_item["text"]
240
+ num_words = len(text.split())
241
+ # skip if number of works not in desired range
242
+ # TODO: maybe it would be valuable to sample sub-utterances from long utterances
243
+ if not (min_words <= num_words <= max_words):
244
+ num_skipped_words += num_words
245
+ num_skipped_utterances += 1
246
+ continue
247
+ asr_texts.append(tmp_item["text"])
248
+ if "tts_text_normalized" in tmp_item:
249
+ tts_texts.append(tmp_item["tts_text_normalized"])
250
+ else:
251
+ tts_texts.append(tmp_item["tts_text"])
252
+ need_normalization = True
253
+
254
+ if need_normalization:
255
+ logging.warning("TTS normalization is extremely slow! It is recommended to normalize TTS text")
256
+
257
+ if num_skipped_utterances:
258
+ logging.warning(f"Skipped {num_skipped_utterances} utterances " f"with {num_skipped_words}")
259
+
260
+ num_utterances = len(asr_texts)
261
+ # preprocessing is very costly, if we need only part - remove unnecessary utterances
262
+ if num_parts > 1:
263
+ # NB: floor division, full dataset can contain fewer utterances than original, like in tarred dataset
264
+ num_utterances_part = num_utterances // num_parts
265
+ start = num_utterances_part * current_part_index
266
+ end = start + num_utterances_part
267
+ logging.info(
268
+ f"Taking part of the dataset: {current_part_index} index, total {num_parts} from {start} to {end}"
269
+ )
270
+ asr_texts = asr_texts[start:end]
271
+ tts_texts = tts_texts[start:end]
272
+ num_utterances = num_utterances_part
273
+
274
+ self.data = [dict() for _ in range(num_utterances)]
275
+
276
+ if len(asr_texts) == 0:
277
+ # no data was loaded
278
+ logging.warning("Text-to-text dataset is empty")
279
+ return
280
+
281
+ if tokenizer_workers == 1:
282
+ logging.warning(
283
+ "Preprocessing large text with tokenizer_workers=1 may be slow with TTS tokenizer. "
284
+ "Prefer tokenizer_workers=(num_cpu_cores/num_gpus_per_node)"
285
+ )
286
+ for i, tokenized_text in enumerate(
287
+ tqdm((self._asr_text_to_tokens(text) for text in asr_texts), total=len(asr_texts))
288
+ ):
289
+ self.data[i]["asr_text_tokens"] = tokenized_text
290
+ else:
291
+ # Multiprocessing hack: use global variables for every process (not really global in program context)
292
+ def _init_asr_tokenize_process(tokenizer, bos_id, eos_id):
293
+ global asr_tokenizer_global, asr_bos_id_global, asr_eos_id_global # process-global
294
+ # deepcopy to avoid serialization of parent models
295
+ asr_tokenizer_global = copy.deepcopy(tokenizer)
296
+ asr_bos_id_global = copy.deepcopy(bos_id)
297
+ asr_eos_id_global = copy.deepcopy(eos_id)
298
+
299
+ with concurrent.futures.ProcessPoolExecutor(
300
+ initializer=_init_asr_tokenize_process,
301
+ initargs=(asr_tokenizer, self.asr_bos_id, self.asr_eos_id),
302
+ max_workers=tokenizer_workers,
303
+ ) as pool:
304
+ # chunk size for pool map is empirically chosen as a trade-off between speed and responsiveness
305
+ for i, tokenized_text in enumerate(
306
+ tqdm(pool.map(_asr_text_to_tokens, asr_texts, chunksize=1000), total=len(asr_texts))
307
+ ):
308
+ self.data[i]["asr_text_tokens"] = tokenized_text
309
+ # force free memory
310
+ del asr_texts
311
+ gc.collect()
312
+
313
+ if tokenizer_workers == 1:
314
+ logging.warning(
315
+ "Preprocessing large text with tokenizer_workers=1 may be slow with TTS tokenizer. "
316
+ "Prefer tokenizer_workers=(num_cpu_cores/num_gpus_per_node)"
317
+ )
318
+ for i, tokenized_text in enumerate(
319
+ tqdm(
320
+ (self._tts_text_to_tokens(text, normalize=need_normalization) for text in tts_texts),
321
+ total=len(tts_texts),
322
+ )
323
+ ):
324
+ self.data[i]["tts_text_tokens"] = tokenized_text
325
+ else:
326
+ if need_normalization:
327
+ # TODO: implement, if we really need normalization inplace
328
+ raise NotImplementedError(
329
+ "Normalization with tokenizer_workers > 1 is not implemented. "
330
+ "It is not recommended to use normalization on the fly at all, since it's extremely slow"
331
+ )
332
+
333
+ def _init_tts_tokenize_process(tokenizer):
334
+ global tts_tokenizer_global # process-global
335
+ tts_tokenizer_global = copy.deepcopy(tokenizer)
336
+
337
+ with concurrent.futures.ProcessPoolExecutor(
338
+ initializer=_init_tts_tokenize_process, initargs=(tts_parser,), max_workers=tokenizer_workers,
339
+ ) as pool:
340
+ # chunk size for pool map is empirically chosen as a trade-off between speed and responsiveness
341
+ for i, tokenized_text in enumerate(
342
+ tqdm(pool.map(_tts_text_to_tokens, tts_texts, chunksize=1000), total=len(tts_texts))
343
+ ):
344
+ self.data[i]["tts_text_tokens"] = tokenized_text
345
+ # force free memory
346
+ del tts_texts
347
+ gc.collect()
348
+
349
+ def _asr_text_to_tokens(self, text: str) -> np.ndarray:
350
+ ids = self.asr_tokenizer.text_to_ids(text)
351
+ if self.asr_bos_id is not None:
352
+ ids = [self.asr_bos_id] + ids
353
+ if self.asr_eos_id is not None:
354
+ ids.append(self.asr_eos_id)
355
+ return np.asarray(ids)
356
+
357
+ def _tts_text_to_tokens(self, text: str, normalize=True) -> np.ndarray:
358
+ if normalize:
359
+ text = self.tts_normalizer.normalize(text, **self.tts_normalizer_kwargs)
360
+ tokens = self.tts_parser(text)
361
+ return np.asarray(tokens)
362
+
363
+ def __getitem__(self, index):
364
+ item = self.data[index]
365
+ return TextToTextItem(
366
+ transcript=torch.from_numpy(item["asr_text_tokens"]).long(),
367
+ tts_text=torch.from_numpy(item["tts_text_tokens"]).long(),
368
+ speaker=random.choice(self.speakers),
369
+ )
370
+
371
+ def __len__(self):
372
+ return len(self.data)
373
+
374
+
375
+ class TextToTextDataset(TextToTextDatasetBase, Dataset):
376
+ """Text-to-Text Map-style Dataset for hybrid ASR-TTS models"""
377
+
378
+ def __init__(
379
+ self,
380
+ manifest_filepath: Union[AnyPath, List[AnyPath]],
381
+ speakers_filepath: Union[AnyPath, List[AnyPath]],
382
+ asr_tokenizer: TokenizerSpec,
383
+ asr_use_start_end_token: bool,
384
+ tts_parser: Callable,
385
+ tts_text_pad_id: int,
386
+ tts_text_normalizer: "Normalizer",
387
+ tts_text_normalizer_call_kwargs: Dict,
388
+ min_words: int = 1,
389
+ max_words: int = 1_000_000,
390
+ tokenizer_workers: int = 1,
391
+ ):
392
+ super().__init__(
393
+ manifest_filepath=manifest_filepath,
394
+ speakers_filepath=speakers_filepath,
395
+ asr_tokenizer=asr_tokenizer,
396
+ asr_use_start_end_token=asr_use_start_end_token,
397
+ tts_parser=tts_parser,
398
+ tts_text_pad_id=tts_text_pad_id,
399
+ tts_text_normalizer=tts_text_normalizer,
400
+ tts_text_normalizer_call_kwargs=tts_text_normalizer_call_kwargs,
401
+ min_words=min_words,
402
+ max_words=max_words,
403
+ tokenizer_workers=tokenizer_workers,
404
+ num_parts=1,
405
+ )
406
+
407
+ def collate_fn(
408
+ self, batch: List[Union[TextToTextItem, tuple]]
409
+ ) -> Union[TextToTextBatch, TextOrAudioToTextBatch, tuple]:
410
+ """
411
+ Collate function for dataloader
412
+ Can accept mixed batch of text-to-text items and audio-text items (typical for ASR)
413
+ """
414
+ return TextOrAudioToTextBatch.collate_fn(
415
+ batch=batch, asr_pad_id=self.asr_pad_id, tts_text_pad_id=self.tts_text_pad_id
416
+ )
417
+
418
+
419
+ class TextToTextIterableDataset(TextToTextDatasetBase, IterableDataset):
420
+ """
421
+ Text-to-Text Iterable Dataset for hybrid ASR-TTS models
422
+ Only part necessary for current process should be loaded and stored
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ manifest_filepath: Union[AnyPath, List[AnyPath]],
428
+ speakers_filepath: Union[AnyPath, List[AnyPath]],
429
+ asr_tokenizer: TokenizerSpec,
430
+ asr_use_start_end_token: bool,
431
+ tts_parser: Callable,
432
+ tts_text_pad_id: int,
433
+ tts_text_normalizer: "Normalizer",
434
+ tts_text_normalizer_call_kwargs: Dict,
435
+ min_words: int = 1,
436
+ max_words: int = 1_000_000,
437
+ tokenizer_workers: int = 1,
438
+ num_parts: int = 1,
439
+ current_part_index: int = 0,
440
+ ):
441
+ super().__init__(
442
+ manifest_filepath=manifest_filepath,
443
+ speakers_filepath=speakers_filepath,
444
+ asr_tokenizer=asr_tokenizer,
445
+ asr_use_start_end_token=asr_use_start_end_token,
446
+ tts_parser=tts_parser,
447
+ tts_text_pad_id=tts_text_pad_id,
448
+ tts_text_normalizer=tts_text_normalizer,
449
+ tts_text_normalizer_call_kwargs=tts_text_normalizer_call_kwargs,
450
+ min_words=min_words,
451
+ max_words=max_words,
452
+ tokenizer_workers=tokenizer_workers,
453
+ num_parts=num_parts,
454
+ current_part_index=current_part_index,
455
+ )
456
+
457
+ def __iter__(self):
458
+ # Implementation based on docs: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
459
+ worker_info = torch.utils.data.get_worker_info()
460
+ if worker_info is None: # single-process data loading, return the full iterator
461
+ start = 0
462
+ end = len(self)
463
+ else: # in a worker process
464
+ # split workload
465
+ per_worker = int(math.ceil(len(self) / float(worker_info.num_workers)))
466
+ worker_id = worker_info.id
467
+ start = worker_id * per_worker
468
+ end = min(start + per_worker, len(self))
469
+ indices = np.arange(start, end)
470
+ np.random.shuffle(indices)
471
+ return map(self.__getitem__, indices)
472
+
473
+ def collate_fn(
474
+ self, batch: List[Union[TextToTextItem, tuple]]
475
+ ) -> Union[TextToTextBatch, TextOrAudioToTextBatch, tuple]:
476
+ """
477
+ Collate function for dataloader
478
+ Can accept mixed batch of text-to-text items and audio-text items (typical for ASR)
479
+ """
480
+ return TextOrAudioToTextBatch.collate_fn(
481
+ batch=batch, asr_pad_id=self.asr_pad_id, tts_text_pad_id=self.tts_text_pad_id
482
+ )
SoundScribe/SpeakerID/nemo/collections/asr/losses/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
16
+ from nemo.collections.asr.losses.audio_losses import SDRLoss
17
+ from nemo.collections.asr.losses.ctc import CTCLoss
18
+ from nemo.collections.asr.losses.lattice_losses import LatticeLoss
19
+ from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss
20
+ from nemo.collections.asr.losses.ssl_losses.ctc import CTCLossForSSL
21
+ from nemo.collections.asr.losses.ssl_losses.mlm import MLMLoss
22
+ from nemo.collections.asr.losses.ssl_losses.rnnt import RNNTLossForSSL
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (784 Bytes). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/angularloss.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/audio_losses.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
SoundScribe/SpeakerID/nemo/collections/asr/losses/__pycache__/ctc.cpython-310.pyc ADDED
Binary file (2.29 kB). View file