Spaces:
Running
Running
from fastapi import FastAPI, Form, Request,Response | |
from fastapi.responses import HTMLResponse | |
from jinja2 import Template | |
import markdown | |
import time | |
from datetime import datetime, timedelta | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from agents import DeepResearchAgent, get_llms | |
import threading | |
import logging | |
from queue import Queue | |
import uuid | |
lock = threading.Lock() | |
app = FastAPI() | |
# 每日最大回复次数 | |
MAX_REPLIES_PER_DAY = 100 | |
# 当日回复次数计数器 | |
reply_count = 0 | |
# 启动时设置计数器重置 | |
last_reset_time = datetime.now() | |
# HTML模板 | |
html_template = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>CoI Agent online demo 😊</title> | |
<style> | |
body { | |
font-family: 'Arial', sans-serif; | |
background-color: #f4f4f9; | |
margin: 0; | |
padding: 0; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
min-height: 100vh; | |
} | |
.container { | |
width: 95%; | |
max-width: 1200px; | |
background-color: #fff; | |
padding: 2rem; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
h1 { | |
font-size: 2rem; | |
margin-bottom: 1.5rem; | |
color: #333; | |
text-align: center; | |
} | |
form { | |
margin-bottom: 1.5rem; | |
} | |
.form-group { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin-bottom: 1.5rem; | |
} | |
.form-group label { | |
flex: 0; | |
font-size: 1 rem; /* 增大字体 */ | |
color: #333; | |
margin-right: 0.5rem; | |
background-color: #f0f8ff; /* 气泡背景颜色 */ | |
padding: 0.5rem 1rem; /* 气泡内边距 */ | |
border-radius: 10px; /* 气泡圆角 */ | |
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.1); /* 艺术字效果 */ | |
font-family: 'Times new roman', cursive, sans-serif; /* 艺术字体 */ | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); /* 气泡阴影 */ | |
} | |
.form-group input { | |
flex: 4; | |
padding: 0.6rem; | |
font-size: 1rem; | |
border: 1px solid #ccc; | |
border-radius: 5px; | |
margin-left: 1rem; | |
} | |
.form-group button { | |
flex: 0; | |
padding: 0.6rem 1rem; | |
font-size: 1rem; | |
background-color: #F2A582; | |
color: #fff; | |
border: none; | |
border-radius: 5px; | |
cursor: pointer; | |
transition: background-color 0.3s ease; | |
margin-left: 1rem; | |
} | |
.form-group button:hover { | |
background-color: #0056b3; | |
} | |
.loading, | |
.time-box, | |
.counter-box, | |
.result, | |
.error { | |
margin-top: 1.5rem; | |
} | |
.loading { | |
font-size: 1.2rem; | |
color: #007bff; | |
animation: fadeIn 0.5s ease-in-out; | |
text-align: center; | |
} | |
.time-counter-container { | |
display: flex; | |
justify-content: space-between; | |
} | |
.time-box, | |
.counter-box { | |
display: inline-block; | |
padding: 0.5rem 1rem; | |
background-color: #e9ecef; | |
border-radius: 10px; | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
font-size: 0.9rem; | |
margin: 0.5rem; | |
flex: 1; | |
text-align: center; | |
} | |
.result { | |
display: flex; | |
justify-content: space-between; | |
flex-wrap: wrap; | |
} | |
.result .box { | |
flex: 1; | |
margin: 0.5rem; | |
padding: 1rem; | |
background-color: #e9ecef; | |
border-radius: 10px; | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
word-wrap: break-word; | |
height: 400px; | |
overflow-y: auto; | |
font-size: 1rem; | |
font-family: "Times New Roman", Times, serif; | |
line-height: 1.5; | |
} | |
.error .box { | |
width: 100%; | |
padding: 1rem; | |
background-color: #f8d7da; | |
color: #721c24; | |
border-radius: 10px; | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
word-wrap: break-word; | |
} | |
h2 { | |
font-size: 1.3rem; | |
margin-bottom: 1rem; | |
color: #333; | |
} | |
@keyframes fadeIn { | |
from { opacity: 0; } | |
to { opacity: 1; } | |
} | |
.progress-bar-container { | |
width: 100%; | |
background-color: #e9ecef; | |
border-radius: 10px; | |
overflow: hidden; | |
margin-top: 1.5rem; | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
} | |
.progress-bar { | |
height: 20px; | |
background-color: #727372; | |
width: 0%; | |
transition: width 0.1s ease; | |
} | |
.example-container { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin-bottom: 1.5rem; | |
} | |
.example-label { | |
flex: 0.7; | |
font-size: 1 rem; | |
color: #333; | |
text-align: center; | |
margin-right: 0rem; | |
padding: 0.5rem 0.2rem; | |
background-color: #f0f8ff; | |
border-radius: 10px; | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
font-family: 'Times new roman', cursive, sans-serif; | |
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.1); | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
} | |
.example-topics { | |
flex: 6; | |
display: flex; | |
justify-content: space-around; | |
} | |
.example-topics button { | |
padding: 0.5rem 1rem; | |
font-size: 1rem; | |
background-color: #ffa07a; /* 浅橙色 */ | |
color: #fff; | |
border: none; | |
border-radius: 5px; | |
cursor: pointer; | |
margin: 0.3rem; | |
transition: background-color 0.3s ease; | |
} | |
.example-topics button:hover { | |
background-color: #ff4500; /* 深橙色 */ | |
} | |
</style> | |
<script> | |
let startTime = 0; | |
let intervalId = null; | |
let progressIntervalId = null; | |
let maxTime = 180; // 最大时间180秒 | |
function showLoading() { | |
document.getElementById("loading").style.display = "block"; | |
document.getElementById("submit-btn").disabled = true; | |
startTime = Date.now(); | |
intervalId = setInterval(updateTime, 100); | |
progressIntervalId = setInterval(updateProgressBar, 100); | |
// 隐藏错误消息 | |
const errorBox = document.querySelector(".error"); | |
if (errorBox) { | |
errorBox.style.display = "none"; | |
} | |
} | |
function hideLoading() { | |
document.getElementById("loading").style.display = "none"; | |
document.getElementById("submit-btn").disabled = false; | |
if (intervalId) { | |
clearInterval(intervalId); | |
intervalId = null; | |
} | |
if (progressIntervalId) { | |
clearInterval(progressIntervalId); | |
progressIntervalId = null; | |
} | |
updateProgressBar(100); // 立即更新进度条至100% | |
} | |
function updateTime() { | |
const now = Date.now(); | |
const elapsed = ((now - startTime) / 1000).toFixed(2); | |
document.getElementById("time-taken").innerText = `Time Taken: ${elapsed} s`; | |
} | |
function updateProgressBar(percentage = null) { | |
const progressBar = document.getElementById("progress-bar"); | |
if (percentage !== null) { | |
progressBar.style.width = `${percentage}%`; | |
} else { | |
const now = Date.now(); | |
const elapsed = (now - startTime) / 1000; | |
const progress = Math.min((elapsed / maxTime) * 60, 97); | |
progressBar.style.width = `${progress}%`; | |
} | |
} | |
function fillTopic(topic) { | |
document.getElementById("topic").value = topic; | |
} | |
</script> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>CoI Agent online demo 😊</h1> | |
<div class="time-counter-container"> | |
<div id="time-taken" class="time-box">Time Taken: {{ time_taken }} seconds</div> | |
<div class="counter-box">Today's Replies: {{ reply_count }}/100 </div> | |
</div> | |
<div class="progress-bar-container"> | |
<div id="progress-bar" class="progress-bar"></div> | |
</div> | |
<div class="result"> | |
<div class="box"> | |
<h2>Idea</h2> | |
<div>{{ idea | safe }}</div> | |
</div> | |
</div> | |
<form action="/" method="post" onsubmit="showLoading()"> | |
<div class="form-group"> | |
<label for="topic">Topic:</label> | |
<input type="text" id="topic" name="topic"> | |
<button type="submit" id="submit-btn">{{ button_text }}</button> | |
</div> | |
</form> | |
<div class="example-container"> | |
<div class="example-label">Example Input:</div> | |
<div class="example-topics"> | |
<button onclick="fillTopic('Realistic Image Synthesis in Medical Imaging')">Realistic Image Synthesis in Medical Imaging</button> | |
<button onclick="fillTopic('Using diffusion to generate road layout')">Using diffusion to generate road layout</button> | |
<button onclick="fillTopic('Using LLM-based agent to generate idea')">Using LLM-based agent to generate idea</button> | |
</div> | |
</div> | |
<div id="loading" class="loading" display="none">{{ loading_text }}</div> | |
{% if error %} | |
<div class="error"> | |
<div class="box"> | |
<h2>Error</h2> | |
<div>{{ error }}</div> | |
</div> | |
</div> | |
{% endif %} | |
</div> | |
<script> | |
async function getUserId() { | |
// 检查 sessionStorage 中是否有 user_id | |
let userId = sessionStorage.getItem("user_id"); | |
if (!userId) { | |
// 请求新的 user_id | |
const response = await fetch("/user_id"); | |
const data = await response.json(); | |
userId = data.user_id; | |
sessionStorage.setItem("user_id", userId); | |
} | |
console.log("User ID:", userId); | |
} | |
window.onload = getUserId; | |
</script> | |
</body> | |
</html> | |
""" | |
# 重置每日计数器 | |
def reset_counter(): | |
global reply_count | |
reply_count = 0 | |
# 设置定时任务每天0点重置计数器 | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(reset_counter, 'cron', hour=0, minute=0) | |
scheduler.start() | |
queue = Queue() | |
def fix_markdown(text): | |
lines = text.split('\n') | |
# Initialize the result list | |
result = [] | |
# Iterate through the lines | |
for i, line in enumerate(lines): | |
# Check if the current line starts with a numbered list item | |
numbers = ['1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', '10.', '11.', '12.', '13.', '14.', '15.', '16.', '17.', '18.', '19.', '20.','21.','22.','23.','24.','25.','26.','27.','28.','29.','30.'] | |
if line.lstrip().startswith(tuple(numbers)): | |
# If it's not the first line and the previous line is not blank, add a blank line | |
if i > 0 and lines[i - 1].strip(): | |
result.append('') | |
# Append the current line to the result | |
result.append(line) | |
# Join the result list into a single string with newline characters | |
return '<br>'.join(result) | |
async def add_user_id_and_state_cookie(request: Request, call_next): | |
user_id = request.cookies.get("user_id") | |
state = request.cookies.get("state") | |
if not user_id: | |
user_id = str(uuid.uuid4()) | |
state = "generate" | |
response = await call_next(request) | |
response.set_cookie(key="user_id", value=user_id) | |
response.set_cookie(key="state", value=state) | |
elif not state: | |
state = "generate" | |
response = await call_next(request) | |
response.set_cookie(key="state", value=state) | |
else: | |
response = await call_next(request) | |
return response | |
async def get_user_id(): | |
user_id = str(uuid.uuid4()) | |
return {"user_id": user_id} | |
def form_get(): | |
return Template(html_template).render(idea= "This is a example of the idea geneartion", error=None, reply_count=reply_count,button_text="Generate",loading_text="Generating content, Usually takes 2-3 minutes, please wait...") | |
def form_post(request: Request,response: Response,topic: str = Form(...)): | |
global reply_count | |
start_time = time.time() | |
# 获得每个网页的user_id | |
user_id = request.cookies.get("user_id") | |
state = request.cookies.get("state", "generate") | |
print(user_id,state) | |
loading_text = "Generating content, Usually takes 2-3 minutes, please wait..." | |
if state == "generate": | |
if not queue.empty(): | |
queue_len = queue.qsize() | |
if queue_len + reply_count >= MAX_REPLIES_PER_DAY: | |
error_message = "Today's maximum number of replies has been reached. Please try again tomorrow." | |
return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text="Generate",loading_text=loading_text) | |
error_message = "There are currently {} requests being processed. If you want to queue, please click the Continue button and you will enter the queue.".format(queue_len) | |
new_state = "continue" | |
new_button_text = "Continue" | |
response.set_cookie(key="state", value=new_state) | |
return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text=new_button_text,loading_text=f"Generating content, Usually takes {queue_len*2}-{queue_len*3} minutes, please wait...") | |
else: | |
queue.put(user_id) | |
elif state == "continue": | |
queue.put(user_id) | |
new_state = "generate" | |
new_button_text = "Generate" | |
response.set_cookie(key="state", value=new_state) | |
queue_len = queue.qsize() | |
# 判断当前是否轮到该用户,如果没轮到则一直等待到轮到为止 | |
while queue.queue[0] != user_id: | |
continue | |
with lock: | |
logging.info(f"Processing request for topic: {topic}") | |
start_time = time.time() | |
error_message = None | |
idea = "" | |
time_taken = 0 | |
# 检查是否超过每日最大回复次数 | |
if reply_count >= MAX_REPLIES_PER_DAY: | |
error_message = "Today's maximum number of replies has been reached. Please try again tomorrow." | |
logging.info(f"Today's maximum number of replies has been reached. Please try again tomorrow.") | |
try: | |
main_llm, cheap_llm = get_llms() | |
deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1) | |
print(f"begin to generate idea of topic {topic}") | |
idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year = deep_research_agent.generate_idea_with_chain(topic) | |
idea = fix_markdown(idea) | |
idea = markdown.markdown(idea) | |
# 更新每日回复次数 | |
reply_count += 1 | |
end_time = time.time() | |
time_taken = round(end_time - start_time, 2) | |
logging.info(f"Successfully generated idea for topic: {topic}") | |
except Exception as e: | |
end_time = time.time() | |
time_taken = round(end_time - start_time, 2) | |
logging.error(f"Failed to generate idea for topic: {topic}, Error: {str(e)}") | |
error_message = str(e) | |
# 从队列中移除当前用户 | |
queue.get() | |
return Template(html_template).render(idea=idea, error=error_message, reply_count=reply_count, time_taken=time_taken,button_text=new_button_text,loading_text=loading_text) | |