File size: 4,240 Bytes
9d3cb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn.functional as F
from torch import nn

from ...core import AudioSignal
from ...core import STFTParams
from ...core import util


class SpectralGate(nn.Module):
    """Spectral gating algorithm for noise reduction,
    as in Audacity/Ocenaudio. The steps are as follows:

    1.  An FFT is calculated over the noise audio clip
    2.  Statistics are calculated over FFT of the the noise
        (in frequency)
    3.  A threshold is calculated based upon the statistics
        of the noise (and the desired sensitivity of the algorithm)
    4.  An FFT is calculated over the signal
    5.  A mask is determined by comparing the signal FFT to the
        threshold
    6.  The mask is smoothed with a filter over frequency and time
    7.  The mask is appled to the FFT of the signal, and is inverted

    Implementation inspired by Tim Sainburg's noisereduce:

    https://timsainburg.com/noise-reduction-python.html

    Parameters
    ----------
    n_freq : int, optional
        Number of frequency bins to smooth by, by default 3
    n_time : int, optional
        Number of time bins to smooth by, by default 5
    """

    def __init__(self, n_freq: int = 3, n_time: int = 5):
        super().__init__()

        smoothing_filter = torch.outer(
            torch.cat(
                [
                    torch.linspace(0, 1, n_freq + 2)[:-1],
                    torch.linspace(1, 0, n_freq + 2),
                ]
            )[..., 1:-1],
            torch.cat(
                [
                    torch.linspace(0, 1, n_time + 2)[:-1],
                    torch.linspace(1, 0, n_time + 2),
                ]
            )[..., 1:-1],
        )
        smoothing_filter = smoothing_filter / smoothing_filter.sum()
        smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
        self.register_buffer("smoothing_filter", smoothing_filter)

    def forward(
        self,
        audio_signal: AudioSignal,
        nz_signal: AudioSignal,
        denoise_amount: float = 1.0,
        n_std: float = 3.0,
        win_length: int = 2048,
        hop_length: int = 512,
    ):
        """Perform noise reduction.

        Parameters
        ----------
        audio_signal : AudioSignal
            Audio signal that noise will be removed from.
        nz_signal : AudioSignal, optional
            Noise signal to compute noise statistics from.
        denoise_amount : float, optional
            Amount to denoise by, by default 1.0
        n_std : float, optional
            Number of standard deviations above which to consider
            noise, by default 3.0
        win_length : int, optional
            Length of window for STFT, by default 2048
        hop_length : int, optional
            Hop length for STFT, by default 512

        Returns
        -------
        AudioSignal
            Denoised audio signal.
        """
        stft_params = STFTParams(win_length, hop_length, "sqrt_hann")

        audio_signal = audio_signal.clone()
        audio_signal.stft_data = None
        audio_signal.stft_params = stft_params

        nz_signal = nz_signal.clone()
        nz_signal.stft_params = stft_params

        nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
        nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
        nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)

        nz_thresh = nz_freq_mean + nz_freq_std * n_std

        stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
        nb, nac, nf, nt = stft_db.shape
        db_thresh = nz_thresh.expand(nb, nac, -1, nt)

        stft_mask = (stft_db < db_thresh).float()
        shape = stft_mask.shape

        stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
        pad_tuple = (
            self.smoothing_filter.shape[-2] // 2,
            self.smoothing_filter.shape[-1] // 2,
        )
        stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
        stft_mask = stft_mask.reshape(*shape)
        stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
            audio_signal.device
        )
        stft_mask = 1 - stft_mask

        audio_signal.stft_data *= stft_mask
        audio_signal.istft()

        return audio_signal