Faridmaruf commited on
Commit
ed91a72
1 Parent(s): 94be86c

Delete infer_pack

Browse files
infer_pack/attentions.py DELETED
@@ -1,417 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from infer_pack import commons
9
- from infer_pack import modules
10
- from infer_pack.modules import LayerNorm
11
-
12
-
13
- class Encoder(nn.Module):
14
- def __init__(
15
- self,
16
- hidden_channels,
17
- filter_channels,
18
- n_heads,
19
- n_layers,
20
- kernel_size=1,
21
- p_dropout=0.0,
22
- window_size=10,
23
- **kwargs
24
- ):
25
- super().__init__()
26
- self.hidden_channels = hidden_channels
27
- self.filter_channels = filter_channels
28
- self.n_heads = n_heads
29
- self.n_layers = n_layers
30
- self.kernel_size = kernel_size
31
- self.p_dropout = p_dropout
32
- self.window_size = window_size
33
-
34
- self.drop = nn.Dropout(p_dropout)
35
- self.attn_layers = nn.ModuleList()
36
- self.norm_layers_1 = nn.ModuleList()
37
- self.ffn_layers = nn.ModuleList()
38
- self.norm_layers_2 = nn.ModuleList()
39
- for i in range(self.n_layers):
40
- self.attn_layers.append(
41
- MultiHeadAttention(
42
- hidden_channels,
43
- hidden_channels,
44
- n_heads,
45
- p_dropout=p_dropout,
46
- window_size=window_size,
47
- )
48
- )
49
- self.norm_layers_1.append(LayerNorm(hidden_channels))
50
- self.ffn_layers.append(
51
- FFN(
52
- hidden_channels,
53
- hidden_channels,
54
- filter_channels,
55
- kernel_size,
56
- p_dropout=p_dropout,
57
- )
58
- )
59
- self.norm_layers_2.append(LayerNorm(hidden_channels))
60
-
61
- def forward(self, x, x_mask):
62
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
- x = x * x_mask
64
- for i in range(self.n_layers):
65
- y = self.attn_layers[i](x, x, attn_mask)
66
- y = self.drop(y)
67
- x = self.norm_layers_1[i](x + y)
68
-
69
- y = self.ffn_layers[i](x, x_mask)
70
- y = self.drop(y)
71
- x = self.norm_layers_2[i](x + y)
72
- x = x * x_mask
73
- return x
74
-
75
-
76
- class Decoder(nn.Module):
77
- def __init__(
78
- self,
79
- hidden_channels,
80
- filter_channels,
81
- n_heads,
82
- n_layers,
83
- kernel_size=1,
84
- p_dropout=0.0,
85
- proximal_bias=False,
86
- proximal_init=True,
87
- **kwargs
88
- ):
89
- super().__init__()
90
- self.hidden_channels = hidden_channels
91
- self.filter_channels = filter_channels
92
- self.n_heads = n_heads
93
- self.n_layers = n_layers
94
- self.kernel_size = kernel_size
95
- self.p_dropout = p_dropout
96
- self.proximal_bias = proximal_bias
97
- self.proximal_init = proximal_init
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.self_attn_layers = nn.ModuleList()
101
- self.norm_layers_0 = nn.ModuleList()
102
- self.encdec_attn_layers = nn.ModuleList()
103
- self.norm_layers_1 = nn.ModuleList()
104
- self.ffn_layers = nn.ModuleList()
105
- self.norm_layers_2 = nn.ModuleList()
106
- for i in range(self.n_layers):
107
- self.self_attn_layers.append(
108
- MultiHeadAttention(
109
- hidden_channels,
110
- hidden_channels,
111
- n_heads,
112
- p_dropout=p_dropout,
113
- proximal_bias=proximal_bias,
114
- proximal_init=proximal_init,
115
- )
116
- )
117
- self.norm_layers_0.append(LayerNorm(hidden_channels))
118
- self.encdec_attn_layers.append(
119
- MultiHeadAttention(
120
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
- )
122
- )
123
- self.norm_layers_1.append(LayerNorm(hidden_channels))
124
- self.ffn_layers.append(
125
- FFN(
126
- hidden_channels,
127
- hidden_channels,
128
- filter_channels,
129
- kernel_size,
130
- p_dropout=p_dropout,
131
- causal=True,
132
- )
133
- )
134
- self.norm_layers_2.append(LayerNorm(hidden_channels))
135
-
136
- def forward(self, x, x_mask, h, h_mask):
137
- """
138
- x: decoder input
139
- h: encoder output
140
- """
141
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
- device=x.device, dtype=x.dtype
143
- )
144
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
- x = x * x_mask
146
- for i in range(self.n_layers):
147
- y = self.self_attn_layers[i](x, x, self_attn_mask)
148
- y = self.drop(y)
149
- x = self.norm_layers_0[i](x + y)
150
-
151
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
- y = self.drop(y)
153
- x = self.norm_layers_1[i](x + y)
154
-
155
- y = self.ffn_layers[i](x, x_mask)
156
- y = self.drop(y)
157
- x = self.norm_layers_2[i](x + y)
158
- x = x * x_mask
159
- return x
160
-
161
-
162
- class MultiHeadAttention(nn.Module):
163
- def __init__(
164
- self,
165
- channels,
166
- out_channels,
167
- n_heads,
168
- p_dropout=0.0,
169
- window_size=None,
170
- heads_share=True,
171
- block_length=None,
172
- proximal_bias=False,
173
- proximal_init=False,
174
- ):
175
- super().__init__()
176
- assert channels % n_heads == 0
177
-
178
- self.channels = channels
179
- self.out_channels = out_channels
180
- self.n_heads = n_heads
181
- self.p_dropout = p_dropout
182
- self.window_size = window_size
183
- self.heads_share = heads_share
184
- self.block_length = block_length
185
- self.proximal_bias = proximal_bias
186
- self.proximal_init = proximal_init
187
- self.attn = None
188
-
189
- self.k_channels = channels // n_heads
190
- self.conv_q = nn.Conv1d(channels, channels, 1)
191
- self.conv_k = nn.Conv1d(channels, channels, 1)
192
- self.conv_v = nn.Conv1d(channels, channels, 1)
193
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
- self.drop = nn.Dropout(p_dropout)
195
-
196
- if window_size is not None:
197
- n_heads_rel = 1 if heads_share else n_heads
198
- rel_stddev = self.k_channels**-0.5
199
- self.emb_rel_k = nn.Parameter(
200
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
- * rel_stddev
202
- )
203
- self.emb_rel_v = nn.Parameter(
204
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
- * rel_stddev
206
- )
207
-
208
- nn.init.xavier_uniform_(self.conv_q.weight)
209
- nn.init.xavier_uniform_(self.conv_k.weight)
210
- nn.init.xavier_uniform_(self.conv_v.weight)
211
- if proximal_init:
212
- with torch.no_grad():
213
- self.conv_k.weight.copy_(self.conv_q.weight)
214
- self.conv_k.bias.copy_(self.conv_q.bias)
215
-
216
- def forward(self, x, c, attn_mask=None):
217
- q = self.conv_q(x)
218
- k = self.conv_k(c)
219
- v = self.conv_v(c)
220
-
221
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
-
223
- x = self.conv_o(x)
224
- return x
225
-
226
- def attention(self, query, key, value, mask=None):
227
- # reshape [b, d, t] -> [b, n_h, t, d_k]
228
- b, d, t_s, t_t = (*key.size(), query.size(2))
229
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
-
233
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
- if self.window_size is not None:
235
- assert (
236
- t_s == t_t
237
- ), "Relative attention is only available for self-attention."
238
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
- rel_logits = self._matmul_with_relative_keys(
240
- query / math.sqrt(self.k_channels), key_relative_embeddings
241
- )
242
- scores_local = self._relative_position_to_absolute_position(rel_logits)
243
- scores = scores + scores_local
244
- if self.proximal_bias:
245
- assert t_s == t_t, "Proximal bias is only available for self-attention."
246
- scores = scores + self._attention_bias_proximal(t_s).to(
247
- device=scores.device, dtype=scores.dtype
248
- )
249
- if mask is not None:
250
- scores = scores.masked_fill(mask == 0, -1e4)
251
- if self.block_length is not None:
252
- assert (
253
- t_s == t_t
254
- ), "Local attention is only available for self-attention."
255
- block_mask = (
256
- torch.ones_like(scores)
257
- .triu(-self.block_length)
258
- .tril(self.block_length)
259
- )
260
- scores = scores.masked_fill(block_mask == 0, -1e4)
261
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
- p_attn = self.drop(p_attn)
263
- output = torch.matmul(p_attn, value)
264
- if self.window_size is not None:
265
- relative_weights = self._absolute_position_to_relative_position(p_attn)
266
- value_relative_embeddings = self._get_relative_embeddings(
267
- self.emb_rel_v, t_s
268
- )
269
- output = output + self._matmul_with_relative_values(
270
- relative_weights, value_relative_embeddings
271
- )
272
- output = (
273
- output.transpose(2, 3).contiguous().view(b, d, t_t)
274
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
- return output, p_attn
276
-
277
- def _matmul_with_relative_values(self, x, y):
278
- """
279
- x: [b, h, l, m]
280
- y: [h or 1, m, d]
281
- ret: [b, h, l, d]
282
- """
283
- ret = torch.matmul(x, y.unsqueeze(0))
284
- return ret
285
-
286
- def _matmul_with_relative_keys(self, x, y):
287
- """
288
- x: [b, h, l, d]
289
- y: [h or 1, m, d]
290
- ret: [b, h, l, m]
291
- """
292
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
- return ret
294
-
295
- def _get_relative_embeddings(self, relative_embeddings, length):
296
- max_relative_position = 2 * self.window_size + 1
297
- # Pad first before slice to avoid using cond ops.
298
- pad_length = max(length - (self.window_size + 1), 0)
299
- slice_start_position = max((self.window_size + 1) - length, 0)
300
- slice_end_position = slice_start_position + 2 * length - 1
301
- if pad_length > 0:
302
- padded_relative_embeddings = F.pad(
303
- relative_embeddings,
304
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
- )
306
- else:
307
- padded_relative_embeddings = relative_embeddings
308
- used_relative_embeddings = padded_relative_embeddings[
309
- :, slice_start_position:slice_end_position
310
- ]
311
- return used_relative_embeddings
312
-
313
- def _relative_position_to_absolute_position(self, x):
314
- """
315
- x: [b, h, l, 2*l-1]
316
- ret: [b, h, l, l]
317
- """
318
- batch, heads, length, _ = x.size()
319
- # Concat columns of pad to shift from relative to absolute indexing.
320
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
-
322
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
- x_flat = x.view([batch, heads, length * 2 * length])
324
- x_flat = F.pad(
325
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
- )
327
-
328
- # Reshape and slice out the padded elements.
329
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
- :, :, :length, length - 1 :
331
- ]
332
- return x_final
333
-
334
- def _absolute_position_to_relative_position(self, x):
335
- """
336
- x: [b, h, l, l]
337
- ret: [b, h, l, 2*l-1]
338
- """
339
- batch, heads, length, _ = x.size()
340
- # padd along column
341
- x = F.pad(
342
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
- )
344
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
- # add 0's in the beginning that will skew the elements after reshape
346
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
- return x_final
349
-
350
- def _attention_bias_proximal(self, length):
351
- """Bias for self-attention to encourage attention to close positions.
352
- Args:
353
- length: an integer scalar.
354
- Returns:
355
- a Tensor with shape [1, 1, length, length]
356
- """
357
- r = torch.arange(length, dtype=torch.float32)
358
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
-
361
-
362
- class FFN(nn.Module):
363
- def __init__(
364
- self,
365
- in_channels,
366
- out_channels,
367
- filter_channels,
368
- kernel_size,
369
- p_dropout=0.0,
370
- activation=None,
371
- causal=False,
372
- ):
373
- super().__init__()
374
- self.in_channels = in_channels
375
- self.out_channels = out_channels
376
- self.filter_channels = filter_channels
377
- self.kernel_size = kernel_size
378
- self.p_dropout = p_dropout
379
- self.activation = activation
380
- self.causal = causal
381
-
382
- if causal:
383
- self.padding = self._causal_padding
384
- else:
385
- self.padding = self._same_padding
386
-
387
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
- self.drop = nn.Dropout(p_dropout)
390
-
391
- def forward(self, x, x_mask):
392
- x = self.conv_1(self.padding(x * x_mask))
393
- if self.activation == "gelu":
394
- x = x * torch.sigmoid(1.702 * x)
395
- else:
396
- x = torch.relu(x)
397
- x = self.drop(x)
398
- x = self.conv_2(self.padding(x * x_mask))
399
- return x * x_mask
400
-
401
- def _causal_padding(self, x):
402
- if self.kernel_size == 1:
403
- return x
404
- pad_l = self.kernel_size - 1
405
- pad_r = 0
406
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
- x = F.pad(x, commons.convert_pad_shape(padding))
408
- return x
409
-
410
- def _same_padding(self, x):
411
- if self.kernel_size == 1:
412
- return x
413
- pad_l = (self.kernel_size - 1) // 2
414
- pad_r = self.kernel_size // 2
415
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
- x = F.pad(x, commons.convert_pad_shape(padding))
417
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/commons.py DELETED
@@ -1,166 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
-
8
- def init_weights(m, mean=0.0, std=0.01):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- m.weight.data.normal_(mean, std)
12
-
13
-
14
- def get_padding(kernel_size, dilation=1):
15
- return int((kernel_size * dilation - dilation) / 2)
16
-
17
-
18
- def convert_pad_shape(pad_shape):
19
- l = pad_shape[::-1]
20
- pad_shape = [item for sublist in l for item in sublist]
21
- return pad_shape
22
-
23
-
24
- def kl_divergence(m_p, logs_p, m_q, logs_q):
25
- """KL(P||Q)"""
26
- kl = (logs_q - logs_p) - 0.5
27
- kl += (
28
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
- )
30
- return kl
31
-
32
-
33
- def rand_gumbel(shape):
34
- """Sample from the Gumbel distribution, protect from overflows."""
35
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
- return -torch.log(-torch.log(uniform_samples))
37
-
38
-
39
- def rand_gumbel_like(x):
40
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
- return g
42
-
43
-
44
- def slice_segments(x, ids_str, segment_size=4):
45
- ret = torch.zeros_like(x[:, :, :segment_size])
46
- for i in range(x.size(0)):
47
- idx_str = ids_str[i]
48
- idx_end = idx_str + segment_size
49
- ret[i] = x[i, :, idx_str:idx_end]
50
- return ret
51
-
52
-
53
- def slice_segments2(x, ids_str, segment_size=4):
54
- ret = torch.zeros_like(x[:, :segment_size])
55
- for i in range(x.size(0)):
56
- idx_str = ids_str[i]
57
- idx_end = idx_str + segment_size
58
- ret[i] = x[i, idx_str:idx_end]
59
- return ret
60
-
61
-
62
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
63
- b, d, t = x.size()
64
- if x_lengths is None:
65
- x_lengths = t
66
- ids_str_max = x_lengths - segment_size + 1
67
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
68
- ret = slice_segments(x, ids_str, segment_size)
69
- return ret, ids_str
70
-
71
-
72
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
73
- position = torch.arange(length, dtype=torch.float)
74
- num_timescales = channels // 2
75
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
76
- num_timescales - 1
77
- )
78
- inv_timescales = min_timescale * torch.exp(
79
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
80
- )
81
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
82
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
83
- signal = F.pad(signal, [0, 0, 0, channels % 2])
84
- signal = signal.view(1, channels, length)
85
- return signal
86
-
87
-
88
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
89
- b, channels, length = x.size()
90
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
- return x + signal.to(dtype=x.dtype, device=x.device)
92
-
93
-
94
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
95
- b, channels, length = x.size()
96
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
-
99
-
100
- def subsequent_mask(length):
101
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
- return mask
103
-
104
-
105
- @torch.jit.script
106
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
- n_channels_int = n_channels[0]
108
- in_act = input_a + input_b
109
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
- acts = t_act * s_act
112
- return acts
113
-
114
-
115
- def convert_pad_shape(pad_shape):
116
- l = pad_shape[::-1]
117
- pad_shape = [item for sublist in l for item in sublist]
118
- return pad_shape
119
-
120
-
121
- def shift_1d(x):
122
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
123
- return x
124
-
125
-
126
- def sequence_mask(length, max_length=None):
127
- if max_length is None:
128
- max_length = length.max()
129
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
130
- return x.unsqueeze(0) < length.unsqueeze(1)
131
-
132
-
133
- def generate_path(duration, mask):
134
- """
135
- duration: [b, 1, t_x]
136
- mask: [b, 1, t_y, t_x]
137
- """
138
- device = duration.device
139
-
140
- b, _, t_y, t_x = mask.shape
141
- cum_duration = torch.cumsum(duration, -1)
142
-
143
- cum_duration_flat = cum_duration.view(b * t_x)
144
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
- path = path.view(b, t_x, t_y)
146
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
- path = path.unsqueeze(1).transpose(2, 3) * mask
148
- return path
149
-
150
-
151
- def clip_grad_value_(parameters, clip_value, norm_type=2):
152
- if isinstance(parameters, torch.Tensor):
153
- parameters = [parameters]
154
- parameters = list(filter(lambda p: p.grad is not None, parameters))
155
- norm_type = float(norm_type)
156
- if clip_value is not None:
157
- clip_value = float(clip_value)
158
-
159
- total_norm = 0
160
- for p in parameters:
161
- param_norm = p.grad.data.norm(norm_type)
162
- total_norm += param_norm.item() ** norm_type
163
- if clip_value is not None:
164
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
- total_norm = total_norm ** (1.0 / norm_type)
166
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models.py DELETED
@@ -1,1124 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder768(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(768, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- stats = self.proj(x) * x_mask
106
-
107
- m, logs = torch.split(stats, self.out_channels, dim=1)
108
- return m, logs, x_mask
109
-
110
-
111
- class ResidualCouplingBlock(nn.Module):
112
- def __init__(
113
- self,
114
- channels,
115
- hidden_channels,
116
- kernel_size,
117
- dilation_rate,
118
- n_layers,
119
- n_flows=4,
120
- gin_channels=0,
121
- ):
122
- super().__init__()
123
- self.channels = channels
124
- self.hidden_channels = hidden_channels
125
- self.kernel_size = kernel_size
126
- self.dilation_rate = dilation_rate
127
- self.n_layers = n_layers
128
- self.n_flows = n_flows
129
- self.gin_channels = gin_channels
130
-
131
- self.flows = nn.ModuleList()
132
- for i in range(n_flows):
133
- self.flows.append(
134
- modules.ResidualCouplingLayer(
135
- channels,
136
- hidden_channels,
137
- kernel_size,
138
- dilation_rate,
139
- n_layers,
140
- gin_channels=gin_channels,
141
- mean_only=True,
142
- )
143
- )
144
- self.flows.append(modules.Flip())
145
-
146
- def forward(self, x, x_mask, g=None, reverse=False):
147
- if not reverse:
148
- for flow in self.flows:
149
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
- else:
151
- for flow in reversed(self.flows):
152
- x = flow(x, x_mask, g=g, reverse=reverse)
153
- return x
154
-
155
- def remove_weight_norm(self):
156
- for i in range(self.n_flows):
157
- self.flows[i * 2].remove_weight_norm()
158
-
159
-
160
- class PosteriorEncoder(nn.Module):
161
- def __init__(
162
- self,
163
- in_channels,
164
- out_channels,
165
- hidden_channels,
166
- kernel_size,
167
- dilation_rate,
168
- n_layers,
169
- gin_channels=0,
170
- ):
171
- super().__init__()
172
- self.in_channels = in_channels
173
- self.out_channels = out_channels
174
- self.hidden_channels = hidden_channels
175
- self.kernel_size = kernel_size
176
- self.dilation_rate = dilation_rate
177
- self.n_layers = n_layers
178
- self.gin_channels = gin_channels
179
-
180
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
- self.enc = modules.WN(
182
- hidden_channels,
183
- kernel_size,
184
- dilation_rate,
185
- n_layers,
186
- gin_channels=gin_channels,
187
- )
188
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
-
190
- def forward(self, x, x_lengths, g=None):
191
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
- x.dtype
193
- )
194
- x = self.pre(x) * x_mask
195
- x = self.enc(x, x_mask, g=g)
196
- stats = self.proj(x) * x_mask
197
- m, logs = torch.split(stats, self.out_channels, dim=1)
198
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
- return z, m, logs, x_mask
200
-
201
- def remove_weight_norm(self):
202
- self.enc.remove_weight_norm()
203
-
204
-
205
- class Generator(torch.nn.Module):
206
- def __init__(
207
- self,
208
- initial_channel,
209
- resblock,
210
- resblock_kernel_sizes,
211
- resblock_dilation_sizes,
212
- upsample_rates,
213
- upsample_initial_channel,
214
- upsample_kernel_sizes,
215
- gin_channels=0,
216
- ):
217
- super(Generator, self).__init__()
218
- self.num_kernels = len(resblock_kernel_sizes)
219
- self.num_upsamples = len(upsample_rates)
220
- self.conv_pre = Conv1d(
221
- initial_channel, upsample_initial_channel, 7, 1, padding=3
222
- )
223
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
-
225
- self.ups = nn.ModuleList()
226
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
- self.ups.append(
228
- weight_norm(
229
- ConvTranspose1d(
230
- upsample_initial_channel // (2**i),
231
- upsample_initial_channel // (2 ** (i + 1)),
232
- k,
233
- u,
234
- padding=(k - u) // 2,
235
- )
236
- )
237
- )
238
-
239
- self.resblocks = nn.ModuleList()
240
- for i in range(len(self.ups)):
241
- ch = upsample_initial_channel // (2 ** (i + 1))
242
- for j, (k, d) in enumerate(
243
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
- ):
245
- self.resblocks.append(resblock(ch, k, d))
246
-
247
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
- self.ups.apply(init_weights)
249
-
250
- if gin_channels != 0:
251
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
-
253
- def forward(self, x, g=None):
254
- x = self.conv_pre(x)
255
- if g is not None:
256
- x = x + self.cond(g)
257
-
258
- for i in range(self.num_upsamples):
259
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
- x = self.ups[i](x)
261
- xs = None
262
- for j in range(self.num_kernels):
263
- if xs is None:
264
- xs = self.resblocks[i * self.num_kernels + j](x)
265
- else:
266
- xs += self.resblocks[i * self.num_kernels + j](x)
267
- x = xs / self.num_kernels
268
- x = F.leaky_relu(x)
269
- x = self.conv_post(x)
270
- x = torch.tanh(x)
271
-
272
- return x
273
-
274
- def remove_weight_norm(self):
275
- for l in self.ups:
276
- remove_weight_norm(l)
277
- for l in self.resblocks:
278
- l.remove_weight_norm()
279
-
280
-
281
- class SineGen(torch.nn.Module):
282
- """Definition of sine generator
283
- SineGen(samp_rate, harmonic_num = 0,
284
- sine_amp = 0.1, noise_std = 0.003,
285
- voiced_threshold = 0,
286
- flag_for_pulse=False)
287
- samp_rate: sampling rate in Hz
288
- harmonic_num: number of harmonic overtones (default 0)
289
- sine_amp: amplitude of sine-wavefrom (default 0.1)
290
- noise_std: std of Gaussian noise (default 0.003)
291
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
- Note: when flag_for_pulse is True, the first time step of a voiced
294
- segment is always sin(np.pi) or cos(0)
295
- """
296
-
297
- def __init__(
298
- self,
299
- samp_rate,
300
- harmonic_num=0,
301
- sine_amp=0.1,
302
- noise_std=0.003,
303
- voiced_threshold=0,
304
- flag_for_pulse=False,
305
- ):
306
- super(SineGen, self).__init__()
307
- self.sine_amp = sine_amp
308
- self.noise_std = noise_std
309
- self.harmonic_num = harmonic_num
310
- self.dim = self.harmonic_num + 1
311
- self.sampling_rate = samp_rate
312
- self.voiced_threshold = voiced_threshold
313
-
314
- def _f02uv(self, f0):
315
- # generate uv signal
316
- uv = torch.ones_like(f0)
317
- uv = uv * (f0 > self.voiced_threshold)
318
- return uv
319
-
320
- def forward(self, f0, upp):
321
- """sine_tensor, uv = forward(f0)
322
- input F0: tensor(batchsize=1, length, dim=1)
323
- f0 for unvoiced steps should be 0
324
- output sine_tensor: tensor(batchsize=1, length, dim)
325
- output uv: tensor(batchsize=1, length, 1)
326
- """
327
- with torch.no_grad():
328
- f0 = f0[:, None].transpose(1, 2)
329
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
- # fundamental component
331
- f0_buf[:, :, 0] = f0[:, :, 0]
332
- for idx in np.arange(self.harmonic_num):
333
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
- idx + 2
335
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
337
- rand_ini = torch.rand(
338
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
- )
340
- rand_ini[:, 0] = 0
341
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
343
- tmp_over_one *= upp
344
- tmp_over_one = F.interpolate(
345
- tmp_over_one.transpose(2, 1),
346
- scale_factor=upp,
347
- mode="linear",
348
- align_corners=True,
349
- ).transpose(2, 1)
350
- rad_values = F.interpolate(
351
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
- ).transpose(
353
- 2, 1
354
- ) #######
355
- tmp_over_one %= 1
356
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
- cumsum_shift = torch.zeros_like(rad_values)
358
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
- sine_waves = torch.sin(
360
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
- )
362
- sine_waves = sine_waves * self.sine_amp
363
- uv = self._f02uv(f0)
364
- uv = F.interpolate(
365
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
- ).transpose(2, 1)
367
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
- noise = noise_amp * torch.randn_like(sine_waves)
369
- sine_waves = sine_waves * uv + noise
370
- return sine_waves, uv, noise
371
-
372
-
373
- class SourceModuleHnNSF(torch.nn.Module):
374
- """SourceModule for hn-nsf
375
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
- add_noise_std=0.003, voiced_threshod=0)
377
- sampling_rate: sampling_rate in Hz
378
- harmonic_num: number of harmonic above F0 (default: 0)
379
- sine_amp: amplitude of sine source signal (default: 0.1)
380
- add_noise_std: std of additive Gaussian noise (default: 0.003)
381
- note that amplitude of noise in unvoiced is decided
382
- by sine_amp
383
- voiced_threshold: threhold to set U/V given F0 (default: 0)
384
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
- F0_sampled (batchsize, length, 1)
386
- Sine_source (batchsize, length, 1)
387
- noise_source (batchsize, length 1)
388
- uv (batchsize, length, 1)
389
- """
390
-
391
- def __init__(
392
- self,
393
- sampling_rate,
394
- harmonic_num=0,
395
- sine_amp=0.1,
396
- add_noise_std=0.003,
397
- voiced_threshod=0,
398
- is_half=True,
399
- ):
400
- super(SourceModuleHnNSF, self).__init__()
401
-
402
- self.sine_amp = sine_amp
403
- self.noise_std = add_noise_std
404
- self.is_half = is_half
405
- # to produce sine waveforms
406
- self.l_sin_gen = SineGen(
407
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
- )
409
-
410
- # to merge source harmonics into a single excitation
411
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
- self.l_tanh = torch.nn.Tanh()
413
-
414
- def forward(self, x, upp=None):
415
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
- if self.is_half:
417
- sine_wavs = sine_wavs.half()
418
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
- return sine_merge, None, None # noise, uv
420
-
421
-
422
- class GeneratorNSF(torch.nn.Module):
423
- def __init__(
424
- self,
425
- initial_channel,
426
- resblock,
427
- resblock_kernel_sizes,
428
- resblock_dilation_sizes,
429
- upsample_rates,
430
- upsample_initial_channel,
431
- upsample_kernel_sizes,
432
- gin_channels,
433
- sr,
434
- is_half=False,
435
- ):
436
- super(GeneratorNSF, self).__init__()
437
- self.num_kernels = len(resblock_kernel_sizes)
438
- self.num_upsamples = len(upsample_rates)
439
-
440
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
- self.m_source = SourceModuleHnNSF(
442
- sampling_rate=sr, harmonic_num=0, is_half=is_half
443
- )
444
- self.noise_convs = nn.ModuleList()
445
- self.conv_pre = Conv1d(
446
- initial_channel, upsample_initial_channel, 7, 1, padding=3
447
- )
448
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
-
450
- self.ups = nn.ModuleList()
451
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
- c_cur = upsample_initial_channel // (2 ** (i + 1))
453
- self.ups.append(
454
- weight_norm(
455
- ConvTranspose1d(
456
- upsample_initial_channel // (2**i),
457
- upsample_initial_channel // (2 ** (i + 1)),
458
- k,
459
- u,
460
- padding=(k - u) // 2,
461
- )
462
- )
463
- )
464
- if i + 1 < len(upsample_rates):
465
- stride_f0 = np.prod(upsample_rates[i + 1 :])
466
- self.noise_convs.append(
467
- Conv1d(
468
- 1,
469
- c_cur,
470
- kernel_size=stride_f0 * 2,
471
- stride=stride_f0,
472
- padding=stride_f0 // 2,
473
- )
474
- )
475
- else:
476
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
-
478
- self.resblocks = nn.ModuleList()
479
- for i in range(len(self.ups)):
480
- ch = upsample_initial_channel // (2 ** (i + 1))
481
- for j, (k, d) in enumerate(
482
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
- ):
484
- self.resblocks.append(resblock(ch, k, d))
485
-
486
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
- self.ups.apply(init_weights)
488
-
489
- if gin_channels != 0:
490
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
-
492
- self.upp = np.prod(upsample_rates)
493
-
494
- def forward(self, x, f0, g=None):
495
- har_source, noi_source, uv = self.m_source(f0, self.upp)
496
- har_source = har_source.transpose(1, 2)
497
- x = self.conv_pre(x)
498
- if g is not None:
499
- x = x + self.cond(g)
500
-
501
- for i in range(self.num_upsamples):
502
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
- x = self.ups[i](x)
504
- x_source = self.noise_convs[i](har_source)
505
- x = x + x_source
506
- xs = None
507
- for j in range(self.num_kernels):
508
- if xs is None:
509
- xs = self.resblocks[i * self.num_kernels + j](x)
510
- else:
511
- xs += self.resblocks[i * self.num_kernels + j](x)
512
- x = xs / self.num_kernels
513
- x = F.leaky_relu(x)
514
- x = self.conv_post(x)
515
- x = torch.tanh(x)
516
- return x
517
-
518
- def remove_weight_norm(self):
519
- for l in self.ups:
520
- remove_weight_norm(l)
521
- for l in self.resblocks:
522
- l.remove_weight_norm()
523
-
524
-
525
- sr2sr = {
526
- "32k": 32000,
527
- "40k": 40000,
528
- "48k": 48000,
529
- }
530
-
531
-
532
- class SynthesizerTrnMs256NSFsid(nn.Module):
533
- def __init__(
534
- self,
535
- spec_channels,
536
- segment_size,
537
- inter_channels,
538
- hidden_channels,
539
- filter_channels,
540
- n_heads,
541
- n_layers,
542
- kernel_size,
543
- p_dropout,
544
- resblock,
545
- resblock_kernel_sizes,
546
- resblock_dilation_sizes,
547
- upsample_rates,
548
- upsample_initial_channel,
549
- upsample_kernel_sizes,
550
- spk_embed_dim,
551
- gin_channels,
552
- sr,
553
- **kwargs
554
- ):
555
- super().__init__()
556
- if type(sr) == type("strr"):
557
- sr = sr2sr[sr]
558
- self.spec_channels = spec_channels
559
- self.inter_channels = inter_channels
560
- self.hidden_channels = hidden_channels
561
- self.filter_channels = filter_channels
562
- self.n_heads = n_heads
563
- self.n_layers = n_layers
564
- self.kernel_size = kernel_size
565
- self.p_dropout = p_dropout
566
- self.resblock = resblock
567
- self.resblock_kernel_sizes = resblock_kernel_sizes
568
- self.resblock_dilation_sizes = resblock_dilation_sizes
569
- self.upsample_rates = upsample_rates
570
- self.upsample_initial_channel = upsample_initial_channel
571
- self.upsample_kernel_sizes = upsample_kernel_sizes
572
- self.segment_size = segment_size
573
- self.gin_channels = gin_channels
574
- # self.hop_length = hop_length#
575
- self.spk_embed_dim = spk_embed_dim
576
- self.enc_p = TextEncoder256(
577
- inter_channels,
578
- hidden_channels,
579
- filter_channels,
580
- n_heads,
581
- n_layers,
582
- kernel_size,
583
- p_dropout,
584
- )
585
- self.dec = GeneratorNSF(
586
- inter_channels,
587
- resblock,
588
- resblock_kernel_sizes,
589
- resblock_dilation_sizes,
590
- upsample_rates,
591
- upsample_initial_channel,
592
- upsample_kernel_sizes,
593
- gin_channels=gin_channels,
594
- sr=sr,
595
- is_half=kwargs["is_half"],
596
- )
597
- self.enc_q = PosteriorEncoder(
598
- spec_channels,
599
- inter_channels,
600
- hidden_channels,
601
- 5,
602
- 1,
603
- 16,
604
- gin_channels=gin_channels,
605
- )
606
- self.flow = ResidualCouplingBlock(
607
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
608
- )
609
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
610
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
611
-
612
- def remove_weight_norm(self):
613
- self.dec.remove_weight_norm()
614
- self.flow.remove_weight_norm()
615
- self.enc_q.remove_weight_norm()
616
-
617
- def forward(
618
- self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
619
- ): # 这里ds是id,[bs,1]
620
- # print(1,pitch.shape)#[bs,t]
621
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
622
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
623
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
624
- z_p = self.flow(z, y_mask, g=g)
625
- z_slice, ids_slice = commons.rand_slice_segments(
626
- z, y_lengths, self.segment_size
627
- )
628
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
629
- pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
630
- # print(-2,pitchf.shape,z_slice.shape)
631
- o = self.dec(z_slice, pitchf, g=g)
632
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
633
-
634
- def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
635
- g = self.emb_g(sid).unsqueeze(-1)
636
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
637
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
638
- z = self.flow(z_p, x_mask, g=g, reverse=True)
639
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
640
- return o, x_mask, (z, z_p, m_p, logs_p)
641
-
642
-
643
- class SynthesizerTrnMs768NSFsid(nn.Module):
644
- def __init__(
645
- self,
646
- spec_channels,
647
- segment_size,
648
- inter_channels,
649
- hidden_channels,
650
- filter_channels,
651
- n_heads,
652
- n_layers,
653
- kernel_size,
654
- p_dropout,
655
- resblock,
656
- resblock_kernel_sizes,
657
- resblock_dilation_sizes,
658
- upsample_rates,
659
- upsample_initial_channel,
660
- upsample_kernel_sizes,
661
- spk_embed_dim,
662
- gin_channels,
663
- sr,
664
- **kwargs
665
- ):
666
- super().__init__()
667
- if type(sr) == type("strr"):
668
- sr = sr2sr[sr]
669
- self.spec_channels = spec_channels
670
- self.inter_channels = inter_channels
671
- self.hidden_channels = hidden_channels
672
- self.filter_channels = filter_channels
673
- self.n_heads = n_heads
674
- self.n_layers = n_layers
675
- self.kernel_size = kernel_size
676
- self.p_dropout = p_dropout
677
- self.resblock = resblock
678
- self.resblock_kernel_sizes = resblock_kernel_sizes
679
- self.resblock_dilation_sizes = resblock_dilation_sizes
680
- self.upsample_rates = upsample_rates
681
- self.upsample_initial_channel = upsample_initial_channel
682
- self.upsample_kernel_sizes = upsample_kernel_sizes
683
- self.segment_size = segment_size
684
- self.gin_channels = gin_channels
685
- # self.hop_length = hop_length#
686
- self.spk_embed_dim = spk_embed_dim
687
- self.enc_p = TextEncoder768(
688
- inter_channels,
689
- hidden_channels,
690
- filter_channels,
691
- n_heads,
692
- n_layers,
693
- kernel_size,
694
- p_dropout,
695
- )
696
- self.dec = GeneratorNSF(
697
- inter_channels,
698
- resblock,
699
- resblock_kernel_sizes,
700
- resblock_dilation_sizes,
701
- upsample_rates,
702
- upsample_initial_channel,
703
- upsample_kernel_sizes,
704
- gin_channels=gin_channels,
705
- sr=sr,
706
- is_half=kwargs["is_half"],
707
- )
708
- self.enc_q = PosteriorEncoder(
709
- spec_channels,
710
- inter_channels,
711
- hidden_channels,
712
- 5,
713
- 1,
714
- 16,
715
- gin_channels=gin_channels,
716
- )
717
- self.flow = ResidualCouplingBlock(
718
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
719
- )
720
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
721
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
722
-
723
- def remove_weight_norm(self):
724
- self.dec.remove_weight_norm()
725
- self.flow.remove_weight_norm()
726
- self.enc_q.remove_weight_norm()
727
-
728
- def forward(
729
- self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
730
- ): # 这里ds是id,[bs,1]
731
- # print(1,pitch.shape)#[bs,t]
732
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
733
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
734
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
735
- z_p = self.flow(z, y_mask, g=g)
736
- z_slice, ids_slice = commons.rand_slice_segments(
737
- z, y_lengths, self.segment_size
738
- )
739
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
740
- pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
741
- # print(-2,pitchf.shape,z_slice.shape)
742
- o = self.dec(z_slice, pitchf, g=g)
743
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
744
-
745
- def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
746
- g = self.emb_g(sid).unsqueeze(-1)
747
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
748
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
749
- z = self.flow(z_p, x_mask, g=g, reverse=True)
750
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
751
- return o, x_mask, (z, z_p, m_p, logs_p)
752
-
753
-
754
- class SynthesizerTrnMs256NSFsid_nono(nn.Module):
755
- def __init__(
756
- self,
757
- spec_channels,
758
- segment_size,
759
- inter_channels,
760
- hidden_channels,
761
- filter_channels,
762
- n_heads,
763
- n_layers,
764
- kernel_size,
765
- p_dropout,
766
- resblock,
767
- resblock_kernel_sizes,
768
- resblock_dilation_sizes,
769
- upsample_rates,
770
- upsample_initial_channel,
771
- upsample_kernel_sizes,
772
- spk_embed_dim,
773
- gin_channels,
774
- sr=None,
775
- **kwargs
776
- ):
777
- super().__init__()
778
- self.spec_channels = spec_channels
779
- self.inter_channels = inter_channels
780
- self.hidden_channels = hidden_channels
781
- self.filter_channels = filter_channels
782
- self.n_heads = n_heads
783
- self.n_layers = n_layers
784
- self.kernel_size = kernel_size
785
- self.p_dropout = p_dropout
786
- self.resblock = resblock
787
- self.resblock_kernel_sizes = resblock_kernel_sizes
788
- self.resblock_dilation_sizes = resblock_dilation_sizes
789
- self.upsample_rates = upsample_rates
790
- self.upsample_initial_channel = upsample_initial_channel
791
- self.upsample_kernel_sizes = upsample_kernel_sizes
792
- self.segment_size = segment_size
793
- self.gin_channels = gin_channels
794
- # self.hop_length = hop_length#
795
- self.spk_embed_dim = spk_embed_dim
796
- self.enc_p = TextEncoder256(
797
- inter_channels,
798
- hidden_channels,
799
- filter_channels,
800
- n_heads,
801
- n_layers,
802
- kernel_size,
803
- p_dropout,
804
- f0=False,
805
- )
806
- self.dec = Generator(
807
- inter_channels,
808
- resblock,
809
- resblock_kernel_sizes,
810
- resblock_dilation_sizes,
811
- upsample_rates,
812
- upsample_initial_channel,
813
- upsample_kernel_sizes,
814
- gin_channels=gin_channels,
815
- )
816
- self.enc_q = PosteriorEncoder(
817
- spec_channels,
818
- inter_channels,
819
- hidden_channels,
820
- 5,
821
- 1,
822
- 16,
823
- gin_channels=gin_channels,
824
- )
825
- self.flow = ResidualCouplingBlock(
826
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
827
- )
828
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
829
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
830
-
831
- def remove_weight_norm(self):
832
- self.dec.remove_weight_norm()
833
- self.flow.remove_weight_norm()
834
- self.enc_q.remove_weight_norm()
835
-
836
- def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
837
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
838
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
839
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
840
- z_p = self.flow(z, y_mask, g=g)
841
- z_slice, ids_slice = commons.rand_slice_segments(
842
- z, y_lengths, self.segment_size
843
- )
844
- o = self.dec(z_slice, g=g)
845
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
846
-
847
- def infer(self, phone, phone_lengths, sid, max_len=None):
848
- g = self.emb_g(sid).unsqueeze(-1)
849
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
850
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
851
- z = self.flow(z_p, x_mask, g=g, reverse=True)
852
- o = self.dec((z * x_mask)[:, :, :max_len], g=g)
853
- return o, x_mask, (z, z_p, m_p, logs_p)
854
-
855
-
856
- class SynthesizerTrnMs768NSFsid_nono(nn.Module):
857
- def __init__(
858
- self,
859
- spec_channels,
860
- segment_size,
861
- inter_channels,
862
- hidden_channels,
863
- filter_channels,
864
- n_heads,
865
- n_layers,
866
- kernel_size,
867
- p_dropout,
868
- resblock,
869
- resblock_kernel_sizes,
870
- resblock_dilation_sizes,
871
- upsample_rates,
872
- upsample_initial_channel,
873
- upsample_kernel_sizes,
874
- spk_embed_dim,
875
- gin_channels,
876
- sr=None,
877
- **kwargs
878
- ):
879
- super().__init__()
880
- self.spec_channels = spec_channels
881
- self.inter_channels = inter_channels
882
- self.hidden_channels = hidden_channels
883
- self.filter_channels = filter_channels
884
- self.n_heads = n_heads
885
- self.n_layers = n_layers
886
- self.kernel_size = kernel_size
887
- self.p_dropout = p_dropout
888
- self.resblock = resblock
889
- self.resblock_kernel_sizes = resblock_kernel_sizes
890
- self.resblock_dilation_sizes = resblock_dilation_sizes
891
- self.upsample_rates = upsample_rates
892
- self.upsample_initial_channel = upsample_initial_channel
893
- self.upsample_kernel_sizes = upsample_kernel_sizes
894
- self.segment_size = segment_size
895
- self.gin_channels = gin_channels
896
- # self.hop_length = hop_length#
897
- self.spk_embed_dim = spk_embed_dim
898
- self.enc_p = TextEncoder768(
899
- inter_channels,
900
- hidden_channels,
901
- filter_channels,
902
- n_heads,
903
- n_layers,
904
- kernel_size,
905
- p_dropout,
906
- f0=False,
907
- )
908
- self.dec = Generator(
909
- inter_channels,
910
- resblock,
911
- resblock_kernel_sizes,
912
- resblock_dilation_sizes,
913
- upsample_rates,
914
- upsample_initial_channel,
915
- upsample_kernel_sizes,
916
- gin_channels=gin_channels,
917
- )
918
- self.enc_q = PosteriorEncoder(
919
- spec_channels,
920
- inter_channels,
921
- hidden_channels,
922
- 5,
923
- 1,
924
- 16,
925
- gin_channels=gin_channels,
926
- )
927
- self.flow = ResidualCouplingBlock(
928
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
929
- )
930
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
931
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
932
-
933
- def remove_weight_norm(self):
934
- self.dec.remove_weight_norm()
935
- self.flow.remove_weight_norm()
936
- self.enc_q.remove_weight_norm()
937
-
938
- def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
939
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
940
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
941
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
942
- z_p = self.flow(z, y_mask, g=g)
943
- z_slice, ids_slice = commons.rand_slice_segments(
944
- z, y_lengths, self.segment_size
945
- )
946
- o = self.dec(z_slice, g=g)
947
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
948
-
949
- def infer(self, phone, phone_lengths, sid, max_len=None):
950
- g = self.emb_g(sid).unsqueeze(-1)
951
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
952
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
953
- z = self.flow(z_p, x_mask, g=g, reverse=True)
954
- o = self.dec((z * x_mask)[:, :, :max_len], g=g)
955
- return o, x_mask, (z, z_p, m_p, logs_p)
956
-
957
-
958
- class MultiPeriodDiscriminator(torch.nn.Module):
959
- def __init__(self, use_spectral_norm=False):
960
- super(MultiPeriodDiscriminator, self).__init__()
961
- periods = [2, 3, 5, 7, 11, 17]
962
- # periods = [3, 5, 7, 11, 17, 23, 37]
963
-
964
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
965
- discs = discs + [
966
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
967
- ]
968
- self.discriminators = nn.ModuleList(discs)
969
-
970
- def forward(self, y, y_hat):
971
- y_d_rs = [] #
972
- y_d_gs = []
973
- fmap_rs = []
974
- fmap_gs = []
975
- for i, d in enumerate(self.discriminators):
976
- y_d_r, fmap_r = d(y)
977
- y_d_g, fmap_g = d(y_hat)
978
- # for j in range(len(fmap_r)):
979
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
980
- y_d_rs.append(y_d_r)
981
- y_d_gs.append(y_d_g)
982
- fmap_rs.append(fmap_r)
983
- fmap_gs.append(fmap_g)
984
-
985
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
986
-
987
-
988
- class MultiPeriodDiscriminatorV2(torch.nn.Module):
989
- def __init__(self, use_spectral_norm=False):
990
- super(MultiPeriodDiscriminatorV2, self).__init__()
991
- # periods = [2, 3, 5, 7, 11, 17]
992
- periods = [2, 3, 5, 7, 11, 17, 23, 37]
993
-
994
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
995
- discs = discs + [
996
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
997
- ]
998
- self.discriminators = nn.ModuleList(discs)
999
-
1000
- def forward(self, y, y_hat):
1001
- y_d_rs = [] #
1002
- y_d_gs = []
1003
- fmap_rs = []
1004
- fmap_gs = []
1005
- for i, d in enumerate(self.discriminators):
1006
- y_d_r, fmap_r = d(y)
1007
- y_d_g, fmap_g = d(y_hat)
1008
- # for j in range(len(fmap_r)):
1009
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1010
- y_d_rs.append(y_d_r)
1011
- y_d_gs.append(y_d_g)
1012
- fmap_rs.append(fmap_r)
1013
- fmap_gs.append(fmap_g)
1014
-
1015
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1016
-
1017
-
1018
- class DiscriminatorS(torch.nn.Module):
1019
- def __init__(self, use_spectral_norm=False):
1020
- super(DiscriminatorS, self).__init__()
1021
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1022
- self.convs = nn.ModuleList(
1023
- [
1024
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1025
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1026
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1027
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1028
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1029
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1030
- ]
1031
- )
1032
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1033
-
1034
- def forward(self, x):
1035
- fmap = []
1036
-
1037
- for l in self.convs:
1038
- x = l(x)
1039
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
1040
- fmap.append(x)
1041
- x = self.conv_post(x)
1042
- fmap.append(x)
1043
- x = torch.flatten(x, 1, -1)
1044
-
1045
- return x, fmap
1046
-
1047
-
1048
- class DiscriminatorP(torch.nn.Module):
1049
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1050
- super(DiscriminatorP, self).__init__()
1051
- self.period = period
1052
- self.use_spectral_norm = use_spectral_norm
1053
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1054
- self.convs = nn.ModuleList(
1055
- [
1056
- norm_f(
1057
- Conv2d(
1058
- 1,
1059
- 32,
1060
- (kernel_size, 1),
1061
- (stride, 1),
1062
- padding=(get_padding(kernel_size, 1), 0),
1063
- )
1064
- ),
1065
- norm_f(
1066
- Conv2d(
1067
- 32,
1068
- 128,
1069
- (kernel_size, 1),
1070
- (stride, 1),
1071
- padding=(get_padding(kernel_size, 1), 0),
1072
- )
1073
- ),
1074
- norm_f(
1075
- Conv2d(
1076
- 128,
1077
- 512,
1078
- (kernel_size, 1),
1079
- (stride, 1),
1080
- padding=(get_padding(kernel_size, 1), 0),
1081
- )
1082
- ),
1083
- norm_f(
1084
- Conv2d(
1085
- 512,
1086
- 1024,
1087
- (kernel_size, 1),
1088
- (stride, 1),
1089
- padding=(get_padding(kernel_size, 1), 0),
1090
- )
1091
- ),
1092
- norm_f(
1093
- Conv2d(
1094
- 1024,
1095
- 1024,
1096
- (kernel_size, 1),
1097
- 1,
1098
- padding=(get_padding(kernel_size, 1), 0),
1099
- )
1100
- ),
1101
- ]
1102
- )
1103
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1104
-
1105
- def forward(self, x):
1106
- fmap = []
1107
-
1108
- # 1d to 2d
1109
- b, c, t = x.shape
1110
- if t % self.period != 0: # pad first
1111
- n_pad = self.period - (t % self.period)
1112
- x = F.pad(x, (0, n_pad), "reflect")
1113
- t = t + n_pad
1114
- x = x.view(b, c, t // self.period, self.period)
1115
-
1116
- for l in self.convs:
1117
- x = l(x)
1118
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
1119
- fmap.append(x)
1120
- x = self.conv_post(x)
1121
- fmap.append(x)
1122
- x = torch.flatten(x, 1, -1)
1123
-
1124
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models_onnx.py DELETED
@@ -1,819 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder768(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(768, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- stats = self.proj(x) * x_mask
106
-
107
- m, logs = torch.split(stats, self.out_channels, dim=1)
108
- return m, logs, x_mask
109
-
110
-
111
- class ResidualCouplingBlock(nn.Module):
112
- def __init__(
113
- self,
114
- channels,
115
- hidden_channels,
116
- kernel_size,
117
- dilation_rate,
118
- n_layers,
119
- n_flows=4,
120
- gin_channels=0,
121
- ):
122
- super().__init__()
123
- self.channels = channels
124
- self.hidden_channels = hidden_channels
125
- self.kernel_size = kernel_size
126
- self.dilation_rate = dilation_rate
127
- self.n_layers = n_layers
128
- self.n_flows = n_flows
129
- self.gin_channels = gin_channels
130
-
131
- self.flows = nn.ModuleList()
132
- for i in range(n_flows):
133
- self.flows.append(
134
- modules.ResidualCouplingLayer(
135
- channels,
136
- hidden_channels,
137
- kernel_size,
138
- dilation_rate,
139
- n_layers,
140
- gin_channels=gin_channels,
141
- mean_only=True,
142
- )
143
- )
144
- self.flows.append(modules.Flip())
145
-
146
- def forward(self, x, x_mask, g=None, reverse=False):
147
- if not reverse:
148
- for flow in self.flows:
149
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
- else:
151
- for flow in reversed(self.flows):
152
- x = flow(x, x_mask, g=g, reverse=reverse)
153
- return x
154
-
155
- def remove_weight_norm(self):
156
- for i in range(self.n_flows):
157
- self.flows[i * 2].remove_weight_norm()
158
-
159
-
160
- class PosteriorEncoder(nn.Module):
161
- def __init__(
162
- self,
163
- in_channels,
164
- out_channels,
165
- hidden_channels,
166
- kernel_size,
167
- dilation_rate,
168
- n_layers,
169
- gin_channels=0,
170
- ):
171
- super().__init__()
172
- self.in_channels = in_channels
173
- self.out_channels = out_channels
174
- self.hidden_channels = hidden_channels
175
- self.kernel_size = kernel_size
176
- self.dilation_rate = dilation_rate
177
- self.n_layers = n_layers
178
- self.gin_channels = gin_channels
179
-
180
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
- self.enc = modules.WN(
182
- hidden_channels,
183
- kernel_size,
184
- dilation_rate,
185
- n_layers,
186
- gin_channels=gin_channels,
187
- )
188
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
-
190
- def forward(self, x, x_lengths, g=None):
191
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
- x.dtype
193
- )
194
- x = self.pre(x) * x_mask
195
- x = self.enc(x, x_mask, g=g)
196
- stats = self.proj(x) * x_mask
197
- m, logs = torch.split(stats, self.out_channels, dim=1)
198
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
- return z, m, logs, x_mask
200
-
201
- def remove_weight_norm(self):
202
- self.enc.remove_weight_norm()
203
-
204
-
205
- class Generator(torch.nn.Module):
206
- def __init__(
207
- self,
208
- initial_channel,
209
- resblock,
210
- resblock_kernel_sizes,
211
- resblock_dilation_sizes,
212
- upsample_rates,
213
- upsample_initial_channel,
214
- upsample_kernel_sizes,
215
- gin_channels=0,
216
- ):
217
- super(Generator, self).__init__()
218
- self.num_kernels = len(resblock_kernel_sizes)
219
- self.num_upsamples = len(upsample_rates)
220
- self.conv_pre = Conv1d(
221
- initial_channel, upsample_initial_channel, 7, 1, padding=3
222
- )
223
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
-
225
- self.ups = nn.ModuleList()
226
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
- self.ups.append(
228
- weight_norm(
229
- ConvTranspose1d(
230
- upsample_initial_channel // (2**i),
231
- upsample_initial_channel // (2 ** (i + 1)),
232
- k,
233
- u,
234
- padding=(k - u) // 2,
235
- )
236
- )
237
- )
238
-
239
- self.resblocks = nn.ModuleList()
240
- for i in range(len(self.ups)):
241
- ch = upsample_initial_channel // (2 ** (i + 1))
242
- for j, (k, d) in enumerate(
243
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
- ):
245
- self.resblocks.append(resblock(ch, k, d))
246
-
247
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
- self.ups.apply(init_weights)
249
-
250
- if gin_channels != 0:
251
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
-
253
- def forward(self, x, g=None):
254
- x = self.conv_pre(x)
255
- if g is not None:
256
- x = x + self.cond(g)
257
-
258
- for i in range(self.num_upsamples):
259
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
- x = self.ups[i](x)
261
- xs = None
262
- for j in range(self.num_kernels):
263
- if xs is None:
264
- xs = self.resblocks[i * self.num_kernels + j](x)
265
- else:
266
- xs += self.resblocks[i * self.num_kernels + j](x)
267
- x = xs / self.num_kernels
268
- x = F.leaky_relu(x)
269
- x = self.conv_post(x)
270
- x = torch.tanh(x)
271
-
272
- return x
273
-
274
- def remove_weight_norm(self):
275
- for l in self.ups:
276
- remove_weight_norm(l)
277
- for l in self.resblocks:
278
- l.remove_weight_norm()
279
-
280
-
281
- class SineGen(torch.nn.Module):
282
- """Definition of sine generator
283
- SineGen(samp_rate, harmonic_num = 0,
284
- sine_amp = 0.1, noise_std = 0.003,
285
- voiced_threshold = 0,
286
- flag_for_pulse=False)
287
- samp_rate: sampling rate in Hz
288
- harmonic_num: number of harmonic overtones (default 0)
289
- sine_amp: amplitude of sine-wavefrom (default 0.1)
290
- noise_std: std of Gaussian noise (default 0.003)
291
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
- Note: when flag_for_pulse is True, the first time step of a voiced
294
- segment is always sin(np.pi) or cos(0)
295
- """
296
-
297
- def __init__(
298
- self,
299
- samp_rate,
300
- harmonic_num=0,
301
- sine_amp=0.1,
302
- noise_std=0.003,
303
- voiced_threshold=0,
304
- flag_for_pulse=False,
305
- ):
306
- super(SineGen, self).__init__()
307
- self.sine_amp = sine_amp
308
- self.noise_std = noise_std
309
- self.harmonic_num = harmonic_num
310
- self.dim = self.harmonic_num + 1
311
- self.sampling_rate = samp_rate
312
- self.voiced_threshold = voiced_threshold
313
-
314
- def _f02uv(self, f0):
315
- # generate uv signal
316
- uv = torch.ones_like(f0)
317
- uv = uv * (f0 > self.voiced_threshold)
318
- return uv
319
-
320
- def forward(self, f0, upp):
321
- """sine_tensor, uv = forward(f0)
322
- input F0: tensor(batchsize=1, length, dim=1)
323
- f0 for unvoiced steps should be 0
324
- output sine_tensor: tensor(batchsize=1, length, dim)
325
- output uv: tensor(batchsize=1, length, 1)
326
- """
327
- with torch.no_grad():
328
- f0 = f0[:, None].transpose(1, 2)
329
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
- # fundamental component
331
- f0_buf[:, :, 0] = f0[:, :, 0]
332
- for idx in np.arange(self.harmonic_num):
333
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
- idx + 2
335
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
337
- rand_ini = torch.rand(
338
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
- )
340
- rand_ini[:, 0] = 0
341
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
343
- tmp_over_one *= upp
344
- tmp_over_one = F.interpolate(
345
- tmp_over_one.transpose(2, 1),
346
- scale_factor=upp,
347
- mode="linear",
348
- align_corners=True,
349
- ).transpose(2, 1)
350
- rad_values = F.interpolate(
351
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
- ).transpose(
353
- 2, 1
354
- ) #######
355
- tmp_over_one %= 1
356
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
- cumsum_shift = torch.zeros_like(rad_values)
358
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
- sine_waves = torch.sin(
360
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
- )
362
- sine_waves = sine_waves * self.sine_amp
363
- uv = self._f02uv(f0)
364
- uv = F.interpolate(
365
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
- ).transpose(2, 1)
367
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
- noise = noise_amp * torch.randn_like(sine_waves)
369
- sine_waves = sine_waves * uv + noise
370
- return sine_waves, uv, noise
371
-
372
-
373
- class SourceModuleHnNSF(torch.nn.Module):
374
- """SourceModule for hn-nsf
375
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
- add_noise_std=0.003, voiced_threshod=0)
377
- sampling_rate: sampling_rate in Hz
378
- harmonic_num: number of harmonic above F0 (default: 0)
379
- sine_amp: amplitude of sine source signal (default: 0.1)
380
- add_noise_std: std of additive Gaussian noise (default: 0.003)
381
- note that amplitude of noise in unvoiced is decided
382
- by sine_amp
383
- voiced_threshold: threhold to set U/V given F0 (default: 0)
384
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
- F0_sampled (batchsize, length, 1)
386
- Sine_source (batchsize, length, 1)
387
- noise_source (batchsize, length 1)
388
- uv (batchsize, length, 1)
389
- """
390
-
391
- def __init__(
392
- self,
393
- sampling_rate,
394
- harmonic_num=0,
395
- sine_amp=0.1,
396
- add_noise_std=0.003,
397
- voiced_threshod=0,
398
- is_half=True,
399
- ):
400
- super(SourceModuleHnNSF, self).__init__()
401
-
402
- self.sine_amp = sine_amp
403
- self.noise_std = add_noise_std
404
- self.is_half = is_half
405
- # to produce sine waveforms
406
- self.l_sin_gen = SineGen(
407
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
- )
409
-
410
- # to merge source harmonics into a single excitation
411
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
- self.l_tanh = torch.nn.Tanh()
413
-
414
- def forward(self, x, upp=None):
415
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
- if self.is_half:
417
- sine_wavs = sine_wavs.half()
418
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
- return sine_merge, None, None # noise, uv
420
-
421
-
422
- class GeneratorNSF(torch.nn.Module):
423
- def __init__(
424
- self,
425
- initial_channel,
426
- resblock,
427
- resblock_kernel_sizes,
428
- resblock_dilation_sizes,
429
- upsample_rates,
430
- upsample_initial_channel,
431
- upsample_kernel_sizes,
432
- gin_channels,
433
- sr,
434
- is_half=False,
435
- ):
436
- super(GeneratorNSF, self).__init__()
437
- self.num_kernels = len(resblock_kernel_sizes)
438
- self.num_upsamples = len(upsample_rates)
439
-
440
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
- self.m_source = SourceModuleHnNSF(
442
- sampling_rate=sr, harmonic_num=0, is_half=is_half
443
- )
444
- self.noise_convs = nn.ModuleList()
445
- self.conv_pre = Conv1d(
446
- initial_channel, upsample_initial_channel, 7, 1, padding=3
447
- )
448
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
-
450
- self.ups = nn.ModuleList()
451
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
- c_cur = upsample_initial_channel // (2 ** (i + 1))
453
- self.ups.append(
454
- weight_norm(
455
- ConvTranspose1d(
456
- upsample_initial_channel // (2**i),
457
- upsample_initial_channel // (2 ** (i + 1)),
458
- k,
459
- u,
460
- padding=(k - u) // 2,
461
- )
462
- )
463
- )
464
- if i + 1 < len(upsample_rates):
465
- stride_f0 = np.prod(upsample_rates[i + 1 :])
466
- self.noise_convs.append(
467
- Conv1d(
468
- 1,
469
- c_cur,
470
- kernel_size=stride_f0 * 2,
471
- stride=stride_f0,
472
- padding=stride_f0 // 2,
473
- )
474
- )
475
- else:
476
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
-
478
- self.resblocks = nn.ModuleList()
479
- for i in range(len(self.ups)):
480
- ch = upsample_initial_channel // (2 ** (i + 1))
481
- for j, (k, d) in enumerate(
482
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
- ):
484
- self.resblocks.append(resblock(ch, k, d))
485
-
486
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
- self.ups.apply(init_weights)
488
-
489
- if gin_channels != 0:
490
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
-
492
- self.upp = np.prod(upsample_rates)
493
-
494
- def forward(self, x, f0, g=None):
495
- har_source, noi_source, uv = self.m_source(f0, self.upp)
496
- har_source = har_source.transpose(1, 2)
497
- x = self.conv_pre(x)
498
- if g is not None:
499
- x = x + self.cond(g)
500
-
501
- for i in range(self.num_upsamples):
502
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
- x = self.ups[i](x)
504
- x_source = self.noise_convs[i](har_source)
505
- x = x + x_source
506
- xs = None
507
- for j in range(self.num_kernels):
508
- if xs is None:
509
- xs = self.resblocks[i * self.num_kernels + j](x)
510
- else:
511
- xs += self.resblocks[i * self.num_kernels + j](x)
512
- x = xs / self.num_kernels
513
- x = F.leaky_relu(x)
514
- x = self.conv_post(x)
515
- x = torch.tanh(x)
516
- return x
517
-
518
- def remove_weight_norm(self):
519
- for l in self.ups:
520
- remove_weight_norm(l)
521
- for l in self.resblocks:
522
- l.remove_weight_norm()
523
-
524
-
525
- sr2sr = {
526
- "32k": 32000,
527
- "40k": 40000,
528
- "48k": 48000,
529
- }
530
-
531
-
532
- class SynthesizerTrnMsNSFsidM(nn.Module):
533
- def __init__(
534
- self,
535
- spec_channels,
536
- segment_size,
537
- inter_channels,
538
- hidden_channels,
539
- filter_channels,
540
- n_heads,
541
- n_layers,
542
- kernel_size,
543
- p_dropout,
544
- resblock,
545
- resblock_kernel_sizes,
546
- resblock_dilation_sizes,
547
- upsample_rates,
548
- upsample_initial_channel,
549
- upsample_kernel_sizes,
550
- spk_embed_dim,
551
- gin_channels,
552
- sr,
553
- version,
554
- **kwargs
555
- ):
556
- super().__init__()
557
- if type(sr) == type("strr"):
558
- sr = sr2sr[sr]
559
- self.spec_channels = spec_channels
560
- self.inter_channels = inter_channels
561
- self.hidden_channels = hidden_channels
562
- self.filter_channels = filter_channels
563
- self.n_heads = n_heads
564
- self.n_layers = n_layers
565
- self.kernel_size = kernel_size
566
- self.p_dropout = p_dropout
567
- self.resblock = resblock
568
- self.resblock_kernel_sizes = resblock_kernel_sizes
569
- self.resblock_dilation_sizes = resblock_dilation_sizes
570
- self.upsample_rates = upsample_rates
571
- self.upsample_initial_channel = upsample_initial_channel
572
- self.upsample_kernel_sizes = upsample_kernel_sizes
573
- self.segment_size = segment_size
574
- self.gin_channels = gin_channels
575
- # self.hop_length = hop_length#
576
- self.spk_embed_dim = spk_embed_dim
577
- if version == "v1":
578
- self.enc_p = TextEncoder256(
579
- inter_channels,
580
- hidden_channels,
581
- filter_channels,
582
- n_heads,
583
- n_layers,
584
- kernel_size,
585
- p_dropout,
586
- )
587
- else:
588
- self.enc_p = TextEncoder768(
589
- inter_channels,
590
- hidden_channels,
591
- filter_channels,
592
- n_heads,
593
- n_layers,
594
- kernel_size,
595
- p_dropout,
596
- )
597
- self.dec = GeneratorNSF(
598
- inter_channels,
599
- resblock,
600
- resblock_kernel_sizes,
601
- resblock_dilation_sizes,
602
- upsample_rates,
603
- upsample_initial_channel,
604
- upsample_kernel_sizes,
605
- gin_channels=gin_channels,
606
- sr=sr,
607
- is_half=kwargs["is_half"],
608
- )
609
- self.enc_q = PosteriorEncoder(
610
- spec_channels,
611
- inter_channels,
612
- hidden_channels,
613
- 5,
614
- 1,
615
- 16,
616
- gin_channels=gin_channels,
617
- )
618
- self.flow = ResidualCouplingBlock(
619
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
620
- )
621
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
622
- self.speaker_map = None
623
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
624
-
625
- def remove_weight_norm(self):
626
- self.dec.remove_weight_norm()
627
- self.flow.remove_weight_norm()
628
- self.enc_q.remove_weight_norm()
629
-
630
- def construct_spkmixmap(self, n_speaker):
631
- self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
632
- for i in range(n_speaker):
633
- self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
634
- self.speaker_map = self.speaker_map.unsqueeze(0)
635
-
636
- def forward(self, phone, phone_lengths, pitch, nsff0, g, rnd, max_len=None):
637
- if self.speaker_map is not None: # [N, S] * [S, B, 1, H]
638
- g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
639
- g = g * self.speaker_map # [N, S, B, 1, H]
640
- g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
641
- g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
642
- else:
643
- g = g.unsqueeze(0)
644
- g = self.emb_g(g).transpose(1, 2)
645
-
646
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
647
- z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
648
- z = self.flow(z_p, x_mask, g=g, reverse=True)
649
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
650
- return o
651
-
652
-
653
- class MultiPeriodDiscriminator(torch.nn.Module):
654
- def __init__(self, use_spectral_norm=False):
655
- super(MultiPeriodDiscriminator, self).__init__()
656
- periods = [2, 3, 5, 7, 11, 17]
657
- # periods = [3, 5, 7, 11, 17, 23, 37]
658
-
659
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
660
- discs = discs + [
661
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
662
- ]
663
- self.discriminators = nn.ModuleList(discs)
664
-
665
- def forward(self, y, y_hat):
666
- y_d_rs = [] #
667
- y_d_gs = []
668
- fmap_rs = []
669
- fmap_gs = []
670
- for i, d in enumerate(self.discriminators):
671
- y_d_r, fmap_r = d(y)
672
- y_d_g, fmap_g = d(y_hat)
673
- # for j in range(len(fmap_r)):
674
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
675
- y_d_rs.append(y_d_r)
676
- y_d_gs.append(y_d_g)
677
- fmap_rs.append(fmap_r)
678
- fmap_gs.append(fmap_g)
679
-
680
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
681
-
682
-
683
- class MultiPeriodDiscriminatorV2(torch.nn.Module):
684
- def __init__(self, use_spectral_norm=False):
685
- super(MultiPeriodDiscriminatorV2, self).__init__()
686
- # periods = [2, 3, 5, 7, 11, 17]
687
- periods = [2, 3, 5, 7, 11, 17, 23, 37]
688
-
689
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
690
- discs = discs + [
691
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
692
- ]
693
- self.discriminators = nn.ModuleList(discs)
694
-
695
- def forward(self, y, y_hat):
696
- y_d_rs = [] #
697
- y_d_gs = []
698
- fmap_rs = []
699
- fmap_gs = []
700
- for i, d in enumerate(self.discriminators):
701
- y_d_r, fmap_r = d(y)
702
- y_d_g, fmap_g = d(y_hat)
703
- # for j in range(len(fmap_r)):
704
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
705
- y_d_rs.append(y_d_r)
706
- y_d_gs.append(y_d_g)
707
- fmap_rs.append(fmap_r)
708
- fmap_gs.append(fmap_g)
709
-
710
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
711
-
712
-
713
- class DiscriminatorS(torch.nn.Module):
714
- def __init__(self, use_spectral_norm=False):
715
- super(DiscriminatorS, self).__init__()
716
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
717
- self.convs = nn.ModuleList(
718
- [
719
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
720
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
721
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
722
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
723
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
724
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
725
- ]
726
- )
727
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
728
-
729
- def forward(self, x):
730
- fmap = []
731
-
732
- for l in self.convs:
733
- x = l(x)
734
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
735
- fmap.append(x)
736
- x = self.conv_post(x)
737
- fmap.append(x)
738
- x = torch.flatten(x, 1, -1)
739
-
740
- return x, fmap
741
-
742
-
743
- class DiscriminatorP(torch.nn.Module):
744
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
745
- super(DiscriminatorP, self).__init__()
746
- self.period = period
747
- self.use_spectral_norm = use_spectral_norm
748
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
749
- self.convs = nn.ModuleList(
750
- [
751
- norm_f(
752
- Conv2d(
753
- 1,
754
- 32,
755
- (kernel_size, 1),
756
- (stride, 1),
757
- padding=(get_padding(kernel_size, 1), 0),
758
- )
759
- ),
760
- norm_f(
761
- Conv2d(
762
- 32,
763
- 128,
764
- (kernel_size, 1),
765
- (stride, 1),
766
- padding=(get_padding(kernel_size, 1), 0),
767
- )
768
- ),
769
- norm_f(
770
- Conv2d(
771
- 128,
772
- 512,
773
- (kernel_size, 1),
774
- (stride, 1),
775
- padding=(get_padding(kernel_size, 1), 0),
776
- )
777
- ),
778
- norm_f(
779
- Conv2d(
780
- 512,
781
- 1024,
782
- (kernel_size, 1),
783
- (stride, 1),
784
- padding=(get_padding(kernel_size, 1), 0),
785
- )
786
- ),
787
- norm_f(
788
- Conv2d(
789
- 1024,
790
- 1024,
791
- (kernel_size, 1),
792
- 1,
793
- padding=(get_padding(kernel_size, 1), 0),
794
- )
795
- ),
796
- ]
797
- )
798
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
799
-
800
- def forward(self, x):
801
- fmap = []
802
-
803
- # 1d to 2d
804
- b, c, t = x.shape
805
- if t % self.period != 0: # pad first
806
- n_pad = self.period - (t % self.period)
807
- x = F.pad(x, (0, n_pad), "reflect")
808
- t = t + n_pad
809
- x = x.view(b, c, t // self.period, self.period)
810
-
811
- for l in self.convs:
812
- x = l(x)
813
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
814
- fmap.append(x)
815
- x = self.conv_post(x)
816
- fmap.append(x)
817
- x = torch.flatten(x, 1, -1)
818
-
819
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models_onnx_moess.py DELETED
@@ -1,849 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder256Sim(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(256, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- x = self.proj(x) * x_mask
106
- return x, x_mask
107
-
108
-
109
- class ResidualCouplingBlock(nn.Module):
110
- def __init__(
111
- self,
112
- channels,
113
- hidden_channels,
114
- kernel_size,
115
- dilation_rate,
116
- n_layers,
117
- n_flows=4,
118
- gin_channels=0,
119
- ):
120
- super().__init__()
121
- self.channels = channels
122
- self.hidden_channels = hidden_channels
123
- self.kernel_size = kernel_size
124
- self.dilation_rate = dilation_rate
125
- self.n_layers = n_layers
126
- self.n_flows = n_flows
127
- self.gin_channels = gin_channels
128
-
129
- self.flows = nn.ModuleList()
130
- for i in range(n_flows):
131
- self.flows.append(
132
- modules.ResidualCouplingLayer(
133
- channels,
134
- hidden_channels,
135
- kernel_size,
136
- dilation_rate,
137
- n_layers,
138
- gin_channels=gin_channels,
139
- mean_only=True,
140
- )
141
- )
142
- self.flows.append(modules.Flip())
143
-
144
- def forward(self, x, x_mask, g=None, reverse=False):
145
- if not reverse:
146
- for flow in self.flows:
147
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
- else:
149
- for flow in reversed(self.flows):
150
- x = flow(x, x_mask, g=g, reverse=reverse)
151
- return x
152
-
153
- def remove_weight_norm(self):
154
- for i in range(self.n_flows):
155
- self.flows[i * 2].remove_weight_norm()
156
-
157
-
158
- class PosteriorEncoder(nn.Module):
159
- def __init__(
160
- self,
161
- in_channels,
162
- out_channels,
163
- hidden_channels,
164
- kernel_size,
165
- dilation_rate,
166
- n_layers,
167
- gin_channels=0,
168
- ):
169
- super().__init__()
170
- self.in_channels = in_channels
171
- self.out_channels = out_channels
172
- self.hidden_channels = hidden_channels
173
- self.kernel_size = kernel_size
174
- self.dilation_rate = dilation_rate
175
- self.n_layers = n_layers
176
- self.gin_channels = gin_channels
177
-
178
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
- self.enc = modules.WN(
180
- hidden_channels,
181
- kernel_size,
182
- dilation_rate,
183
- n_layers,
184
- gin_channels=gin_channels,
185
- )
186
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
-
188
- def forward(self, x, x_lengths, g=None):
189
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
- x.dtype
191
- )
192
- x = self.pre(x) * x_mask
193
- x = self.enc(x, x_mask, g=g)
194
- stats = self.proj(x) * x_mask
195
- m, logs = torch.split(stats, self.out_channels, dim=1)
196
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
- return z, m, logs, x_mask
198
-
199
- def remove_weight_norm(self):
200
- self.enc.remove_weight_norm()
201
-
202
-
203
- class Generator(torch.nn.Module):
204
- def __init__(
205
- self,
206
- initial_channel,
207
- resblock,
208
- resblock_kernel_sizes,
209
- resblock_dilation_sizes,
210
- upsample_rates,
211
- upsample_initial_channel,
212
- upsample_kernel_sizes,
213
- gin_channels=0,
214
- ):
215
- super(Generator, self).__init__()
216
- self.num_kernels = len(resblock_kernel_sizes)
217
- self.num_upsamples = len(upsample_rates)
218
- self.conv_pre = Conv1d(
219
- initial_channel, upsample_initial_channel, 7, 1, padding=3
220
- )
221
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
-
223
- self.ups = nn.ModuleList()
224
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
- self.ups.append(
226
- weight_norm(
227
- ConvTranspose1d(
228
- upsample_initial_channel // (2**i),
229
- upsample_initial_channel // (2 ** (i + 1)),
230
- k,
231
- u,
232
- padding=(k - u) // 2,
233
- )
234
- )
235
- )
236
-
237
- self.resblocks = nn.ModuleList()
238
- for i in range(len(self.ups)):
239
- ch = upsample_initial_channel // (2 ** (i + 1))
240
- for j, (k, d) in enumerate(
241
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
- ):
243
- self.resblocks.append(resblock(ch, k, d))
244
-
245
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
- self.ups.apply(init_weights)
247
-
248
- if gin_channels != 0:
249
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
-
251
- def forward(self, x, g=None):
252
- x = self.conv_pre(x)
253
- if g is not None:
254
- x = x + self.cond(g)
255
-
256
- for i in range(self.num_upsamples):
257
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
- x = self.ups[i](x)
259
- xs = None
260
- for j in range(self.num_kernels):
261
- if xs is None:
262
- xs = self.resblocks[i * self.num_kernels + j](x)
263
- else:
264
- xs += self.resblocks[i * self.num_kernels + j](x)
265
- x = xs / self.num_kernels
266
- x = F.leaky_relu(x)
267
- x = self.conv_post(x)
268
- x = torch.tanh(x)
269
-
270
- return x
271
-
272
- def remove_weight_norm(self):
273
- for l in self.ups:
274
- remove_weight_norm(l)
275
- for l in self.resblocks:
276
- l.remove_weight_norm()
277
-
278
-
279
- class SineGen(torch.nn.Module):
280
- """Definition of sine generator
281
- SineGen(samp_rate, harmonic_num = 0,
282
- sine_amp = 0.1, noise_std = 0.003,
283
- voiced_threshold = 0,
284
- flag_for_pulse=False)
285
- samp_rate: sampling rate in Hz
286
- harmonic_num: number of harmonic overtones (default 0)
287
- sine_amp: amplitude of sine-wavefrom (default 0.1)
288
- noise_std: std of Gaussian noise (default 0.003)
289
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
- Note: when flag_for_pulse is True, the first time step of a voiced
292
- segment is always sin(np.pi) or cos(0)
293
- """
294
-
295
- def __init__(
296
- self,
297
- samp_rate,
298
- harmonic_num=0,
299
- sine_amp=0.1,
300
- noise_std=0.003,
301
- voiced_threshold=0,
302
- flag_for_pulse=False,
303
- ):
304
- super(SineGen, self).__init__()
305
- self.sine_amp = sine_amp
306
- self.noise_std = noise_std
307
- self.harmonic_num = harmonic_num
308
- self.dim = self.harmonic_num + 1
309
- self.sampling_rate = samp_rate
310
- self.voiced_threshold = voiced_threshold
311
-
312
- def _f02uv(self, f0):
313
- # generate uv signal
314
- uv = torch.ones_like(f0)
315
- uv = uv * (f0 > self.voiced_threshold)
316
- return uv
317
-
318
- def forward(self, f0, upp):
319
- """sine_tensor, uv = forward(f0)
320
- input F0: tensor(batchsize=1, length, dim=1)
321
- f0 for unvoiced steps should be 0
322
- output sine_tensor: tensor(batchsize=1, length, dim)
323
- output uv: tensor(batchsize=1, length, 1)
324
- """
325
- with torch.no_grad():
326
- f0 = f0[:, None].transpose(1, 2)
327
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
- # fundamental component
329
- f0_buf[:, :, 0] = f0[:, :, 0]
330
- for idx in np.arange(self.harmonic_num):
331
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
- idx + 2
333
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
- rand_ini = torch.rand(
336
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
- )
338
- rand_ini[:, 0] = 0
339
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
- tmp_over_one *= upp
342
- tmp_over_one = F.interpolate(
343
- tmp_over_one.transpose(2, 1),
344
- scale_factor=upp,
345
- mode="linear",
346
- align_corners=True,
347
- ).transpose(2, 1)
348
- rad_values = F.interpolate(
349
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
- ).transpose(
351
- 2, 1
352
- ) #######
353
- tmp_over_one %= 1
354
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
- cumsum_shift = torch.zeros_like(rad_values)
356
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
- sine_waves = torch.sin(
358
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
- )
360
- sine_waves = sine_waves * self.sine_amp
361
- uv = self._f02uv(f0)
362
- uv = F.interpolate(
363
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
- ).transpose(2, 1)
365
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
- noise = noise_amp * torch.randn_like(sine_waves)
367
- sine_waves = sine_waves * uv + noise
368
- return sine_waves, uv, noise
369
-
370
-
371
- class SourceModuleHnNSF(torch.nn.Module):
372
- """SourceModule for hn-nsf
373
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
- add_noise_std=0.003, voiced_threshod=0)
375
- sampling_rate: sampling_rate in Hz
376
- harmonic_num: number of harmonic above F0 (default: 0)
377
- sine_amp: amplitude of sine source signal (default: 0.1)
378
- add_noise_std: std of additive Gaussian noise (default: 0.003)
379
- note that amplitude of noise in unvoiced is decided
380
- by sine_amp
381
- voiced_threshold: threhold to set U/V given F0 (default: 0)
382
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
- F0_sampled (batchsize, length, 1)
384
- Sine_source (batchsize, length, 1)
385
- noise_source (batchsize, length 1)
386
- uv (batchsize, length, 1)
387
- """
388
-
389
- def __init__(
390
- self,
391
- sampling_rate,
392
- harmonic_num=0,
393
- sine_amp=0.1,
394
- add_noise_std=0.003,
395
- voiced_threshod=0,
396
- is_half=True,
397
- ):
398
- super(SourceModuleHnNSF, self).__init__()
399
-
400
- self.sine_amp = sine_amp
401
- self.noise_std = add_noise_std
402
- self.is_half = is_half
403
- # to produce sine waveforms
404
- self.l_sin_gen = SineGen(
405
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
- )
407
-
408
- # to merge source harmonics into a single excitation
409
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
- self.l_tanh = torch.nn.Tanh()
411
-
412
- def forward(self, x, upp=None):
413
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
- if self.is_half:
415
- sine_wavs = sine_wavs.half()
416
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
- return sine_merge, None, None # noise, uv
418
-
419
-
420
- class GeneratorNSF(torch.nn.Module):
421
- def __init__(
422
- self,
423
- initial_channel,
424
- resblock,
425
- resblock_kernel_sizes,
426
- resblock_dilation_sizes,
427
- upsample_rates,
428
- upsample_initial_channel,
429
- upsample_kernel_sizes,
430
- gin_channels,
431
- sr,
432
- is_half=False,
433
- ):
434
- super(GeneratorNSF, self).__init__()
435
- self.num_kernels = len(resblock_kernel_sizes)
436
- self.num_upsamples = len(upsample_rates)
437
-
438
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
- self.m_source = SourceModuleHnNSF(
440
- sampling_rate=sr, harmonic_num=0, is_half=is_half
441
- )
442
- self.noise_convs = nn.ModuleList()
443
- self.conv_pre = Conv1d(
444
- initial_channel, upsample_initial_channel, 7, 1, padding=3
445
- )
446
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
-
448
- self.ups = nn.ModuleList()
449
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
- c_cur = upsample_initial_channel // (2 ** (i + 1))
451
- self.ups.append(
452
- weight_norm(
453
- ConvTranspose1d(
454
- upsample_initial_channel // (2**i),
455
- upsample_initial_channel // (2 ** (i + 1)),
456
- k,
457
- u,
458
- padding=(k - u) // 2,
459
- )
460
- )
461
- )
462
- if i + 1 < len(upsample_rates):
463
- stride_f0 = np.prod(upsample_rates[i + 1 :])
464
- self.noise_convs.append(
465
- Conv1d(
466
- 1,
467
- c_cur,
468
- kernel_size=stride_f0 * 2,
469
- stride=stride_f0,
470
- padding=stride_f0 // 2,
471
- )
472
- )
473
- else:
474
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
-
476
- self.resblocks = nn.ModuleList()
477
- for i in range(len(self.ups)):
478
- ch = upsample_initial_channel // (2 ** (i + 1))
479
- for j, (k, d) in enumerate(
480
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
- ):
482
- self.resblocks.append(resblock(ch, k, d))
483
-
484
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
- self.ups.apply(init_weights)
486
-
487
- if gin_channels != 0:
488
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
-
490
- self.upp = np.prod(upsample_rates)
491
-
492
- def forward(self, x, f0, g=None):
493
- har_source, noi_source, uv = self.m_source(f0, self.upp)
494
- har_source = har_source.transpose(1, 2)
495
- x = self.conv_pre(x)
496
- if g is not None:
497
- x = x + self.cond(g)
498
-
499
- for i in range(self.num_upsamples):
500
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
- x = self.ups[i](x)
502
- x_source = self.noise_convs[i](har_source)
503
- x = x + x_source
504
- xs = None
505
- for j in range(self.num_kernels):
506
- if xs is None:
507
- xs = self.resblocks[i * self.num_kernels + j](x)
508
- else:
509
- xs += self.resblocks[i * self.num_kernels + j](x)
510
- x = xs / self.num_kernels
511
- x = F.leaky_relu(x)
512
- x = self.conv_post(x)
513
- x = torch.tanh(x)
514
- return x
515
-
516
- def remove_weight_norm(self):
517
- for l in self.ups:
518
- remove_weight_norm(l)
519
- for l in self.resblocks:
520
- l.remove_weight_norm()
521
-
522
-
523
- sr2sr = {
524
- "32k": 32000,
525
- "40k": 40000,
526
- "48k": 48000,
527
- }
528
-
529
-
530
- class SynthesizerTrnMs256NSFsidM(nn.Module):
531
- def __init__(
532
- self,
533
- spec_channels,
534
- segment_size,
535
- inter_channels,
536
- hidden_channels,
537
- filter_channels,
538
- n_heads,
539
- n_layers,
540
- kernel_size,
541
- p_dropout,
542
- resblock,
543
- resblock_kernel_sizes,
544
- resblock_dilation_sizes,
545
- upsample_rates,
546
- upsample_initial_channel,
547
- upsample_kernel_sizes,
548
- spk_embed_dim,
549
- gin_channels,
550
- sr,
551
- **kwargs
552
- ):
553
- super().__init__()
554
- if type(sr) == type("strr"):
555
- sr = sr2sr[sr]
556
- self.spec_channels = spec_channels
557
- self.inter_channels = inter_channels
558
- self.hidden_channels = hidden_channels
559
- self.filter_channels = filter_channels
560
- self.n_heads = n_heads
561
- self.n_layers = n_layers
562
- self.kernel_size = kernel_size
563
- self.p_dropout = p_dropout
564
- self.resblock = resblock
565
- self.resblock_kernel_sizes = resblock_kernel_sizes
566
- self.resblock_dilation_sizes = resblock_dilation_sizes
567
- self.upsample_rates = upsample_rates
568
- self.upsample_initial_channel = upsample_initial_channel
569
- self.upsample_kernel_sizes = upsample_kernel_sizes
570
- self.segment_size = segment_size
571
- self.gin_channels = gin_channels
572
- # self.hop_length = hop_length#
573
- self.spk_embed_dim = spk_embed_dim
574
- self.enc_p = TextEncoder256(
575
- inter_channels,
576
- hidden_channels,
577
- filter_channels,
578
- n_heads,
579
- n_layers,
580
- kernel_size,
581
- p_dropout,
582
- )
583
- self.dec = GeneratorNSF(
584
- inter_channels,
585
- resblock,
586
- resblock_kernel_sizes,
587
- resblock_dilation_sizes,
588
- upsample_rates,
589
- upsample_initial_channel,
590
- upsample_kernel_sizes,
591
- gin_channels=gin_channels,
592
- sr=sr,
593
- is_half=kwargs["is_half"],
594
- )
595
- self.enc_q = PosteriorEncoder(
596
- spec_channels,
597
- inter_channels,
598
- hidden_channels,
599
- 5,
600
- 1,
601
- 16,
602
- gin_channels=gin_channels,
603
- )
604
- self.flow = ResidualCouplingBlock(
605
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
- )
607
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
-
610
- def remove_weight_norm(self):
611
- self.dec.remove_weight_norm()
612
- self.flow.remove_weight_norm()
613
- self.enc_q.remove_weight_norm()
614
-
615
- def forward(self, phone, phone_lengths, pitch, nsff0, sid, rnd, max_len=None):
616
- g = self.emb_g(sid).unsqueeze(-1)
617
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
618
- z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
619
- z = self.flow(z_p, x_mask, g=g, reverse=True)
620
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
621
- return o
622
-
623
-
624
- class SynthesizerTrnMs256NSFsid_sim(nn.Module):
625
- """
626
- Synthesizer for Training
627
- """
628
-
629
- def __init__(
630
- self,
631
- spec_channels,
632
- segment_size,
633
- inter_channels,
634
- hidden_channels,
635
- filter_channels,
636
- n_heads,
637
- n_layers,
638
- kernel_size,
639
- p_dropout,
640
- resblock,
641
- resblock_kernel_sizes,
642
- resblock_dilation_sizes,
643
- upsample_rates,
644
- upsample_initial_channel,
645
- upsample_kernel_sizes,
646
- spk_embed_dim,
647
- # hop_length,
648
- gin_channels=0,
649
- use_sdp=True,
650
- **kwargs
651
- ):
652
- super().__init__()
653
- self.spec_channels = spec_channels
654
- self.inter_channels = inter_channels
655
- self.hidden_channels = hidden_channels
656
- self.filter_channels = filter_channels
657
- self.n_heads = n_heads
658
- self.n_layers = n_layers
659
- self.kernel_size = kernel_size
660
- self.p_dropout = p_dropout
661
- self.resblock = resblock
662
- self.resblock_kernel_sizes = resblock_kernel_sizes
663
- self.resblock_dilation_sizes = resblock_dilation_sizes
664
- self.upsample_rates = upsample_rates
665
- self.upsample_initial_channel = upsample_initial_channel
666
- self.upsample_kernel_sizes = upsample_kernel_sizes
667
- self.segment_size = segment_size
668
- self.gin_channels = gin_channels
669
- # self.hop_length = hop_length#
670
- self.spk_embed_dim = spk_embed_dim
671
- self.enc_p = TextEncoder256Sim(
672
- inter_channels,
673
- hidden_channels,
674
- filter_channels,
675
- n_heads,
676
- n_layers,
677
- kernel_size,
678
- p_dropout,
679
- )
680
- self.dec = GeneratorNSF(
681
- inter_channels,
682
- resblock,
683
- resblock_kernel_sizes,
684
- resblock_dilation_sizes,
685
- upsample_rates,
686
- upsample_initial_channel,
687
- upsample_kernel_sizes,
688
- gin_channels=gin_channels,
689
- is_half=kwargs["is_half"],
690
- )
691
-
692
- self.flow = ResidualCouplingBlock(
693
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
694
- )
695
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
696
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
697
-
698
- def remove_weight_norm(self):
699
- self.dec.remove_weight_norm()
700
- self.flow.remove_weight_norm()
701
- self.enc_q.remove_weight_norm()
702
-
703
- def forward(
704
- self, phone, phone_lengths, pitch, pitchf, ds, max_len=None
705
- ): # y是spec不需要了现在
706
- g = self.emb_g(ds.unsqueeze(0)).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
707
- x, x_mask = self.enc_p(phone, pitch, phone_lengths)
708
- x = self.flow(x, x_mask, g=g, reverse=True)
709
- o = self.dec((x * x_mask)[:, :, :max_len], pitchf, g=g)
710
- return o
711
-
712
-
713
- class MultiPeriodDiscriminator(torch.nn.Module):
714
- def __init__(self, use_spectral_norm=False):
715
- super(MultiPeriodDiscriminator, self).__init__()
716
- periods = [2, 3, 5, 7, 11, 17]
717
- # periods = [3, 5, 7, 11, 17, 23, 37]
718
-
719
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
720
- discs = discs + [
721
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
722
- ]
723
- self.discriminators = nn.ModuleList(discs)
724
-
725
- def forward(self, y, y_hat):
726
- y_d_rs = [] #
727
- y_d_gs = []
728
- fmap_rs = []
729
- fmap_gs = []
730
- for i, d in enumerate(self.discriminators):
731
- y_d_r, fmap_r = d(y)
732
- y_d_g, fmap_g = d(y_hat)
733
- # for j in range(len(fmap_r)):
734
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
735
- y_d_rs.append(y_d_r)
736
- y_d_gs.append(y_d_g)
737
- fmap_rs.append(fmap_r)
738
- fmap_gs.append(fmap_g)
739
-
740
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
741
-
742
-
743
- class DiscriminatorS(torch.nn.Module):
744
- def __init__(self, use_spectral_norm=False):
745
- super(DiscriminatorS, self).__init__()
746
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
747
- self.convs = nn.ModuleList(
748
- [
749
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
750
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
751
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
752
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
753
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
754
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
755
- ]
756
- )
757
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
758
-
759
- def forward(self, x):
760
- fmap = []
761
-
762
- for l in self.convs:
763
- x = l(x)
764
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
765
- fmap.append(x)
766
- x = self.conv_post(x)
767
- fmap.append(x)
768
- x = torch.flatten(x, 1, -1)
769
-
770
- return x, fmap
771
-
772
-
773
- class DiscriminatorP(torch.nn.Module):
774
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
775
- super(DiscriminatorP, self).__init__()
776
- self.period = period
777
- self.use_spectral_norm = use_spectral_norm
778
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
779
- self.convs = nn.ModuleList(
780
- [
781
- norm_f(
782
- Conv2d(
783
- 1,
784
- 32,
785
- (kernel_size, 1),
786
- (stride, 1),
787
- padding=(get_padding(kernel_size, 1), 0),
788
- )
789
- ),
790
- norm_f(
791
- Conv2d(
792
- 32,
793
- 128,
794
- (kernel_size, 1),
795
- (stride, 1),
796
- padding=(get_padding(kernel_size, 1), 0),
797
- )
798
- ),
799
- norm_f(
800
- Conv2d(
801
- 128,
802
- 512,
803
- (kernel_size, 1),
804
- (stride, 1),
805
- padding=(get_padding(kernel_size, 1), 0),
806
- )
807
- ),
808
- norm_f(
809
- Conv2d(
810
- 512,
811
- 1024,
812
- (kernel_size, 1),
813
- (stride, 1),
814
- padding=(get_padding(kernel_size, 1), 0),
815
- )
816
- ),
817
- norm_f(
818
- Conv2d(
819
- 1024,
820
- 1024,
821
- (kernel_size, 1),
822
- 1,
823
- padding=(get_padding(kernel_size, 1), 0),
824
- )
825
- ),
826
- ]
827
- )
828
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
829
-
830
- def forward(self, x):
831
- fmap = []
832
-
833
- # 1d to 2d
834
- b, c, t = x.shape
835
- if t % self.period != 0: # pad first
836
- n_pad = self.period - (t % self.period)
837
- x = F.pad(x, (0, n_pad), "reflect")
838
- t = t + n_pad
839
- x = x.view(b, c, t // self.period, self.period)
840
-
841
- for l in self.convs:
842
- x = l(x)
843
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
844
- fmap.append(x)
845
- x = self.conv_post(x)
846
- fmap.append(x)
847
- x = torch.flatten(x, 1, -1)
848
-
849
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules.py DELETED
@@ -1,522 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import scipy
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
- from torch.nn.utils import weight_norm, remove_weight_norm
11
-
12
- from infer_pack import commons
13
- from infer_pack.commons import init_weights, get_padding
14
- from infer_pack.transforms import piecewise_rational_quadratic_transform
15
-
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class LayerNorm(nn.Module):
21
- def __init__(self, channels, eps=1e-5):
22
- super().__init__()
23
- self.channels = channels
24
- self.eps = eps
25
-
26
- self.gamma = nn.Parameter(torch.ones(channels))
27
- self.beta = nn.Parameter(torch.zeros(channels))
28
-
29
- def forward(self, x):
30
- x = x.transpose(1, -1)
31
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
- return x.transpose(1, -1)
33
-
34
-
35
- class ConvReluNorm(nn.Module):
36
- def __init__(
37
- self,
38
- in_channels,
39
- hidden_channels,
40
- out_channels,
41
- kernel_size,
42
- n_layers,
43
- p_dropout,
44
- ):
45
- super().__init__()
46
- self.in_channels = in_channels
47
- self.hidden_channels = hidden_channels
48
- self.out_channels = out_channels
49
- self.kernel_size = kernel_size
50
- self.n_layers = n_layers
51
- self.p_dropout = p_dropout
52
- assert n_layers > 1, "Number of layers should be larger than 0."
53
-
54
- self.conv_layers = nn.ModuleList()
55
- self.norm_layers = nn.ModuleList()
56
- self.conv_layers.append(
57
- nn.Conv1d(
58
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
- )
60
- )
61
- self.norm_layers.append(LayerNorm(hidden_channels))
62
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
- for _ in range(n_layers - 1):
64
- self.conv_layers.append(
65
- nn.Conv1d(
66
- hidden_channels,
67
- hidden_channels,
68
- kernel_size,
69
- padding=kernel_size // 2,
70
- )
71
- )
72
- self.norm_layers.append(LayerNorm(hidden_channels))
73
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
- self.proj.weight.data.zero_()
75
- self.proj.bias.data.zero_()
76
-
77
- def forward(self, x, x_mask):
78
- x_org = x
79
- for i in range(self.n_layers):
80
- x = self.conv_layers[i](x * x_mask)
81
- x = self.norm_layers[i](x)
82
- x = self.relu_drop(x)
83
- x = x_org + self.proj(x)
84
- return x * x_mask
85
-
86
-
87
- class DDSConv(nn.Module):
88
- """
89
- Dialted and Depth-Separable Convolution
90
- """
91
-
92
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
- super().__init__()
94
- self.channels = channels
95
- self.kernel_size = kernel_size
96
- self.n_layers = n_layers
97
- self.p_dropout = p_dropout
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.convs_sep = nn.ModuleList()
101
- self.convs_1x1 = nn.ModuleList()
102
- self.norms_1 = nn.ModuleList()
103
- self.norms_2 = nn.ModuleList()
104
- for i in range(n_layers):
105
- dilation = kernel_size**i
106
- padding = (kernel_size * dilation - dilation) // 2
107
- self.convs_sep.append(
108
- nn.Conv1d(
109
- channels,
110
- channels,
111
- kernel_size,
112
- groups=channels,
113
- dilation=dilation,
114
- padding=padding,
115
- )
116
- )
117
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
- self.norms_1.append(LayerNorm(channels))
119
- self.norms_2.append(LayerNorm(channels))
120
-
121
- def forward(self, x, x_mask, g=None):
122
- if g is not None:
123
- x = x + g
124
- for i in range(self.n_layers):
125
- y = self.convs_sep[i](x * x_mask)
126
- y = self.norms_1[i](y)
127
- y = F.gelu(y)
128
- y = self.convs_1x1[i](y)
129
- y = self.norms_2[i](y)
130
- y = F.gelu(y)
131
- y = self.drop(y)
132
- x = x + y
133
- return x * x_mask
134
-
135
-
136
- class WN(torch.nn.Module):
137
- def __init__(
138
- self,
139
- hidden_channels,
140
- kernel_size,
141
- dilation_rate,
142
- n_layers,
143
- gin_channels=0,
144
- p_dropout=0,
145
- ):
146
- super(WN, self).__init__()
147
- assert kernel_size % 2 == 1
148
- self.hidden_channels = hidden_channels
149
- self.kernel_size = (kernel_size,)
150
- self.dilation_rate = dilation_rate
151
- self.n_layers = n_layers
152
- self.gin_channels = gin_channels
153
- self.p_dropout = p_dropout
154
-
155
- self.in_layers = torch.nn.ModuleList()
156
- self.res_skip_layers = torch.nn.ModuleList()
157
- self.drop = nn.Dropout(p_dropout)
158
-
159
- if gin_channels != 0:
160
- cond_layer = torch.nn.Conv1d(
161
- gin_channels, 2 * hidden_channels * n_layers, 1
162
- )
163
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
-
165
- for i in range(n_layers):
166
- dilation = dilation_rate**i
167
- padding = int((kernel_size * dilation - dilation) / 2)
168
- in_layer = torch.nn.Conv1d(
169
- hidden_channels,
170
- 2 * hidden_channels,
171
- kernel_size,
172
- dilation=dilation,
173
- padding=padding,
174
- )
175
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
- self.in_layers.append(in_layer)
177
-
178
- # last one is not necessary
179
- if i < n_layers - 1:
180
- res_skip_channels = 2 * hidden_channels
181
- else:
182
- res_skip_channels = hidden_channels
183
-
184
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
- self.res_skip_layers.append(res_skip_layer)
187
-
188
- def forward(self, x, x_mask, g=None, **kwargs):
189
- output = torch.zeros_like(x)
190
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
-
192
- if g is not None:
193
- g = self.cond_layer(g)
194
-
195
- for i in range(self.n_layers):
196
- x_in = self.in_layers[i](x)
197
- if g is not None:
198
- cond_offset = i * 2 * self.hidden_channels
199
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
- else:
201
- g_l = torch.zeros_like(x_in)
202
-
203
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
- acts = self.drop(acts)
205
-
206
- res_skip_acts = self.res_skip_layers[i](acts)
207
- if i < self.n_layers - 1:
208
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
- x = (x + res_acts) * x_mask
210
- output = output + res_skip_acts[:, self.hidden_channels :, :]
211
- else:
212
- output = output + res_skip_acts
213
- return output * x_mask
214
-
215
- def remove_weight_norm(self):
216
- if self.gin_channels != 0:
217
- torch.nn.utils.remove_weight_norm(self.cond_layer)
218
- for l in self.in_layers:
219
- torch.nn.utils.remove_weight_norm(l)
220
- for l in self.res_skip_layers:
221
- torch.nn.utils.remove_weight_norm(l)
222
-
223
-
224
- class ResBlock1(torch.nn.Module):
225
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
- super(ResBlock1, self).__init__()
227
- self.convs1 = nn.ModuleList(
228
- [
229
- weight_norm(
230
- Conv1d(
231
- channels,
232
- channels,
233
- kernel_size,
234
- 1,
235
- dilation=dilation[0],
236
- padding=get_padding(kernel_size, dilation[0]),
237
- )
238
- ),
239
- weight_norm(
240
- Conv1d(
241
- channels,
242
- channels,
243
- kernel_size,
244
- 1,
245
- dilation=dilation[1],
246
- padding=get_padding(kernel_size, dilation[1]),
247
- )
248
- ),
249
- weight_norm(
250
- Conv1d(
251
- channels,
252
- channels,
253
- kernel_size,
254
- 1,
255
- dilation=dilation[2],
256
- padding=get_padding(kernel_size, dilation[2]),
257
- )
258
- ),
259
- ]
260
- )
261
- self.convs1.apply(init_weights)
262
-
263
- self.convs2 = nn.ModuleList(
264
- [
265
- weight_norm(
266
- Conv1d(
267
- channels,
268
- channels,
269
- kernel_size,
270
- 1,
271
- dilation=1,
272
- padding=get_padding(kernel_size, 1),
273
- )
274
- ),
275
- weight_norm(
276
- Conv1d(
277
- channels,
278
- channels,
279
- kernel_size,
280
- 1,
281
- dilation=1,
282
- padding=get_padding(kernel_size, 1),
283
- )
284
- ),
285
- weight_norm(
286
- Conv1d(
287
- channels,
288
- channels,
289
- kernel_size,
290
- 1,
291
- dilation=1,
292
- padding=get_padding(kernel_size, 1),
293
- )
294
- ),
295
- ]
296
- )
297
- self.convs2.apply(init_weights)
298
-
299
- def forward(self, x, x_mask=None):
300
- for c1, c2 in zip(self.convs1, self.convs2):
301
- xt = F.leaky_relu(x, LRELU_SLOPE)
302
- if x_mask is not None:
303
- xt = xt * x_mask
304
- xt = c1(xt)
305
- xt = F.leaky_relu(xt, LRELU_SLOPE)
306
- if x_mask is not None:
307
- xt = xt * x_mask
308
- xt = c2(xt)
309
- x = xt + x
310
- if x_mask is not None:
311
- x = x * x_mask
312
- return x
313
-
314
- def remove_weight_norm(self):
315
- for l in self.convs1:
316
- remove_weight_norm(l)
317
- for l in self.convs2:
318
- remove_weight_norm(l)
319
-
320
-
321
- class ResBlock2(torch.nn.Module):
322
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
- super(ResBlock2, self).__init__()
324
- self.convs = nn.ModuleList(
325
- [
326
- weight_norm(
327
- Conv1d(
328
- channels,
329
- channels,
330
- kernel_size,
331
- 1,
332
- dilation=dilation[0],
333
- padding=get_padding(kernel_size, dilation[0]),
334
- )
335
- ),
336
- weight_norm(
337
- Conv1d(
338
- channels,
339
- channels,
340
- kernel_size,
341
- 1,
342
- dilation=dilation[1],
343
- padding=get_padding(kernel_size, dilation[1]),
344
- )
345
- ),
346
- ]
347
- )
348
- self.convs.apply(init_weights)
349
-
350
- def forward(self, x, x_mask=None):
351
- for c in self.convs:
352
- xt = F.leaky_relu(x, LRELU_SLOPE)
353
- if x_mask is not None:
354
- xt = xt * x_mask
355
- xt = c(xt)
356
- x = xt + x
357
- if x_mask is not None:
358
- x = x * x_mask
359
- return x
360
-
361
- def remove_weight_norm(self):
362
- for l in self.convs:
363
- remove_weight_norm(l)
364
-
365
-
366
- class Log(nn.Module):
367
- def forward(self, x, x_mask, reverse=False, **kwargs):
368
- if not reverse:
369
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
- logdet = torch.sum(-y, [1, 2])
371
- return y, logdet
372
- else:
373
- x = torch.exp(x) * x_mask
374
- return x
375
-
376
-
377
- class Flip(nn.Module):
378
- def forward(self, x, *args, reverse=False, **kwargs):
379
- x = torch.flip(x, [1])
380
- if not reverse:
381
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
- return x, logdet
383
- else:
384
- return x
385
-
386
-
387
- class ElementwiseAffine(nn.Module):
388
- def __init__(self, channels):
389
- super().__init__()
390
- self.channels = channels
391
- self.m = nn.Parameter(torch.zeros(channels, 1))
392
- self.logs = nn.Parameter(torch.zeros(channels, 1))
393
-
394
- def forward(self, x, x_mask, reverse=False, **kwargs):
395
- if not reverse:
396
- y = self.m + torch.exp(self.logs) * x
397
- y = y * x_mask
398
- logdet = torch.sum(self.logs * x_mask, [1, 2])
399
- return y, logdet
400
- else:
401
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
- return x
403
-
404
-
405
- class ResidualCouplingLayer(nn.Module):
406
- def __init__(
407
- self,
408
- channels,
409
- hidden_channels,
410
- kernel_size,
411
- dilation_rate,
412
- n_layers,
413
- p_dropout=0,
414
- gin_channels=0,
415
- mean_only=False,
416
- ):
417
- assert channels % 2 == 0, "channels should be divisible by 2"
418
- super().__init__()
419
- self.channels = channels
420
- self.hidden_channels = hidden_channels
421
- self.kernel_size = kernel_size
422
- self.dilation_rate = dilation_rate
423
- self.n_layers = n_layers
424
- self.half_channels = channels // 2
425
- self.mean_only = mean_only
426
-
427
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
- self.enc = WN(
429
- hidden_channels,
430
- kernel_size,
431
- dilation_rate,
432
- n_layers,
433
- p_dropout=p_dropout,
434
- gin_channels=gin_channels,
435
- )
436
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
- self.post.weight.data.zero_()
438
- self.post.bias.data.zero_()
439
-
440
- def forward(self, x, x_mask, g=None, reverse=False):
441
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
- h = self.pre(x0) * x_mask
443
- h = self.enc(h, x_mask, g=g)
444
- stats = self.post(h) * x_mask
445
- if not self.mean_only:
446
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
- else:
448
- m = stats
449
- logs = torch.zeros_like(m)
450
-
451
- if not reverse:
452
- x1 = m + x1 * torch.exp(logs) * x_mask
453
- x = torch.cat([x0, x1], 1)
454
- logdet = torch.sum(logs, [1, 2])
455
- return x, logdet
456
- else:
457
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
- x = torch.cat([x0, x1], 1)
459
- return x
460
-
461
- def remove_weight_norm(self):
462
- self.enc.remove_weight_norm()
463
-
464
-
465
- class ConvFlow(nn.Module):
466
- def __init__(
467
- self,
468
- in_channels,
469
- filter_channels,
470
- kernel_size,
471
- n_layers,
472
- num_bins=10,
473
- tail_bound=5.0,
474
- ):
475
- super().__init__()
476
- self.in_channels = in_channels
477
- self.filter_channels = filter_channels
478
- self.kernel_size = kernel_size
479
- self.n_layers = n_layers
480
- self.num_bins = num_bins
481
- self.tail_bound = tail_bound
482
- self.half_channels = in_channels // 2
483
-
484
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
- self.proj = nn.Conv1d(
487
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
- )
489
- self.proj.weight.data.zero_()
490
- self.proj.bias.data.zero_()
491
-
492
- def forward(self, x, x_mask, g=None, reverse=False):
493
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
- h = self.pre(x0)
495
- h = self.convs(h, x_mask, g=g)
496
- h = self.proj(h) * x_mask
497
-
498
- b, c, t = x0.shape
499
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
-
501
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
- self.filter_channels
504
- )
505
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
-
507
- x1, logabsdet = piecewise_rational_quadratic_transform(
508
- x1,
509
- unnormalized_widths,
510
- unnormalized_heights,
511
- unnormalized_derivatives,
512
- inverse=reverse,
513
- tails="linear",
514
- tail_bound=self.tail_bound,
515
- )
516
-
517
- x = torch.cat([x0, x1], 1) * x_mask
518
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
- if not reverse:
520
- return x, logdet
521
- else:
522
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules/F0Predictor/DioF0Predictor.py DELETED
@@ -1,90 +0,0 @@
1
- from infer_pack.modules.F0Predictor.F0Predictor import F0Predictor
2
- import pyworld
3
- import numpy as np
4
-
5
-
6
- class DioF0Predictor(F0Predictor):
7
- def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8
- self.hop_length = hop_length
9
- self.f0_min = f0_min
10
- self.f0_max = f0_max
11
- self.sampling_rate = sampling_rate
12
-
13
- def interpolate_f0(self, f0):
14
- """
15
- 对F0进行插值处理
16
- """
17
-
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
-
49
- return ip_data[:, 0], vuv_vector[:, 0]
50
-
51
- def resize_f0(self, x, target_len):
52
- source = np.array(x)
53
- source[source < 0.001] = np.nan
54
- target = np.interp(
55
- np.arange(0, len(source) * target_len, len(source)) / target_len,
56
- np.arange(0, len(source)),
57
- source,
58
- )
59
- res = np.nan_to_num(target)
60
- return res
61
-
62
- def compute_f0(self, wav, p_len=None):
63
- if p_len is None:
64
- p_len = wav.shape[0] // self.hop_length
65
- f0, t = pyworld.dio(
66
- wav.astype(np.double),
67
- fs=self.sampling_rate,
68
- f0_floor=self.f0_min,
69
- f0_ceil=self.f0_max,
70
- frame_period=1000 * self.hop_length / self.sampling_rate,
71
- )
72
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
73
- for index, pitch in enumerate(f0):
74
- f0[index] = round(pitch, 1)
75
- return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
76
-
77
- def compute_f0_uv(self, wav, p_len=None):
78
- if p_len is None:
79
- p_len = wav.shape[0] // self.hop_length
80
- f0, t = pyworld.dio(
81
- wav.astype(np.double),
82
- fs=self.sampling_rate,
83
- f0_floor=self.f0_min,
84
- f0_ceil=self.f0_max,
85
- frame_period=1000 * self.hop_length / self.sampling_rate,
86
- )
87
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
88
- for index, pitch in enumerate(f0):
89
- f0[index] = round(pitch, 1)
90
- return self.interpolate_f0(self.resize_f0(f0, p_len))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules/F0Predictor/F0Predictor.py DELETED
@@ -1,16 +0,0 @@
1
- class F0Predictor(object):
2
- def compute_f0(self, wav, p_len):
3
- """
4
- input: wav:[signal_length]
5
- p_len:int
6
- output: f0:[signal_length//hop_length]
7
- """
8
- pass
9
-
10
- def compute_f0_uv(self, wav, p_len):
11
- """
12
- input: wav:[signal_length]
13
- p_len:int
14
- output: f0:[signal_length//hop_length],uv:[signal_length//hop_length]
15
- """
16
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules/F0Predictor/HarvestF0Predictor.py DELETED
@@ -1,86 +0,0 @@
1
- from infer_pack.modules.F0Predictor.F0Predictor import F0Predictor
2
- import pyworld
3
- import numpy as np
4
-
5
-
6
- class HarvestF0Predictor(F0Predictor):
7
- def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8
- self.hop_length = hop_length
9
- self.f0_min = f0_min
10
- self.f0_max = f0_max
11
- self.sampling_rate = sampling_rate
12
-
13
- def interpolate_f0(self, f0):
14
- """
15
- 对F0进行插值处理
16
- """
17
-
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
-
49
- return ip_data[:, 0], vuv_vector[:, 0]
50
-
51
- def resize_f0(self, x, target_len):
52
- source = np.array(x)
53
- source[source < 0.001] = np.nan
54
- target = np.interp(
55
- np.arange(0, len(source) * target_len, len(source)) / target_len,
56
- np.arange(0, len(source)),
57
- source,
58
- )
59
- res = np.nan_to_num(target)
60
- return res
61
-
62
- def compute_f0(self, wav, p_len=None):
63
- if p_len is None:
64
- p_len = wav.shape[0] // self.hop_length
65
- f0, t = pyworld.harvest(
66
- wav.astype(np.double),
67
- fs=self.hop_length,
68
- f0_ceil=self.f0_max,
69
- f0_floor=self.f0_min,
70
- frame_period=1000 * self.hop_length / self.sampling_rate,
71
- )
72
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.fs)
73
- return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
74
-
75
- def compute_f0_uv(self, wav, p_len=None):
76
- if p_len is None:
77
- p_len = wav.shape[0] // self.hop_length
78
- f0, t = pyworld.harvest(
79
- wav.astype(np.double),
80
- fs=self.sampling_rate,
81
- f0_floor=self.f0_min,
82
- f0_ceil=self.f0_max,
83
- frame_period=1000 * self.hop_length / self.sampling_rate,
84
- )
85
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
86
- return self.interpolate_f0(self.resize_f0(f0, p_len))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules/F0Predictor/PMF0Predictor.py DELETED
@@ -1,97 +0,0 @@
1
- from infer_pack.modules.F0Predictor.F0Predictor import F0Predictor
2
- import parselmouth
3
- import numpy as np
4
-
5
-
6
- class PMF0Predictor(F0Predictor):
7
- def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8
- self.hop_length = hop_length
9
- self.f0_min = f0_min
10
- self.f0_max = f0_max
11
- self.sampling_rate = sampling_rate
12
-
13
- def interpolate_f0(self, f0):
14
- """
15
- 对F0进行插值处理
16
- """
17
-
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
-
49
- return ip_data[:, 0], vuv_vector[:, 0]
50
-
51
- def compute_f0(self, wav, p_len=None):
52
- x = wav
53
- if p_len is None:
54
- p_len = x.shape[0] // self.hop_length
55
- else:
56
- assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
57
- time_step = self.hop_length / self.sampling_rate * 1000
58
- f0 = (
59
- parselmouth.Sound(x, self.sampling_rate)
60
- .to_pitch_ac(
61
- time_step=time_step / 1000,
62
- voicing_threshold=0.6,
63
- pitch_floor=self.f0_min,
64
- pitch_ceiling=self.f0_max,
65
- )
66
- .selected_array["frequency"]
67
- )
68
-
69
- pad_size = (p_len - len(f0) + 1) // 2
70
- if pad_size > 0 or p_len - len(f0) - pad_size > 0:
71
- f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
72
- f0, uv = self.interpolate_f0(f0)
73
- return f0
74
-
75
- def compute_f0_uv(self, wav, p_len=None):
76
- x = wav
77
- if p_len is None:
78
- p_len = x.shape[0] // self.hop_length
79
- else:
80
- assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
81
- time_step = self.hop_length / self.sampling_rate * 1000
82
- f0 = (
83
- parselmouth.Sound(x, self.sampling_rate)
84
- .to_pitch_ac(
85
- time_step=time_step / 1000,
86
- voicing_threshold=0.6,
87
- pitch_floor=self.f0_min,
88
- pitch_ceiling=self.f0_max,
89
- )
90
- .selected_array["frequency"]
91
- )
92
-
93
- pad_size = (p_len - len(f0) + 1) // 2
94
- if pad_size > 0 or p_len - len(f0) - pad_size > 0:
95
- f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
96
- f0, uv = self.interpolate_f0(f0)
97
- return f0, uv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules/F0Predictor/__init__.py DELETED
File without changes
infer_pack/onnx_inference.py DELETED
@@ -1,142 +0,0 @@
1
- import onnxruntime
2
- import librosa
3
- import numpy as np
4
- import soundfile
5
-
6
- class ContentVec:
7
- def __init__(self, vec_path="pretrained/vec-768-layer-12.onnx", device=None):
8
- print("load model(s) from {}".format(vec_path))
9
- if device == "cpu" or device is None:
10
- providers = ["CPUExecutionProvider"]
11
- elif device == "cuda":
12
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
13
- elif device == "dml":
14
- providers = ["DmlExecutionProvider"]
15
- else:
16
- raise RuntimeError("Unsportted Device")
17
- self.model = onnxruntime.InferenceSession(vec_path, providers=providers)
18
-
19
- def __call__(self, wav):
20
- return self.forward(wav)
21
-
22
- def forward(self, wav):
23
- feats = wav
24
- if feats.ndim == 2: # double channels
25
- feats = feats.mean(-1)
26
- assert feats.ndim == 1, feats.ndim
27
- feats = np.expand_dims(np.expand_dims(feats, 0), 0)
28
- onnx_input = {self.model.get_inputs()[0].name: feats}
29
- logits = self.model.run(None, onnx_input)[0]
30
- return logits.transpose(0, 2, 1)
31
-
32
-
33
- def get_f0_predictor(f0_predictor, hop_length, sampling_rate, **kargs):
34
- if f0_predictor == "pm":
35
- from infer_pack.modules.F0Predictor.PMF0Predictor import PMF0Predictor
36
-
37
- f0_predictor_object = PMF0Predictor(
38
- hop_length=hop_length, sampling_rate=sampling_rate
39
- )
40
- elif f0_predictor == "harvest":
41
- from infer_pack.modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor
42
-
43
- f0_predictor_object = HarvestF0Predictor(
44
- hop_length=hop_length, sampling_rate=sampling_rate
45
- )
46
- elif f0_predictor == "dio":
47
- from infer_pack.modules.F0Predictor.DioF0Predictor import DioF0Predictor
48
-
49
- f0_predictor_object = DioF0Predictor(
50
- hop_length=hop_length, sampling_rate=sampling_rate
51
- )
52
- else:
53
- raise Exception("Unknown f0 predictor")
54
- return f0_predictor_object
55
-
56
-
57
- class OnnxRVC:
58
- def __init__(
59
- self,
60
- model_path,
61
- sr=40000,
62
- hop_size=512,
63
- vec_path="vec-768-layer-12",
64
- device="cpu",
65
- ):
66
- vec_path = f"pretrained/{vec_path}.onnx"
67
- self.vec_model = ContentVec(vec_path, device)
68
- if device == "cpu" or device is None:
69
- providers = ["CPUExecutionProvider"]
70
- elif device == "cuda":
71
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
72
- elif device == "dml":
73
- providers = ["DmlExecutionProvider"]
74
- else:
75
- raise RuntimeError("Unsportted Device")
76
- self.model = onnxruntime.InferenceSession(model_path, providers=providers)
77
- self.sampling_rate = sr
78
- self.hop_size = hop_size
79
-
80
- def forward(self, hubert, hubert_length, pitch, pitchf, ds, rnd):
81
- onnx_input = {
82
- self.model.get_inputs()[0].name: hubert,
83
- self.model.get_inputs()[1].name: hubert_length,
84
- self.model.get_inputs()[2].name: pitch,
85
- self.model.get_inputs()[3].name: pitchf,
86
- self.model.get_inputs()[4].name: ds,
87
- self.model.get_inputs()[5].name: rnd,
88
- }
89
- return (self.model.run(None, onnx_input)[0] * 32767).astype(np.int16)
90
-
91
- def inference(
92
- self,
93
- raw_path,
94
- sid,
95
- f0_method="dio",
96
- f0_up_key=0,
97
- pad_time=0.5,
98
- cr_threshold=0.02,
99
- ):
100
- f0_min = 50
101
- f0_max = 1100
102
- f0_mel_min = 1127 * np.log(1 + f0_min / 700)
103
- f0_mel_max = 1127 * np.log(1 + f0_max / 700)
104
- f0_predictor = get_f0_predictor(
105
- f0_method,
106
- hop_length=self.hop_size,
107
- sampling_rate=self.sampling_rate,
108
- threshold=cr_threshold,
109
- )
110
- wav, sr = librosa.load(raw_path, sr=self.sampling_rate)
111
- org_length = len(wav)
112
- if org_length / sr > 50.0:
113
- raise RuntimeError("Reached Max Length")
114
-
115
- wav16k = librosa.resample(wav, orig_sr=self.sampling_rate, target_sr=16000)
116
- wav16k = wav16k
117
-
118
- hubert = self.vec_model(wav16k)
119
- hubert = np.repeat(hubert, 2, axis=2).transpose(0, 2, 1).astype(np.float32)
120
- hubert_length = hubert.shape[1]
121
-
122
- pitchf = f0_predictor.compute_f0(wav, hubert_length)
123
- pitchf = pitchf * 2 ** (f0_up_key / 12)
124
- pitch = pitchf.copy()
125
- f0_mel = 1127 * np.log(1 + pitch / 700)
126
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
127
- f0_mel_max - f0_mel_min
128
- ) + 1
129
- f0_mel[f0_mel <= 1] = 1
130
- f0_mel[f0_mel > 255] = 255
131
- pitch = np.rint(f0_mel).astype(np.int64)
132
-
133
- pitchf = pitchf.reshape(1, len(pitchf)).astype(np.float32)
134
- pitch = pitch.reshape(1, len(pitch))
135
- ds = np.array([sid]).astype(np.int64)
136
-
137
- rnd = np.random.randn(1, 192, hubert_length).astype(np.float32)
138
- hubert_length = np.array([hubert_length]).astype(np.int64)
139
-
140
- out_wav = self.forward(hubert, hubert_length, pitch, pitchf, ds, rnd).squeeze()
141
- out_wav = np.pad(out_wav, (0, 2 * self.hop_size), "constant")
142
- return out_wav[0:org_length]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/transforms.py DELETED
@@ -1,209 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import numpy as np
5
-
6
-
7
- DEFAULT_MIN_BIN_WIDTH = 1e-3
8
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
- DEFAULT_MIN_DERIVATIVE = 1e-3
10
-
11
-
12
- def piecewise_rational_quadratic_transform(
13
- inputs,
14
- unnormalized_widths,
15
- unnormalized_heights,
16
- unnormalized_derivatives,
17
- inverse=False,
18
- tails=None,
19
- tail_bound=1.0,
20
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
- min_derivative=DEFAULT_MIN_DERIVATIVE,
23
- ):
24
- if tails is None:
25
- spline_fn = rational_quadratic_spline
26
- spline_kwargs = {}
27
- else:
28
- spline_fn = unconstrained_rational_quadratic_spline
29
- spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
-
31
- outputs, logabsdet = spline_fn(
32
- inputs=inputs,
33
- unnormalized_widths=unnormalized_widths,
34
- unnormalized_heights=unnormalized_heights,
35
- unnormalized_derivatives=unnormalized_derivatives,
36
- inverse=inverse,
37
- min_bin_width=min_bin_width,
38
- min_bin_height=min_bin_height,
39
- min_derivative=min_derivative,
40
- **spline_kwargs
41
- )
42
- return outputs, logabsdet
43
-
44
-
45
- def searchsorted(bin_locations, inputs, eps=1e-6):
46
- bin_locations[..., -1] += eps
47
- return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
-
49
-
50
- def unconstrained_rational_quadratic_spline(
51
- inputs,
52
- unnormalized_widths,
53
- unnormalized_heights,
54
- unnormalized_derivatives,
55
- inverse=False,
56
- tails="linear",
57
- tail_bound=1.0,
58
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
- min_derivative=DEFAULT_MIN_DERIVATIVE,
61
- ):
62
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
- outside_interval_mask = ~inside_interval_mask
64
-
65
- outputs = torch.zeros_like(inputs)
66
- logabsdet = torch.zeros_like(inputs)
67
-
68
- if tails == "linear":
69
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
- constant = np.log(np.exp(1 - min_derivative) - 1)
71
- unnormalized_derivatives[..., 0] = constant
72
- unnormalized_derivatives[..., -1] = constant
73
-
74
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
- logabsdet[outside_interval_mask] = 0
76
- else:
77
- raise RuntimeError("{} tails are not implemented.".format(tails))
78
-
79
- (
80
- outputs[inside_interval_mask],
81
- logabsdet[inside_interval_mask],
82
- ) = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound,
89
- right=tail_bound,
90
- bottom=-tail_bound,
91
- top=tail_bound,
92
- min_bin_width=min_bin_width,
93
- min_bin_height=min_bin_height,
94
- min_derivative=min_derivative,
95
- )
96
-
97
- return outputs, logabsdet
98
-
99
-
100
- def rational_quadratic_spline(
101
- inputs,
102
- unnormalized_widths,
103
- unnormalized_heights,
104
- unnormalized_derivatives,
105
- inverse=False,
106
- left=0.0,
107
- right=1.0,
108
- bottom=0.0,
109
- top=1.0,
110
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
- min_derivative=DEFAULT_MIN_DERIVATIVE,
113
- ):
114
- if torch.min(inputs) < left or torch.max(inputs) > right:
115
- raise ValueError("Input to a transform is not within its domain")
116
-
117
- num_bins = unnormalized_widths.shape[-1]
118
-
119
- if min_bin_width * num_bins > 1.0:
120
- raise ValueError("Minimal bin width too large for the number of bins")
121
- if min_bin_height * num_bins > 1.0:
122
- raise ValueError("Minimal bin height too large for the number of bins")
123
-
124
- widths = F.softmax(unnormalized_widths, dim=-1)
125
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
- cumwidths = torch.cumsum(widths, dim=-1)
127
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
- cumwidths = (right - left) * cumwidths + left
129
- cumwidths[..., 0] = left
130
- cumwidths[..., -1] = right
131
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
-
133
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
-
135
- heights = F.softmax(unnormalized_heights, dim=-1)
136
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
- cumheights = torch.cumsum(heights, dim=-1)
138
- cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
- cumheights = (top - bottom) * cumheights + bottom
140
- cumheights[..., 0] = bottom
141
- cumheights[..., -1] = top
142
- heights = cumheights[..., 1:] - cumheights[..., :-1]
143
-
144
- if inverse:
145
- bin_idx = searchsorted(cumheights, inputs)[..., None]
146
- else:
147
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
-
149
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
-
152
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
- delta = heights / widths
154
- input_delta = delta.gather(-1, bin_idx)[..., 0]
155
-
156
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
-
159
- input_heights = heights.gather(-1, bin_idx)[..., 0]
160
-
161
- if inverse:
162
- a = (inputs - input_cumheights) * (
163
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
- ) + input_heights * (input_delta - input_derivatives)
165
- b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
- )
168
- c = -input_delta * (inputs - input_cumheights)
169
-
170
- discriminant = b.pow(2) - 4 * a * c
171
- assert (discriminant >= 0).all()
172
-
173
- root = (2 * c) / (-b - torch.sqrt(discriminant))
174
- outputs = root * input_bin_widths + input_cumwidths
175
-
176
- theta_one_minus_theta = root * (1 - root)
177
- denominator = input_delta + (
178
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
- * theta_one_minus_theta
180
- )
181
- derivative_numerator = input_delta.pow(2) * (
182
- input_derivatives_plus_one * root.pow(2)
183
- + 2 * input_delta * theta_one_minus_theta
184
- + input_derivatives * (1 - root).pow(2)
185
- )
186
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
-
188
- return outputs, -logabsdet
189
- else:
190
- theta = (inputs - input_cumwidths) / input_bin_widths
191
- theta_one_minus_theta = theta * (1 - theta)
192
-
193
- numerator = input_heights * (
194
- input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
- )
196
- denominator = input_delta + (
197
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
- * theta_one_minus_theta
199
- )
200
- outputs = input_cumheights + numerator / denominator
201
-
202
- derivative_numerator = input_delta.pow(2) * (
203
- input_derivatives_plus_one * theta.pow(2)
204
- + 2 * input_delta * theta_one_minus_theta
205
- + input_derivatives * (1 - theta).pow(2)
206
- )
207
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
-
209
- return outputs, logabsdet