Spaces:
Runtime error
Runtime error
File size: 6,278 Bytes
5bbd2a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
backend_version = "2.2.3 240316"
print(f"Backend version: {backend_version}")
# 在开头加入路径
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(os.path.join(now_dir, "GPT_SoVITS"))
import soundfile as sf
from flask import Flask, request, Response, jsonify, stream_with_context,send_file
from flask_httpauth import HTTPBasicAuth
from flask_cors import CORS
import io
import urllib.parse
import tempfile
import hashlib, json
# 将当前文件所在的目录添加到 sys.path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 从配置文件读取配置
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
enable_auth = False
USERS = {}
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
_config = json.load(f)
tts_port = _config.get("tts_port", 5000)
default_batch_size = _config.get("batch_size", 1)
default_word_count = _config.get("max_word_count", 50)
enable_auth = _config.get("enable_auth", "false").lower() == "true"
is_classic = _config.get("classic_inference", "false").lower() == "true"
if enable_auth:
print("启用了身份验证")
USERS = _config.get("user", {})
try:
from TTS_infer_pack.TTS import TTS
except ImportError:
is_classic = True
if not is_classic:
from load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info
else:
from classic_inference.classic_load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
# 存储临时文件的字典
temp_files = {}
# 用于防止重复请求
def generate_file_hash(*args):
"""生成基于输入参数的哈希值,用于唯一标识一个请求"""
hash_object = hashlib.md5()
for arg in args:
hash_object.update(str(arg).encode())
return hash_object.hexdigest()
auth = HTTPBasicAuth()
CORS(app, resources={r"/*": {"origins": "*"}})
@auth.verify_password
def verify_password(username, password):
if not enable_auth:
return True # 如果没有启用验证,则允许访问
return USERS.get(username) == password
@app.route('/character_list', methods=['GET'])
@auth.login_required
def character_list():
res = jsonify(update_character_info()['characters_and_emotions'])
return res
@app.route('/tts', methods=['GET', 'POST'])
@auth.login_required
def tts():
global character_name
global models_path
# 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取
if request.is_json:
data = request.json
else:
data = request.args
text = urllib.parse.unquote(data.get('text', ''))
cha_name = data.get('cha_name', None)
expected_path = os.path.join(models_path, cha_name) if cha_name else None
# 检查cha_name和路径
if cha_name and cha_name != character_name and expected_path and os.path.exists(expected_path):
character_name = cha_name
print(f"Loading character {character_name}")
load_character(character_name)
elif expected_path and not os.path.exists(expected_path):
return jsonify({"error": f"Directory {expected_path} does not exist. Using the current character."}), 400
text_language = str(data.get('text_language', '多语种混合')).lower()
try:
batch_size = int(data.get('batch_size', default_batch_size))
speed_factor = float(data.get('speed', 1.0))
top_k = int(data.get('top_k', 6))
top_p = float(data.get('top_p', 0.8))
temperature = float(data.get('temperature', 0.8))
seed = int(data.get('seed', -1))
except ValueError:
return jsonify({"error": "Invalid parameters. They must be numbers."}), 400
stream = str(data.get('stream', 'False')).lower() in ('true', '1', 't', 'y', 'yes')
save_temp = str(data.get('save_temp', 'False')).lower() in ('true', '1', 't', 'y', 'yes')
cut_method = str(data.get('cut_method', 'auto_cut')).lower()
character_emotion = data.get('character_emotion', 'default')
if cut_method == "auto_cut":
cut_method = f"auto_cut_{default_word_count}"
params = {
"text": text,
"text_language": text_language,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"character_emotion": character_emotion,
"cut_method": cut_method,
"stream": stream
}
# 如果不是经典模式,则添加额外的参数
if not is_classic:
params["batch_size"] = batch_size
params["speed_factor"] = speed_factor
params["seed"] = seed
request_hash = generate_file_hash(text, text_language, top_k, top_p, temperature, character_emotion, character_name, seed)
format = data.get('format', 'wav')
if not format in ['wav', 'mp3', 'ogg']:
return jsonify({"error": "Invalid format. It must be one of 'wav', 'mp3', or 'ogg'."}), 400
if stream == False:
if save_temp:
if request_hash in temp_files:
return send_file(temp_files[request_hash], mimetype=f'audio/{format}')
else:
gen = get_wav_from_text_api(**params)
sampling_rate, audio_data = next(gen)
temp_file_path = tempfile.mktemp(suffix=f'.{format}')
with open(temp_file_path, 'wb') as temp_file:
sf.write(temp_file, audio_data, sampling_rate, format=format)
temp_files[request_hash] = temp_file_path
return send_file(temp_file_path, mimetype=f'audio/{format}')
else:
gen = get_wav_from_text_api(**params)
sampling_rate, audio_data = next(gen)
wav = io.BytesIO()
sf.write(wav, audio_data, sampling_rate, format=format)
wav.seek(0)
return Response(wav, mimetype=f'audio/{format}')
else:
gen = get_wav_from_text_api(**params)
return Response(stream_with_context(gen), mimetype='audio/wav')
if __name__ == '__main__':
app.run( host='0.0.0.0', port=tts_port)
|