#!/usr/bin/env python3 # tts_cli.py """ Example CLI for generating audio with Kokoro-StyleTTS2. Usage: python tts_cli.py \ --model /path/to/kokoro-v0_19.pth \ --config /path/to/config.json \ --text "Hello, my stinking friends from 1906! You stink." \ --voicepack /path/to/af.pt \ --output output.wav Make sure: 1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.). 2. You have installed the needed libraries: pip install torch phonemizer munch soundfile pyyaml 3. The model is a checkpoint that your `build_model` can load. Adapt as needed! """ import argparse import os import re import torch import soundfile as sf import numpy as np from phonemizer import backend as phonemizer_backend # If you use eSpeak library: try: from espeak_util import set_espeak_library set_espeak_library() except ImportError: pass # -------------------------------------------------------------------- # Import from your local `models.py` (requires that file to be present). # This example assumes `build_model` loads the entire TTS submodules # (bert, bert_encoder, predictor, decoder, text_encoder). # -------------------------------------------------------------------- from models import build_model def resplit_strings(arr): """ Given a list of string tokens (e.g. words, phrases), tries to split them into two sub-lists whose total lengths are as balanced as possible. The goal is to chunk a large string in half without splitting in the middle of a word. """ if not arr: return "", "" if len(arr) == 1: return arr[0], "" min_diff = float("inf") best_split = 0 lengths = [len(s) for s in arr] spaces = len(arr) - 1 left_len = 0 right_len = sum(lengths) + spaces for i in range(1, len(arr)): # Add current word + space to left side left_len += lengths[i - 1] + (1 if i > 1 else 0) # Remove from right side right_len -= lengths[i - 1] + 1 diff = abs(left_len - right_len) if diff < min_diff: min_diff = diff best_split = i return " ".join(arr[:best_split]), " ".join(arr[best_split:]) def recursive_split(text, lang="a"): """ Splits a piece of text into smaller segments so that each segment's phoneme length < some ~limit (~500 tokens). """ # We'll reuse your existing `phonemize_text` + `tokenize` from script 1 # to see if it is < 512 tokens. If it is, return it as a single chunk. # Otherwise, split on punctuation or whitespace and recurse. # 1. Phonemize first, check length ps = phonemize_text(text, lang=lang, do_normalize=True) tokens = tokenize(ps) if len(tokens) < 512: return [(text, ps)] # If too large, we split on certain punctuation or fallback to whitespace # We'll look for punctuation that often indicates sentence boundaries # If none found, fallback to space-split for punctuation in [r"[.?!…]", r"[:,;—]"]: pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) " # Attempt to split on that punctuation splits = re.split(pattern, text) if len(splits) > 1: break else: # If we didn't break out, just do whitespace split splits = text.split(" ") # Use resplit_strings to chunk it about halfway left, right = resplit_strings(splits) # Recurse return recursive_split(left, lang=lang) + recursive_split(right, lang=lang) def segment_and_tokenize(long_text, lang="a"): """ Takes a large text, optionally normalizes or cleans it, then breaks it into a list of (segment_text, segment_phonemes). """ # Additional cleaning if you want: # long_text = normalize_text(long_text) # your existing function # We chunk it up using recursive_split segments = recursive_split(long_text, lang=lang) return segments # -------------- Normalization & Phonemization Routines -------------- # def parens_to_angles(s): return s.replace("(", "«").replace(")", "»") def split_num(num): num = num.group() if "." in num: return num elif ":" in num: h, m = [int(n) for n in num.split(":")] if m == 0: return f"{h} o'clock" elif m < 10: return f"{h} oh {m}" return f"{h} {m}" year = int(num[:4]) if year < 1100 or year % 1000 < 10: return num left, right = num[:2], int(num[2:4]) s = "s" if num.endswith("s") else "" if 100 <= year % 1000 <= 999: if right == 0: return f"{left} hundred{s}" elif right < 10: return f"{left} oh {right}{s}" return f"{left} {right}{s}" def flip_money(m): m = m.group() bill = "dollar" if m[0] == "$" else "pound" if m[-1].isalpha(): return f"{m[1:]} {bill}s" elif "." not in m: s = "" if m[1:] == "1" else "s" return f"{m[1:]} {bill}{s}" b, c = m[1:].split(".") s = "" if b == "1" else "s" c = int(c.ljust(2, "0")) coins = ( f"cent{'' if c == 1 else 's'}" if m[0] == "$" else ("penny" if c == 1 else "pence") ) return f"{b} {bill}{s} and {c} {coins}" def point_num(num): a, b = num.group().split(".") return " point ".join([a, " ".join(b)]) def normalize_text(text): text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace(chr(8220), '"').replace(chr(8221), '"') text = parens_to_angles(text) # Replace some common full-width punctuation in CJK: for a, b in zip("、。!,:;?", ",.!,:;?"): text = text.replace(a, b + " ") text = re.sub(r"[^\S \n]", " ", text) text = re.sub(r" +", " ", text) text = re.sub(r"(?<=\n) +(?=\n)", "", text) text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text) text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text) text = re.sub(r"\betc\.(?! [A-Z])", "etc", text) text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) text = re.sub( r"\d*\.\d+|\b\d{4}s?\b|(? "nˈaɪn hˈʌndɹɪd" ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps) # "z" at the end of a word -> remove space (just your snippet) ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps) # If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet: if lang == "a": ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps) # Only keep valid symbols ps = "".join(p for p in ps if p in VOCAB) return ps.strip() # ------------------------------------------------------------------- # Utility for generating text masks # ------------------------------------------------------------------- def length_to_mask(lengths): # lengths is a Tensor of shape [B], containing the text length for each batch max_len = lengths.max() row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0) mask = row_ids.expand(lengths.shape[0], -1) return (mask + 1) > lengths.unsqueeze(1) # ------------------------------------------------------------------- # The forward pass for inference (from your snippet). # This version references `model.predictor`, `model.decoder`, etc. # ------------------------------------------------------------------- @torch.no_grad() def forward_tts(model, tokens, ref_s, speed=1.0): """ model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder tokens: list[int], the tokenized input (without [0, ... , 0] yet) ref_s: reference embedding (torch.Tensor) speed: float, speed factor """ device = ref_s.device tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device) text_mask = length_to_mask(input_lengths).to(device) # 1. Encode with BERT bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) # 2. Prosody predictor s = ref_s[ :, 128: ] # from your snippet: the last 128 is ???, or the first 128 is ??? d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = model.predictor.lstm(d) duration = model.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) / speed pred_dur = torch.round(duration).clamp(min=1).long() # 3. Expand alignment total_len = pred_dur.sum().item() pred_aln_trg = torch.zeros(input_lengths, total_len, device=device) c_frame = 0 for i in range(pred_aln_trg.size(0)): n = pred_dur[0, i].item() pred_aln_trg[i, c_frame : c_frame + n] = 1 c_frame += n # 4. Run F0 + Noise predictor en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0) F0_pred, N_pred = model.predictor.F0Ntrain(en, s) # 5. Text encoder -> asr t_en = model.text_encoder(tokens_t, input_lengths, text_mask) asr = t_en @ pred_aln_trg.unsqueeze(0) # 6. Decode audio audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len return audio.squeeze().cpu().numpy() def generate_tts(model, text, voicepack, lang="a", speed=1.0): """ model: the Munch returned by build_model(...) text: the input text (string) voicepack: the torch Tensor reference embedding, or a dict of them lang: 'a' or 'b' or etc. from your phonemizers speed: speech speed factor sample_rate: sampling rate for the output """ # 1. Phonemize ps = phonemize_text(text, lang=lang, do_normalize=True) tokens = tokenize(ps) if not tokens: return None, ps # 2. Retrieve reference style # If your voicepack is a single embedding for all lengths, adapt as needed. # If your voicepack is something like `voicepack[len(tokens)]`, do that. # If you have multiple voices, you might do something else. try: ref_s = voicepack[len(tokens)] except: # fallback if len(tokens) is out of range ref_s = voicepack[-1] ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda") # 3. Generate audio = forward_tts(model, tokens, ref_s, speed=speed) return audio, ps def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0): """ Generate TTS for a large `full_text`, splitting it into smaller segments and concatenating the resulting audio. Returns: (np.float32 array) final_audio, list_of_segment_phonemes """ # 1. Segment the text segments = segment_and_tokenize(full_text, lang=lang) # segments is a list of (seg_text, seg_phonemes) # 2. For each segment, call `generate_tts(...)` audio_chunks = [] all_phonemes = [] for i, (seg_text, seg_ps) in enumerate(segments, 1): print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...") audio, used_phonemes = generate_tts( model, seg_text, voicepack, lang=lang, speed=speed ) if audio is not None: audio_chunks.append(audio) all_phonemes.append(used_phonemes) else: print(f"[LongForm] Skipped empty segment {i}...") if not audio_chunks: return None, [] # 3. Concatenate the audio final_audio = np.concatenate(audio_chunks, axis=0) return final_audio, all_phonemes # ------------------------------------------------------------------- # Main CLI # ------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example") parser.add_argument( "--model", type=str, default="pretrained_models/Kokoro/kokoro-v0_19.pth", help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).", ) parser.add_argument( "--config", type=str, default="pretrained_models/Kokoro/config.json", help="Path to config.json (used by build_model).", ) parser.add_argument( "--text", type=str, default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!", help="Text to be converted into speech.", ) parser.add_argument( "--voicepack", type=str, default="pretrained_models/Kokoro/voices/af.pt", help="Path to a .pt file for your reference embedding(s).", ) parser.add_argument( "--output", type=str, default="output.wav", help="Output WAV filename." ) parser.add_argument( "--speed", type=float, default=1.0, help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to run inference on.", ) args = parser.parse_args() # 1. Build model using your local build_model function # (which loads TextEncoder, Decoder, etc. and returns a Munch). if not os.path.isfile(args.config): raise FileNotFoundError(f"config.json not found: {args.config}") # Optionally load config as Munch (depends on your build_model usage) # But your snippet does something like: # with open(config, 'r') as r: ... # ... # model = build_model(path, device) # We'll do the same but in a simpler form: device = ( args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu" ) print(f"Loading model from: {args.model}") model = build_model( args.model, device ) # This requires that `args.model` is the checkpoint path # Because `build_model` returns a Munch (dict of submodules), # we can't just do `model.eval()`, we must set each submodule to eval: for k, subm in model.items(): if isinstance(subm, torch.nn.Module): subm.eval() # 2. Load voicepack if not os.path.isfile(args.voicepack): raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}") print(f"Loading voicepack from: {args.voicepack}") vp = torch.load(args.voicepack, map_location=device) # If your voicepack is an nn.Module, set it to eval as well if isinstance(vp, torch.nn.Module): vp.eval() # 3. Generate audio print(f"Generating speech for text: {args.text}") audio, phonemes = generate_long_form_tts( model, args.text, vp, lang="a", speed=args.speed ) if audio is None: print("No tokens were generated (maybe empty text?). Exiting.") return # 4. Write WAV print(f"Writing output to: {args.output}") sf.write(args.output, audio, 22050) print("Finished!") print(f"Phonemes used: {phonemes}") if __name__ == "__main__": main()