Spaces:
Running
on
Zero
Running
on
Zero
dos taisaku
Browse files
app.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
|
2 |
from flask_socketio import SocketIO, emit
|
3 |
from flask_cors import CORS
|
|
|
|
|
|
|
|
|
|
|
4 |
import io
|
5 |
import os
|
6 |
import argparse
|
@@ -20,8 +25,19 @@ app = Flask(__name__)
|
|
20 |
CORS(app)
|
21 |
socketio = SocketIO(app, cors_allowed_origins="*")
|
22 |
|
23 |
-
#
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
active_tasks = {}
|
26 |
task_futures = {}
|
27 |
|
@@ -29,13 +45,14 @@ task_futures = {}
|
|
29 |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
30 |
|
31 |
class Task:
|
32 |
-
def __init__(self, task_id, mode, weight1, weight2, file_data):
|
33 |
self.task_id = task_id
|
34 |
self.mode = mode
|
35 |
self.weight1 = weight1
|
36 |
self.weight2 = weight2
|
37 |
self.file_data = file_data
|
38 |
self.cancel_flag = False
|
|
|
39 |
|
40 |
def update_queue_status(message=None):
|
41 |
socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
|
@@ -70,6 +87,11 @@ def process_task(task):
|
|
70 |
del active_tasks[task.task_id]
|
71 |
if task.task_id in task_futures:
|
72 |
del task_futures[task.task_id]
|
|
|
|
|
|
|
|
|
|
|
73 |
update_queue_status('Task completed or cancelled')
|
74 |
|
75 |
def worker():
|
@@ -90,20 +112,40 @@ def worker():
|
|
90 |
threading.Thread(target=worker, daemon=True).start()
|
91 |
|
92 |
@app.route('/submit_task', methods=['POST'])
|
|
|
93 |
def submit_task():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
task_id = str(uuid.uuid4())
|
95 |
file = request.files['file']
|
96 |
mode = request.form.get('mode', 'refine')
|
97 |
weight1 = float(request.form.get('weight1', 0.4))
|
98 |
weight2 = float(request.form.get('weight2', 0.3))
|
99 |
|
|
|
|
|
|
|
|
|
|
|
100 |
# ファイルデータをバイト列として保存
|
101 |
file_data = file.read()
|
102 |
|
103 |
-
task = Task(task_id, mode, weight1, weight2, file_data)
|
104 |
task_queue.put(task)
|
105 |
active_tasks[task_id] = task
|
106 |
|
|
|
|
|
|
|
|
|
107 |
update_queue_status(f'Task submitted: {task_id}')
|
108 |
|
109 |
queue_size = task_queue.qsize()
|
@@ -111,17 +153,31 @@ def submit_task():
|
|
111 |
|
112 |
@app.route('/cancel_task/<task_id>', methods=['POST'])
|
113 |
def cancel_task(task_id):
|
|
|
|
|
|
|
114 |
if task_id in active_tasks:
|
115 |
task = active_tasks[task_id]
|
|
|
|
|
|
|
116 |
task.cancel_flag = True
|
117 |
if task_id in task_futures:
|
118 |
task_futures[task_id].cancel()
|
119 |
del task_futures[task_id]
|
120 |
del active_tasks[task_id]
|
|
|
|
|
121 |
update_queue_status('Task cancelled')
|
122 |
return jsonify({'message': 'Task cancellation requested'})
|
123 |
else:
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
def get_active_task_order(task_id):
|
127 |
return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
|
@@ -134,8 +190,16 @@ def handle_get_task_order(task_id):
|
|
134 |
|
135 |
@socketio.on('connect')
|
136 |
def handle_connect():
|
|
|
|
|
|
|
|
|
137 |
emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
|
138 |
|
|
|
|
|
|
|
|
|
139 |
# Flaskルート
|
140 |
@app.route('/', methods=['GET', 'POST'])
|
141 |
def process_refined():
|
@@ -178,17 +242,21 @@ def process_sketch():
|
|
178 |
'sketch_image': sketch_image
|
179 |
})
|
180 |
|
181 |
-
#
|
182 |
-
@app.errorhandler(
|
183 |
-
def
|
184 |
-
|
|
|
|
|
185 |
|
186 |
if __name__ == '__main__':
|
187 |
parser = argparse.ArgumentParser(description='Server options.')
|
188 |
parser.add_argument('--use_local', action='store_true', help='Use local model')
|
189 |
parser.add_argument('--use_gpu', action='store_true', help='Set to True to use GPU but if not available, it will use CPU')
|
190 |
parser.add_argument('--use_dotenv', action='store_true', help='Use .env file for environment variables')
|
|
|
|
|
191 |
args = parser.parse_args()
|
192 |
|
193 |
initialize(args.use_local, args.use_gpu, args.use_dotenv)
|
194 |
-
socketio.run(app, debug=
|
|
|
1 |
from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
|
2 |
from flask_socketio import SocketIO, emit
|
3 |
from flask_cors import CORS
|
4 |
+
from flask_limiter import Limiter
|
5 |
+
from flask_limiter.util import get_remote_address
|
6 |
+
import concurrent.futures
|
7 |
+
import redis
|
8 |
+
|
9 |
import io
|
10 |
import os
|
11 |
import argparse
|
|
|
25 |
CORS(app)
|
26 |
socketio = SocketIO(app, cors_allowed_origins="*")
|
27 |
|
28 |
+
# Redisクライアントの初期化(レート制限とキャッシュのため)
|
29 |
+
redis_client = redis.Redis(host='localhost', port=6379, db=0)
|
30 |
+
|
31 |
+
# レート制限の設定
|
32 |
+
limiter = Limiter(
|
33 |
+
app,
|
34 |
+
key_func=get_remote_address,
|
35 |
+
default_limits=["200 per day", "50 per hour"]
|
36 |
+
)
|
37 |
+
|
38 |
+
# タスクキューの作成とサイズ制限
|
39 |
+
MAX_QUEUE_SIZE = 100
|
40 |
+
task_queue = queue.Queue(maxsize=MAX_QUEUE_SIZE)
|
41 |
active_tasks = {}
|
42 |
task_futures = {}
|
43 |
|
|
|
45 |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
46 |
|
47 |
class Task:
|
48 |
+
def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip):
|
49 |
self.task_id = task_id
|
50 |
self.mode = mode
|
51 |
self.weight1 = weight1
|
52 |
self.weight2 = weight2
|
53 |
self.file_data = file_data
|
54 |
self.cancel_flag = False
|
55 |
+
self.client_ip = client_ip
|
56 |
|
57 |
def update_queue_status(message=None):
|
58 |
socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
|
|
|
87 |
del active_tasks[task.task_id]
|
88 |
if task.task_id in task_futures:
|
89 |
del task_futures[task.task_id]
|
90 |
+
|
91 |
+
# タスク数をデクリメント
|
92 |
+
client_ip = task.client_ip # この行は Task クラスに client_ip 属性を追加する必要があります
|
93 |
+
redis_client.decr(f'tasks:{client_ip}')
|
94 |
+
|
95 |
update_queue_status('Task completed or cancelled')
|
96 |
|
97 |
def worker():
|
|
|
112 |
threading.Thread(target=worker, daemon=True).start()
|
113 |
|
114 |
@app.route('/submit_task', methods=['POST'])
|
115 |
+
@limiter.limit("10 per minute") # 1分間に10回までのリクエストに制限
|
116 |
def submit_task():
|
117 |
+
if task_queue.full():
|
118 |
+
return jsonify({'error': 'Task queue is full. Please try again later.'}), 503
|
119 |
+
|
120 |
+
# クライアントIPアドレスを取得
|
121 |
+
client_ip = get_remote_address()
|
122 |
+
|
123 |
+
# 同一IPからの同時タスク数を制限
|
124 |
+
if redis_client.get(f'tasks:{client_ip}') and int(redis_client.get(f'tasks:{client_ip}')) >= 2:
|
125 |
+
return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
|
126 |
+
|
127 |
task_id = str(uuid.uuid4())
|
128 |
file = request.files['file']
|
129 |
mode = request.form.get('mode', 'refine')
|
130 |
weight1 = float(request.form.get('weight1', 0.4))
|
131 |
weight2 = float(request.form.get('weight2', 0.3))
|
132 |
|
133 |
+
# ファイルタイプの制限
|
134 |
+
allowed_extensions = {'png', 'jpg', 'jpeg', 'gif'}
|
135 |
+
if '.' not in file.filename or file.filename.rsplit('.', 1)[1].lower() not in allowed_extensions:
|
136 |
+
return jsonify({'error': 'Invalid file type'}), 415
|
137 |
+
|
138 |
# ファイルデータをバイト列として保存
|
139 |
file_data = file.read()
|
140 |
|
141 |
+
task = Task(task_id, mode, weight1, weight2, file_data, client_ip)
|
142 |
task_queue.put(task)
|
143 |
active_tasks[task_id] = task
|
144 |
|
145 |
+
# 同一IPからのタスク数をインクリメント
|
146 |
+
redis_client.incr(f'tasks:{client_ip}')
|
147 |
+
redis_client.expire(f'tasks:{client_ip}', 3600) # 1時間後に期限切れ
|
148 |
+
|
149 |
update_queue_status(f'Task submitted: {task_id}')
|
150 |
|
151 |
queue_size = task_queue.qsize()
|
|
|
153 |
|
154 |
@app.route('/cancel_task/<task_id>', methods=['POST'])
|
155 |
def cancel_task(task_id):
|
156 |
+
# クライアントIPアドレスを取得
|
157 |
+
client_ip = get_remote_address()
|
158 |
+
|
159 |
if task_id in active_tasks:
|
160 |
task = active_tasks[task_id]
|
161 |
+
# タスクの所有者を確認(IPアドレスで簡易的に判断)
|
162 |
+
if task.client_ip != client_ip:
|
163 |
+
return jsonify({'error': 'Unauthorized to cancel this task'}), 403
|
164 |
task.cancel_flag = True
|
165 |
if task_id in task_futures:
|
166 |
task_futures[task_id].cancel()
|
167 |
del task_futures[task_id]
|
168 |
del active_tasks[task_id]
|
169 |
+
# タスク数をデクリメント
|
170 |
+
redis_client.decr(f'tasks:{client_ip}')
|
171 |
update_queue_status('Task cancelled')
|
172 |
return jsonify({'message': 'Task cancellation requested'})
|
173 |
else:
|
174 |
+
for task in list(task_queue.queue):
|
175 |
+
if task.task_id == task_id and task.client_ip == client_ip:
|
176 |
+
task.cancel_flag = True
|
177 |
+
# タスク数をデクリメント
|
178 |
+
redis_client.decr(f'tasks:{client_ip}')
|
179 |
+
return jsonify({'message': 'Task cancellation requested for queued task'})
|
180 |
+
return jsonify({'error': 'Task not found'}), 404
|
181 |
|
182 |
def get_active_task_order(task_id):
|
183 |
return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
|
|
|
190 |
|
191 |
@socketio.on('connect')
|
192 |
def handle_connect():
|
193 |
+
# クライアント接続数の制限
|
194 |
+
if redis_client.get('connected_clients') and int(redis_client.get('connected_clients')) > 100:
|
195 |
+
return False # 接続を拒否
|
196 |
+
redis_client.incr('connected_clients')
|
197 |
emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
|
198 |
|
199 |
+
@socketio.on('disconnect')
|
200 |
+
def handle_disconnect():
|
201 |
+
redis_client.decr('connected_clients')
|
202 |
+
|
203 |
# Flaskルート
|
204 |
@app.route('/', methods=['GET', 'POST'])
|
205 |
def process_refined():
|
|
|
242 |
'sketch_image': sketch_image
|
243 |
})
|
244 |
|
245 |
+
# グローバルエラーハンドラー
|
246 |
+
@app.errorhandler(Exception)
|
247 |
+
def handle_exception(e):
|
248 |
+
# ログにエラーを記録
|
249 |
+
app.logger.error(f"Unhandled exception: {str(e)}")
|
250 |
+
return jsonify({'error': 'An unexpected error occurred'}), 500
|
251 |
|
252 |
if __name__ == '__main__':
|
253 |
parser = argparse.ArgumentParser(description='Server options.')
|
254 |
parser.add_argument('--use_local', action='store_true', help='Use local model')
|
255 |
parser.add_argument('--use_gpu', action='store_true', help='Set to True to use GPU but if not available, it will use CPU')
|
256 |
parser.add_argument('--use_dotenv', action='store_true', help='Use .env file for environment variables')
|
257 |
+
parser.add_argument('--debug', action='store_true', help='Run in debug mode')
|
258 |
+
|
259 |
args = parser.parse_args()
|
260 |
|
261 |
initialize(args.use_local, args.use_gpu, args.use_dotenv)
|
262 |
+
socketio.run(app, debug=args.debug, host='0.0.0.0', port=os.environ['PORT'])
|