Spaces:
Runtime error
Runtime error
# StyleTTS 2 HTTP Streaming API by @fakerybakery - Copyright (c) 2023 mrfakename. All rights reserved. | |
# Docs: API_DOCS.md | |
# To-Do: | |
# * Support voice cloning | |
# * Implement authentication, user "credits" system w/ SQLite3 | |
import io | |
import markdown | |
from tortoise.utils.text import split_and_recombine_text | |
from flask import Flask, Response, request, jsonify | |
import numpy as np | |
import ljinference | |
import torch | |
import hashlib | |
from scipy.io.wavfile import read, write | |
from flask_cors import CORS | |
import os | |
import torchaudio | |
def genHeader(sampleRate, bitsPerSample, channels): | |
datasize = 2000 * 10**6 | |
o = bytes("RIFF", "ascii") | |
o += (datasize + 36).to_bytes(4, "little") | |
o += bytes("WAVE", "ascii") | |
o += bytes("fmt ", "ascii") | |
o += (16).to_bytes(4, "little") | |
o += (1).to_bytes(2, "little") | |
o += (channels).to_bytes(2, "little") | |
o += (sampleRate).to_bytes(4, "little") | |
o += (sampleRate * channels * bitsPerSample // 8).to_bytes(4, "little") | |
o += (channels * bitsPerSample // 8).to_bytes(2, "little") | |
o += (bitsPerSample).to_bytes(2, "little") | |
o += bytes("data", "ascii") | |
o += (datasize).to_bytes(4, "little") | |
return o | |
import phonemizer | |
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) | |
print("Starting Flask app") | |
app = Flask(__name__) | |
cors = CORS(app) | |
def index(): | |
with open('API_DOCS.md', 'r') as f: | |
return markdown.markdown(f.read()) | |
cache_dir = 'cache' | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir, exist_ok=True) | |
def serve_wav(): | |
if request.method == 'GET': | |
request.form = request.args | |
if 'text' not in request.form: | |
if 'text' not in request.json: | |
error_response = {'error': 'Missing required fields. Please include "text" in your request.'} | |
return jsonify(error_response), 400 | |
else: | |
text = request.json['text'] | |
else: | |
text = request.form['text'].strip() | |
texts = split_and_recombine_text(text) | |
audios = [] | |
noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu') | |
for t in texts: | |
# check for cache | |
hash = hashlib.sha256(t.encode()).hexdigest() | |
if os.path.exists(os.path.join(cache_dir, hash + '.wav')): | |
audios.append(read(os.path.join(cache_dir, hash + '.wav'))[1]) | |
else: | |
aud = ljinference.inference(t, noise, diffusion_steps=7, embedding_scale=1) | |
write(os.path.join(cache_dir, hash + '.wav'), 24000, aud) | |
audios.append(aud) | |
output_buffer = io.BytesIO() | |
write(output_buffer, 24000, np.concatenate(audios)) | |
response = Response(output_buffer.getvalue()) | |
response.headers["Content-Type"] = "audio/wav" | |
return response | |
if __name__ == "__main__": | |
app.run("0.0.0.0") |