Fix bug with k>1
Browse files- tortoise/api.py +2 -1
tortoise/api.py
CHANGED
@@ -416,7 +416,8 @@ class TextToSpeech:
|
|
416 |
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
417 |
# results, but will increase memory usage.
|
418 |
self.autoregressive = self.autoregressive.cuda()
|
419 |
-
best_latents = self.autoregressive(auto_conditioning,
|
|
|
420 |
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
421 |
return_latent=True, clip_inputs=False)
|
422 |
self.autoregressive = self.autoregressive.cpu()
|
|
|
416 |
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
417 |
# results, but will increase memory usage.
|
418 |
self.autoregressive = self.autoregressive.cuda()
|
419 |
+
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
420 |
+
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
421 |
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
422 |
return_latent=True, clip_inputs=False)
|
423 |
self.autoregressive = self.autoregressive.cpu()
|