File size: 6,278 Bytes
5bbd2a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
backend_version = "2.2.3 240316"
print(f"Backend version: {backend_version}")

# 在开头加入路径
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(os.path.join(now_dir, "GPT_SoVITS"))

import soundfile as sf
from flask import Flask, request, Response, jsonify, stream_with_context,send_file
from flask_httpauth import HTTPBasicAuth
from flask_cors import CORS
import io
import urllib.parse
import tempfile
import hashlib, json

# 将当前文件所在的目录添加到 sys.path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# 从配置文件读取配置
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
enable_auth = False
USERS = {}

if os.path.exists(config_path):
    with open(config_path, 'r', encoding='utf-8') as f:
        _config = json.load(f)
        tts_port = _config.get("tts_port", 5000)
        default_batch_size = _config.get("batch_size", 1)
        default_word_count = _config.get("max_word_count", 50)
        enable_auth = _config.get("enable_auth", "false").lower() == "true"
        is_classic = _config.get("classic_inference", "false").lower() == "true"
        if enable_auth:
            print("启用了身份验证")
            USERS = _config.get("user", {})

try:
    from TTS_infer_pack.TTS import TTS
except ImportError:
    is_classic = True

if not is_classic:
    from load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info
else:
    from classic_inference.classic_load_infer_info import load_character, character_name, get_wav_from_text_api, models_path, update_character_info

app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})


# 存储临时文件的字典
temp_files = {}

# 用于防止重复请求
def generate_file_hash(*args):
    """生成基于输入参数的哈希值,用于唯一标识一个请求"""
    hash_object = hashlib.md5()
    for arg in args:
        hash_object.update(str(arg).encode())
    return hash_object.hexdigest()



auth = HTTPBasicAuth()
CORS(app, resources={r"/*": {"origins": "*"}})

@auth.verify_password
def verify_password(username, password):
    if not enable_auth:
        return True  # 如果没有启用验证,则允许访问
    return USERS.get(username) == password


@app.route('/character_list', methods=['GET'])
@auth.login_required
def character_list():
    res = jsonify(update_character_info()['characters_and_emotions'])
    return res


@app.route('/tts', methods=['GET', 'POST'])
@auth.login_required
def tts():
    global character_name
    global models_path

    # 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取
    if request.is_json:
        data = request.json
    else:
        data = request.args

    text = urllib.parse.unquote(data.get('text', ''))
    cha_name = data.get('cha_name', None)
    expected_path = os.path.join(models_path, cha_name) if cha_name else None

    # 检查cha_name和路径
    if cha_name and cha_name != character_name and expected_path and os.path.exists(expected_path):
        character_name = cha_name
        print(f"Loading character {character_name}")
        load_character(character_name)  
    elif expected_path and not os.path.exists(expected_path):
        return jsonify({"error": f"Directory {expected_path} does not exist. Using the current character."}), 400

    text_language = str(data.get('text_language', '多语种混合')).lower()
    try:
        batch_size = int(data.get('batch_size', default_batch_size))
        speed_factor = float(data.get('speed', 1.0))
        top_k = int(data.get('top_k', 6))
        top_p = float(data.get('top_p', 0.8))
        temperature = float(data.get('temperature', 0.8))
        seed = int(data.get('seed', -1))
    except ValueError:
        return jsonify({"error": "Invalid parameters. They must be numbers."}), 400
    stream = str(data.get('stream', 'False')).lower() in ('true', '1', 't', 'y', 'yes')
    save_temp = str(data.get('save_temp', 'False')).lower() in ('true', '1', 't', 'y', 'yes')
    cut_method = str(data.get('cut_method', 'auto_cut')).lower()
    character_emotion = data.get('character_emotion', 'default')

    if cut_method == "auto_cut":
        cut_method = f"auto_cut_{default_word_count}"
    
    params = {
        "text": text,
        "text_language": text_language,
        
        "top_k": top_k,
        "top_p": top_p,
        "temperature": temperature,
        "character_emotion": character_emotion,
        "cut_method": cut_method,
        "stream": stream
    }
    # 如果不是经典模式,则添加额外的参数
    if not is_classic:
        params["batch_size"] = batch_size
        params["speed_factor"] = speed_factor
        params["seed"] = seed
    request_hash = generate_file_hash(text, text_language, top_k, top_p, temperature, character_emotion, character_name, seed)
    
    format = data.get('format', 'wav')
    if not format in ['wav', 'mp3', 'ogg']:
        return jsonify({"error": "Invalid format. It must be one of 'wav', 'mp3', or 'ogg'."}), 400
    
   
    if stream == False:
        if save_temp:
            if request_hash in temp_files:
                return send_file(temp_files[request_hash], mimetype=f'audio/{format}')
            else:
                gen = get_wav_from_text_api(**params)
                sampling_rate, audio_data = next(gen)
                temp_file_path = tempfile.mktemp(suffix=f'.{format}')
                with open(temp_file_path, 'wb') as temp_file:
                    sf.write(temp_file, audio_data, sampling_rate, format=format)
                temp_files[request_hash] = temp_file_path
                return send_file(temp_file_path, mimetype=f'audio/{format}')
        else:
            gen = get_wav_from_text_api(**params)
            sampling_rate, audio_data = next(gen)
            wav = io.BytesIO()
            sf.write(wav, audio_data, sampling_rate, format=format)
            wav.seek(0)
            return Response(wav, mimetype=f'audio/{format}')
    else:
        gen = get_wav_from_text_api(**params)
        return Response(stream_with_context(gen),  mimetype='audio/wav')


if __name__ == '__main__':
    app.run( host='0.0.0.0', port=tts_port)