yeq6x commited on
Commit
90ca682
1 Parent(s): 64567a7

dos taisaku

Browse files
Files changed (1) hide show
  1. app.py +78 -10
app.py CHANGED
@@ -1,6 +1,11 @@
1
  from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
2
  from flask_socketio import SocketIO, emit
3
  from flask_cors import CORS
 
 
 
 
 
4
  import io
5
  import os
6
  import argparse
@@ -20,8 +25,19 @@ app = Flask(__name__)
20
  CORS(app)
21
  socketio = SocketIO(app, cors_allowed_origins="*")
22
 
23
- # タスクキューの作成
24
- task_queue = queue.Queue()
 
 
 
 
 
 
 
 
 
 
 
25
  active_tasks = {}
26
  task_futures = {}
27
 
@@ -29,13 +45,14 @@ task_futures = {}
29
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
30
 
31
  class Task:
32
- def __init__(self, task_id, mode, weight1, weight2, file_data):
33
  self.task_id = task_id
34
  self.mode = mode
35
  self.weight1 = weight1
36
  self.weight2 = weight2
37
  self.file_data = file_data
38
  self.cancel_flag = False
 
39
 
40
  def update_queue_status(message=None):
41
  socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
@@ -70,6 +87,11 @@ def process_task(task):
70
  del active_tasks[task.task_id]
71
  if task.task_id in task_futures:
72
  del task_futures[task.task_id]
 
 
 
 
 
73
  update_queue_status('Task completed or cancelled')
74
 
75
  def worker():
@@ -90,20 +112,40 @@ def worker():
90
  threading.Thread(target=worker, daemon=True).start()
91
 
92
  @app.route('/submit_task', methods=['POST'])
 
93
  def submit_task():
 
 
 
 
 
 
 
 
 
 
94
  task_id = str(uuid.uuid4())
95
  file = request.files['file']
96
  mode = request.form.get('mode', 'refine')
97
  weight1 = float(request.form.get('weight1', 0.4))
98
  weight2 = float(request.form.get('weight2', 0.3))
99
 
 
 
 
 
 
100
  # ファイルデータをバイト列として保存
101
  file_data = file.read()
102
 
103
- task = Task(task_id, mode, weight1, weight2, file_data)
104
  task_queue.put(task)
105
  active_tasks[task_id] = task
106
 
 
 
 
 
107
  update_queue_status(f'Task submitted: {task_id}')
108
 
109
  queue_size = task_queue.qsize()
@@ -111,17 +153,31 @@ def submit_task():
111
 
112
  @app.route('/cancel_task/<task_id>', methods=['POST'])
113
  def cancel_task(task_id):
 
 
 
114
  if task_id in active_tasks:
115
  task = active_tasks[task_id]
 
 
 
116
  task.cancel_flag = True
117
  if task_id in task_futures:
118
  task_futures[task_id].cancel()
119
  del task_futures[task_id]
120
  del active_tasks[task_id]
 
 
121
  update_queue_status('Task cancelled')
122
  return jsonify({'message': 'Task cancellation requested'})
123
  else:
124
- return jsonify({'message': 'Task not found or already completed'}), 404
 
 
 
 
 
 
125
 
126
  def get_active_task_order(task_id):
127
  return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
@@ -134,8 +190,16 @@ def handle_get_task_order(task_id):
134
 
135
  @socketio.on('connect')
136
  def handle_connect():
 
 
 
 
137
  emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
138
 
 
 
 
 
139
  # Flaskルート
140
  @app.route('/', methods=['GET', 'POST'])
141
  def process_refined():
@@ -178,17 +242,21 @@ def process_sketch():
178
  'sketch_image': sketch_image
179
  })
180
 
181
- # エラーハンドラー
182
- @app.errorhandler(500)
183
- def server_error(e):
184
- return jsonify(error=str(e)), 500
 
 
185
 
186
  if __name__ == '__main__':
187
  parser = argparse.ArgumentParser(description='Server options.')
188
  parser.add_argument('--use_local', action='store_true', help='Use local model')
189
  parser.add_argument('--use_gpu', action='store_true', help='Set to True to use GPU but if not available, it will use CPU')
190
  parser.add_argument('--use_dotenv', action='store_true', help='Use .env file for environment variables')
 
 
191
  args = parser.parse_args()
192
 
193
  initialize(args.use_local, args.use_gpu, args.use_dotenv)
194
- socketio.run(app, debug=True, host='0.0.0.0', port=5000)
 
1
  from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
2
  from flask_socketio import SocketIO, emit
3
  from flask_cors import CORS
4
+ from flask_limiter import Limiter
5
+ from flask_limiter.util import get_remote_address
6
+ import concurrent.futures
7
+ import redis
8
+
9
  import io
10
  import os
11
  import argparse
 
25
  CORS(app)
26
  socketio = SocketIO(app, cors_allowed_origins="*")
27
 
28
+ # Redisクライアントの初期化(レート制限とキャッシュのため)
29
+ redis_client = redis.Redis(host='localhost', port=6379, db=0)
30
+
31
+ # レート制限の設定
32
+ limiter = Limiter(
33
+ app,
34
+ key_func=get_remote_address,
35
+ default_limits=["200 per day", "50 per hour"]
36
+ )
37
+
38
+ # タスクキューの作成とサイズ制限
39
+ MAX_QUEUE_SIZE = 100
40
+ task_queue = queue.Queue(maxsize=MAX_QUEUE_SIZE)
41
  active_tasks = {}
42
  task_futures = {}
43
 
 
45
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
46
 
47
  class Task:
48
+ def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip):
49
  self.task_id = task_id
50
  self.mode = mode
51
  self.weight1 = weight1
52
  self.weight2 = weight2
53
  self.file_data = file_data
54
  self.cancel_flag = False
55
+ self.client_ip = client_ip
56
 
57
  def update_queue_status(message=None):
58
  socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
 
87
  del active_tasks[task.task_id]
88
  if task.task_id in task_futures:
89
  del task_futures[task.task_id]
90
+
91
+ # タスク数をデクリメント
92
+ client_ip = task.client_ip # この行は Task クラスに client_ip 属性を追加する必要があります
93
+ redis_client.decr(f'tasks:{client_ip}')
94
+
95
  update_queue_status('Task completed or cancelled')
96
 
97
  def worker():
 
112
  threading.Thread(target=worker, daemon=True).start()
113
 
114
  @app.route('/submit_task', methods=['POST'])
115
+ @limiter.limit("10 per minute") # 1分間に10回までのリクエストに制限
116
  def submit_task():
117
+ if task_queue.full():
118
+ return jsonify({'error': 'Task queue is full. Please try again later.'}), 503
119
+
120
+ # クライアントIPアドレスを取得
121
+ client_ip = get_remote_address()
122
+
123
+ # 同一IPからの同時タスク数を制限
124
+ if redis_client.get(f'tasks:{client_ip}') and int(redis_client.get(f'tasks:{client_ip}')) >= 2:
125
+ return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
126
+
127
  task_id = str(uuid.uuid4())
128
  file = request.files['file']
129
  mode = request.form.get('mode', 'refine')
130
  weight1 = float(request.form.get('weight1', 0.4))
131
  weight2 = float(request.form.get('weight2', 0.3))
132
 
133
+ # ファイルタイプの制限
134
+ allowed_extensions = {'png', 'jpg', 'jpeg', 'gif'}
135
+ if '.' not in file.filename or file.filename.rsplit('.', 1)[1].lower() not in allowed_extensions:
136
+ return jsonify({'error': 'Invalid file type'}), 415
137
+
138
  # ファイルデータをバイト列として保存
139
  file_data = file.read()
140
 
141
+ task = Task(task_id, mode, weight1, weight2, file_data, client_ip)
142
  task_queue.put(task)
143
  active_tasks[task_id] = task
144
 
145
+ # 同一IPからのタスク数をインクリメント
146
+ redis_client.incr(f'tasks:{client_ip}')
147
+ redis_client.expire(f'tasks:{client_ip}', 3600) # 1時間後に期限切れ
148
+
149
  update_queue_status(f'Task submitted: {task_id}')
150
 
151
  queue_size = task_queue.qsize()
 
153
 
154
  @app.route('/cancel_task/<task_id>', methods=['POST'])
155
  def cancel_task(task_id):
156
+ # クライアントIPアドレスを取得
157
+ client_ip = get_remote_address()
158
+
159
  if task_id in active_tasks:
160
  task = active_tasks[task_id]
161
+ # タスクの所有者を確認(IPアドレスで簡易的に判断)
162
+ if task.client_ip != client_ip:
163
+ return jsonify({'error': 'Unauthorized to cancel this task'}), 403
164
  task.cancel_flag = True
165
  if task_id in task_futures:
166
  task_futures[task_id].cancel()
167
  del task_futures[task_id]
168
  del active_tasks[task_id]
169
+ # タスク数をデクリメント
170
+ redis_client.decr(f'tasks:{client_ip}')
171
  update_queue_status('Task cancelled')
172
  return jsonify({'message': 'Task cancellation requested'})
173
  else:
174
+ for task in list(task_queue.queue):
175
+ if task.task_id == task_id and task.client_ip == client_ip:
176
+ task.cancel_flag = True
177
+ # タスク数をデクリメント
178
+ redis_client.decr(f'tasks:{client_ip}')
179
+ return jsonify({'message': 'Task cancellation requested for queued task'})
180
+ return jsonify({'error': 'Task not found'}), 404
181
 
182
  def get_active_task_order(task_id):
183
  return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
 
190
 
191
  @socketio.on('connect')
192
  def handle_connect():
193
+ # クライアント接続数の制限
194
+ if redis_client.get('connected_clients') and int(redis_client.get('connected_clients')) > 100:
195
+ return False # 接続を拒否
196
+ redis_client.incr('connected_clients')
197
  emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
198
 
199
+ @socketio.on('disconnect')
200
+ def handle_disconnect():
201
+ redis_client.decr('connected_clients')
202
+
203
  # Flaskルート
204
  @app.route('/', methods=['GET', 'POST'])
205
  def process_refined():
 
242
  'sketch_image': sketch_image
243
  })
244
 
245
+ # グローバルエラーハンドラー
246
+ @app.errorhandler(Exception)
247
+ def handle_exception(e):
248
+ # ログにエラーを記録
249
+ app.logger.error(f"Unhandled exception: {str(e)}")
250
+ return jsonify({'error': 'An unexpected error occurred'}), 500
251
 
252
  if __name__ == '__main__':
253
  parser = argparse.ArgumentParser(description='Server options.')
254
  parser.add_argument('--use_local', action='store_true', help='Use local model')
255
  parser.add_argument('--use_gpu', action='store_true', help='Set to True to use GPU but if not available, it will use CPU')
256
  parser.add_argument('--use_dotenv', action='store_true', help='Use .env file for environment variables')
257
+ parser.add_argument('--debug', action='store_true', help='Run in debug mode')
258
+
259
  args = parser.parse_args()
260
 
261
  initialize(args.use_local, args.use_gpu, args.use_dotenv)
262
+ socketio.run(app, debug=args.debug, host='0.0.0.0', port=os.environ['PORT'])