is this from tortoise?
Browse files- api.py +20 -1
- is_this_from_tortoise.py +14 -0
- models/classifier.py +14 -9
api.py
CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8 |
import progressbar
|
9 |
import torchaudio
|
10 |
|
|
|
11 |
from models.cvvp import CVVP
|
12 |
from models.diffusion_decoder import DiffusionTts
|
13 |
from models.autoregressive import UnifiedVoice
|
@@ -24,7 +25,7 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
|
24 |
pbar = None
|
25 |
|
26 |
|
27 |
-
def download_models():
|
28 |
"""
|
29 |
Call to download all the models that Tortoise uses.
|
30 |
"""
|
@@ -49,6 +50,8 @@ def download_models():
|
|
49 |
pbar.finish()
|
50 |
pbar = None
|
51 |
for model_name, url in MODELS.items():
|
|
|
|
|
52 |
if os.path.exists(f'.models/{model_name}'):
|
53 |
continue
|
54 |
print(f'Downloading {model_name} from {url}...')
|
@@ -144,6 +147,22 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa
|
|
144 |
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
145 |
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
class TextToSpeech:
|
148 |
"""
|
149 |
Main entry point into Tortoise.
|
|
|
8 |
import progressbar
|
9 |
import torchaudio
|
10 |
|
11 |
+
from models.classifier import AudioMiniEncoderWithClassifierHead
|
12 |
from models.cvvp import CVVP
|
13 |
from models.diffusion_decoder import DiffusionTts
|
14 |
from models.autoregressive import UnifiedVoice
|
|
|
25 |
pbar = None
|
26 |
|
27 |
|
28 |
+
def download_models(specific_models=None):
|
29 |
"""
|
30 |
Call to download all the models that Tortoise uses.
|
31 |
"""
|
|
|
50 |
pbar.finish()
|
51 |
pbar = None
|
52 |
for model_name, url in MODELS.items():
|
53 |
+
if specific_models is not None and model_name not in specific_models:
|
54 |
+
continue
|
55 |
if os.path.exists(f'.models/{model_name}'):
|
56 |
continue
|
57 |
print(f'Downloading {model_name} from {url}...')
|
|
|
147 |
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
148 |
|
149 |
|
150 |
+
def classify_audio_clip(clip):
|
151 |
+
"""
|
152 |
+
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
|
153 |
+
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
154 |
+
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
|
155 |
+
"""
|
156 |
+
download_models(['classifier'])
|
157 |
+
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
158 |
+
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
159 |
+
dropout=0, kernel_size=5, distribute_zero_label=False)
|
160 |
+
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
|
161 |
+
clip = clip.cpu().unsqueeze(0)
|
162 |
+
results = F.softmax(classifier(clip), dim=-1)
|
163 |
+
return results[0][0]
|
164 |
+
|
165 |
+
|
166 |
class TextToSpeech:
|
167 |
"""
|
168 |
Main entry point into Tortoise.
|
is_this_from_tortoise.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
from api import classify_audio_clip
|
4 |
+
from utils.audio import load_audio
|
5 |
+
|
6 |
+
if __name__ == '__main__':
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3")
|
9 |
+
args = parser.parse_args()
|
10 |
+
|
11 |
+
clip = load_audio(args.clip, 24000)
|
12 |
+
clip = clip[:, :220000]
|
13 |
+
prob = classify_audio_clip(clip)
|
14 |
+
print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.")
|
models/classifier.py
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
class ResBlock(nn.Module):
|
@@ -27,7 +32,7 @@ class ResBlock(nn.Module):
|
|
27 |
self.in_layers = nn.Sequential(
|
28 |
normalization(channels),
|
29 |
nn.SiLU(),
|
30 |
-
|
31 |
)
|
32 |
|
33 |
self.updown = up or down
|
@@ -46,18 +51,18 @@ class ResBlock(nn.Module):
|
|
46 |
nn.SiLU(),
|
47 |
nn.Dropout(p=dropout),
|
48 |
zero_module(
|
49 |
-
|
50 |
),
|
51 |
)
|
52 |
|
53 |
if self.out_channels == channels:
|
54 |
self.skip_connection = nn.Identity()
|
55 |
elif use_conv:
|
56 |
-
self.skip_connection =
|
57 |
dims, channels, self.out_channels, kernel_size, padding=padding
|
58 |
)
|
59 |
else:
|
60 |
-
self.skip_connection =
|
61 |
|
62 |
def forward(self, x):
|
63 |
if self.do_checkpoint:
|
@@ -94,21 +99,21 @@ class AudioMiniEncoder(nn.Module):
|
|
94 |
kernel_size=3):
|
95 |
super().__init__()
|
96 |
self.init = nn.Sequential(
|
97 |
-
|
98 |
)
|
99 |
ch = base_channels
|
100 |
res = []
|
101 |
self.layers = depth
|
102 |
for l in range(depth):
|
103 |
for r in range(resnet_blocks):
|
104 |
-
res.append(ResBlock(ch, dropout,
|
105 |
-
res.append(Downsample(ch, use_conv=True,
|
106 |
ch *= 2
|
107 |
self.res = nn.Sequential(*res)
|
108 |
self.final = nn.Sequential(
|
109 |
normalization(ch),
|
110 |
nn.SiLU(),
|
111 |
-
|
112 |
)
|
113 |
attn = []
|
114 |
for a in range(attn_blocks):
|
@@ -118,7 +123,7 @@ class AudioMiniEncoder(nn.Module):
|
|
118 |
|
119 |
def forward(self, x):
|
120 |
h = self.init(x)
|
121 |
-
h =
|
122 |
h = self.final(h)
|
123 |
for blk in self.attn:
|
124 |
h = checkpoint(blk, h)
|
|
|
1 |
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
5 |
+
|
6 |
+
from models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
7 |
|
8 |
|
9 |
class ResBlock(nn.Module):
|
|
|
32 |
self.in_layers = nn.Sequential(
|
33 |
normalization(channels),
|
34 |
nn.SiLU(),
|
35 |
+
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
36 |
)
|
37 |
|
38 |
self.updown = up or down
|
|
|
51 |
nn.SiLU(),
|
52 |
nn.Dropout(p=dropout),
|
53 |
zero_module(
|
54 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
55 |
),
|
56 |
)
|
57 |
|
58 |
if self.out_channels == channels:
|
59 |
self.skip_connection = nn.Identity()
|
60 |
elif use_conv:
|
61 |
+
self.skip_connection = nn.Conv1d(
|
62 |
dims, channels, self.out_channels, kernel_size, padding=padding
|
63 |
)
|
64 |
else:
|
65 |
+
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
66 |
|
67 |
def forward(self, x):
|
68 |
if self.do_checkpoint:
|
|
|
99 |
kernel_size=3):
|
100 |
super().__init__()
|
101 |
self.init = nn.Sequential(
|
102 |
+
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
103 |
)
|
104 |
ch = base_channels
|
105 |
res = []
|
106 |
self.layers = depth
|
107 |
for l in range(depth):
|
108 |
for r in range(resnet_blocks):
|
109 |
+
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
|
110 |
+
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
111 |
ch *= 2
|
112 |
self.res = nn.Sequential(*res)
|
113 |
self.final = nn.Sequential(
|
114 |
normalization(ch),
|
115 |
nn.SiLU(),
|
116 |
+
nn.Conv1d(ch, embedding_dim, 1)
|
117 |
)
|
118 |
attn = []
|
119 |
for a in range(attn_blocks):
|
|
|
123 |
|
124 |
def forward(self, x):
|
125 |
h = self.init(x)
|
126 |
+
h = self.res(h)
|
127 |
h = self.final(h)
|
128 |
for blk in self.attn:
|
129 |
h = checkpoint(blk, h)
|