jianghuyihei commited on
Commit
12c54f7
1 Parent(s): f9b90e4

fix websocket

Browse files
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, Form, Request,Response
2
  from fastapi.responses import HTMLResponse
3
  from jinja2 import Template
4
  import markdown
@@ -6,9 +6,11 @@ import time
6
  from datetime import datetime, timedelta
7
  from apscheduler.schedulers.background import BackgroundScheduler
8
  from agents import DeepResearchAgent, get_llms
 
9
  import threading
10
  import logging
11
  from queue import Queue
 
12
  import uuid
13
 
14
 
@@ -346,10 +348,27 @@ html_template = """
346
  <script>
347
  {{ script}}
348
  </script>
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  </body>
350
  </html>
351
  """
352
 
 
 
 
 
353
  # 重置每日计数器
354
  def reset_counter():
355
  global reply_count
@@ -393,23 +412,50 @@ script_template = """
393
 
394
  queue = Queue()
395
 
396
- @app.get("/", response_class=HTMLResponse)
397
- def form_get():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
- script = script_template.format(user_id=str(uuid.uuid4()), state="generate")
 
 
 
 
 
400
  return Template(html_template).render(idea= "This is a example of the idea geneartion", error=None, reply_count=reply_count,button_text="Generate",loading_text="Generating content, Usually takes 3-4 minutes, please wait...",script=script)
401
 
402
  @app.post("/", response_class=HTMLResponse)
403
- def form_post(topic: str = Form(...), user_id: str = Form(...), state: str = Form(...)):
404
  global reply_count
405
  start_time = time.time()
406
-
407
  if user_id == "":
408
- user_id = str(uuid.uuid4())
409
  if state == "":
410
  state = "generate"
411
 
412
  script = script_template.format(user_id=user_id, state=state)
 
 
 
 
413
 
414
  print(f"current0 user_id={user_id}, state={state}")
415
  loading_text = "Generating content, Usually takes 3-4 minutes, please wait..."
@@ -437,7 +483,7 @@ def form_post(topic: str = Form(...), user_id: str = Form(...), state: str = For
437
 
438
  return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text=new_button_text,loading_text=f"Generating content, Usually takes {(queue_len+1)*3}-{(queue_len+1)*4} minutes, please wait...",script=script)
439
 
440
- queue.put([user_id,topic])
441
  new_state = "generate"
442
  new_button_text = "Generate"
443
  queue_len = queue.qsize()
@@ -446,13 +492,24 @@ def form_post(topic: str = Form(...), user_id: str = Form(...), state: str = For
446
  print(f"current2 user_id={user_id}, state={new_state}")
447
  # 判断当前是否轮到该用户,如果没轮到则一直等待到轮到为止
448
  print(queue.queue[0], [user_id,topic])
449
- while queue.queue[0] != [user_id,topic]:
 
 
 
 
 
 
 
 
 
 
450
  time.sleep(10)
451
  continue
452
 
453
  try:
454
  with lock:
455
  logging.info(f"Processing request for topic: {topic}")
 
456
  start_time = time.time()
457
  error_message = None
458
  idea = ""
 
1
+ from fastapi import FastAPI, Form, Request,Response, WebSocket, WebSocketDisconnect
2
  from fastapi.responses import HTMLResponse
3
  from jinja2 import Template
4
  import markdown
 
6
  from datetime import datetime, timedelta
7
  from apscheduler.schedulers.background import BackgroundScheduler
8
  from agents import DeepResearchAgent, get_llms
9
+ import hashlib
10
  import threading
11
  import logging
12
  from queue import Queue
13
+ import json
14
  import uuid
15
 
16
 
 
348
  <script>
349
  {{ script}}
350
  </script>
351
+ <script>
352
+ const socket = new WebSocket("ws://localhost:7860/ws");
353
+
354
+ socket.addEventListener('open', function (event) {
355
+ const userId = document.getElementById("user_id").value;
356
+ socket.send(JSON.stringify({ action: "connect", user_id: userId }));
357
+ });
358
+
359
+ window.addEventListener("beforeunload", function (event) {
360
+ const userId = document.getElementById("user_id").value;
361
+ socket.send(JSON.stringify({ action: "disconnect", user_id: userId }));
362
+ });
363
+ </script>
364
  </body>
365
  </html>
366
  """
367
 
368
+ def generate_user_id(ip: str) -> str:
369
+ # 使用哈希函数生成用户 ID
370
+ return hashlib.md5(ip.encode()).hexdigest()
371
+
372
  # 重置每日计数器
373
  def reset_counter():
374
  global reply_count
 
412
 
413
  queue = Queue()
414
 
415
+ @app.websocket("/ws")
416
+ async def websocket_endpoint(websocket: WebSocket):
417
+ await websocket.accept()
418
+ print("WebSocket connection established.")
419
+ try:
420
+ while True:
421
+ data = await websocket.receive_text()
422
+ message = json.loads(data)
423
+ user_id = message.get("user_id")
424
+ action = message.get("action")
425
+ print(action)
426
+ if action == "disconnect":
427
+ for item in list(queue.queue):
428
+ if item[0] == user_id:
429
+ queue.queue.remove(item)
430
+ break
431
+ print(f"User {user_id} disconnected.")
432
+
433
+ except WebSocketDisconnect:
434
+ print("WebSocket connection closed.")
435
 
436
+ @app.get("/", response_class=HTMLResponse)
437
+ def form_get(request: Request):
438
+ client_ip = request.client.host
439
+ user_id = generate_user_id(client_ip)
440
+ script = script_template.format(user_id=user_id, state="generate")
441
+ print(client_ip,user_id)
442
  return Template(html_template).render(idea= "This is a example of the idea geneartion", error=None, reply_count=reply_count,button_text="Generate",loading_text="Generating content, Usually takes 3-4 minutes, please wait...",script=script)
443
 
444
  @app.post("/", response_class=HTMLResponse)
445
+ def form_post(request: Request,topic: str = Form(...), user_id: str = Form(...), state: str = Form(...)):
446
  global reply_count
447
  start_time = time.time()
448
+ client_ip = request.client.host
449
  if user_id == "":
450
+ user_id = generate_user_id(client_ip)
451
  if state == "":
452
  state = "generate"
453
 
454
  script = script_template.format(user_id=user_id, state=state)
455
+ if user_id in queue.queue:
456
+ error_message = "Your request is being processed. Please wait for the result, or close the previous page and try again."
457
+ script = script_template.format(user_id=user_id, state=state)
458
+ return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text="Generate",loading_text="Generating content, Usually takes 3-4 minutes, please wait...",script=script)
459
 
460
  print(f"current0 user_id={user_id}, state={state}")
461
  loading_text = "Generating content, Usually takes 3-4 minutes, please wait..."
 
483
 
484
  return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text=new_button_text,loading_text=f"Generating content, Usually takes {(queue_len+1)*3}-{(queue_len+1)*4} minutes, please wait...",script=script)
485
 
486
+ queue.put(user_id)
487
  new_state = "generate"
488
  new_button_text = "Generate"
489
  queue_len = queue.qsize()
 
492
  print(f"current2 user_id={user_id}, state={new_state}")
493
  # 判断当前是否轮到该用户,如果没轮到则一直等待到轮到为止
494
  print(queue.queue[0], [user_id,topic])
495
+ while queue.queue[0] != user_id:
496
+ # 检查用户是否还在队列中
497
+ if not any(user_id == item[0] for item in queue.queue):
498
+ return Template(html_template).render(
499
+ idea="",
500
+ error="Request was cancelled.",
501
+ reply_count=reply_count,
502
+ button_text="Generate",
503
+ loading_text=loading_text,
504
+ script=script
505
+ )
506
  time.sleep(10)
507
  continue
508
 
509
  try:
510
  with lock:
511
  logging.info(f"Processing request for topic: {topic}")
512
+ time.sleep(1000)
513
  start_time = time.time()
514
  error_message = None
515
  idea = ""
prompts/__pycache__/deep_research_agent_promts.cpython-310.pyc CHANGED
Binary files a/prompts/__pycache__/deep_research_agent_promts.cpython-310.pyc and b/prompts/__pycache__/deep_research_agent_promts.cpython-310.pyc differ