Skip to content

Commit 5326170

Browse files
author
yinsu.zs
committed
refactor
Change-Id: Ic3ba660f6bf816a56e3c33b1cb1df021121366fd
1 parent a3f7962 commit 5326170

File tree

21 files changed

+2532
-2585
lines changed

21 files changed

+2532
-2585
lines changed

src/code/agent/main.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,3 @@
1-
# 第一步:全局替换 print 函数(必须在所有其他导入之前)
2-
import builtins
3-
from datetime import datetime
4-
5-
_original_print = builtins.print
6-
7-
def timestamped_print(*args, **kwargs):
8-
"""带时间戳的 print 函数"""
9-
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
10-
message = ' '.join(str(arg) for arg in args)
11-
timestamped_message = f"{timestamp} {message}"
12-
_original_print(timestamped_message, **kwargs)
13-
14-
builtins.print = timestamped_print
15-
16-
# 第二步:初始化日志系统
17-
from utils.logger import init_logging
18-
init_logging()
19-
20-
# 第三步:导入其他模块
211
from routes.routes import Routes
222

233
r = Routes()
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import json
2+
import os
3+
import time
4+
import traceback
5+
6+
from flask import Blueprint, Flask, jsonify, request
7+
from flask_sock import Sock
8+
9+
from services.management_service import ManagementService, BackendStatus
10+
from utils.logger import log
11+
from services.gateway import (
12+
CpuGatewayService,
13+
HistoryGatewayService,
14+
get_task_queue
15+
)
16+
from services.process.websocket.websocket_manager import ws_manager
17+
from services.serverlessapi.serverless_api_service import ServerlessApiService
18+
19+
20+
class CpuRoutes:
21+
"""CPU模式路由:处理任务队列和异步转发"""
22+
23+
def __init__(self):
24+
# HTTP 路由使用 /api 前缀
25+
self.bp = Blueprint("cpu_routes", __name__, url_prefix="/api")
26+
# WebSocket 路由使用根路径(保持 ComfyUI 兼容性)
27+
self.ws_bp = Blueprint("cpu_ws", __name__)
28+
self.service = ManagementService() # 单例模式,直接创建实例
29+
self.sock = Sock()
30+
self.sock.bp = self.ws_bp # 将 WebSocket 绑定到单独的 Blueprint
31+
self.setup_routes()
32+
33+
def register(self, app: Flask):
34+
app.register_blueprint(self.bp)
35+
app.register_blueprint(self.ws_bp)
36+
37+
def setup_routes(self):
38+
"""设置所有路由"""
39+
self._register_websocket()
40+
self._register_queue_handler()
41+
self._register_prompt_handler()
42+
self._register_serverless_run_handler()
43+
self._register_history_handler()
44+
# 仅在 prod 模式下注册 userdata 拦截
45+
if os.environ.get('AUTO_LAUNCH_SNAPSHOT_NAME'):
46+
self._register_userdata_block()
47+
48+
def _check_backend_status(self):
49+
"""
50+
检查后端服务状态
51+
52+
Returns:
53+
tuple: (is_valid, error_response)
54+
is_valid为True时error_response为None
55+
is_valid为False时error_response为错误响应
56+
"""
57+
backend_status = self.service.status
58+
if backend_status not in (BackendStatus.RUNNING, BackendStatus.SAVING):
59+
return False, (jsonify({
60+
"status": "failed",
61+
"message": "Please start your comfyui/sd service first"
62+
}), 500)
63+
return True, None
64+
65+
def _register_websocket(self):
66+
@self.sock.route("/ws")
67+
def comfyui_compatible_ws(ws):
68+
"""
69+
CPU函数接收ComfyUI原生的WebSocket连接
70+
保持与ComfyUI前端完全兼容,但推送的是基于任务队列和状态轮询的真实状态
71+
"""
72+
try:
73+
ws_manager.add_connection(ws)
74+
log("INFO", f"New ComfyUI WebSocket connection established")
75+
76+
# 发送初始状态消息(模拟ComfyUI原生行为)
77+
client_id = f"cpu_client_{int(time.time() * 1000)}"
78+
try:
79+
ws.send(json.dumps({
80+
"type": "status",
81+
"data": {
82+
"sid": client_id,
83+
"status": {
84+
"exec_info": {
85+
"queue_remaining": get_task_queue()._get_pending_task_count()
86+
}
87+
}
88+
}
89+
}))
90+
except Exception as e:
91+
log("ERROR", f"Failed to send initial status: {e}")
92+
return
93+
94+
# 设置客户端ID,用于后续关联任务
95+
setattr(ws, '_comfyui_client_id', client_id)
96+
97+
# 将客户端ID与连接关联在WebSocketManager中
98+
ws_manager.associate_client_id_with_connection(ws, client_id)
99+
100+
# TODO 可能是多余的
101+
while True:
102+
try:
103+
message = ws.receive()
104+
log("DEBUG", f"Received message from ComfyUI frontend: {message[:100]}...")
105+
106+
except Exception as e:
107+
error_str = str(e)
108+
if "Connection closed" in error_str or "closed" in error_str.lower():
109+
log("INFO", f"Connection closed by client")
110+
break
111+
log("ERROR", f"Error receiving message: {e}\n{traceback.format_exc()}")
112+
break
113+
114+
except Exception as e:
115+
log("ERROR", f"Connection error: {e}\n{traceback.format_exc()}")
116+
finally:
117+
try:
118+
ws_manager.remove_connection(ws)
119+
log("INFO", f"ComfyUI WebSocket connection closed")
120+
except Exception as e:
121+
log("ERROR", f"Error removing connection: {e}")
122+
123+
def _register_queue_handler(self):
124+
@self.bp.route("/queue", methods=["GET", "POST"])
125+
def handle_queue():
126+
is_valid, error_response = self._check_backend_status()
127+
if not is_valid:
128+
return error_response
129+
130+
try:
131+
gateway_service = CpuGatewayService()
132+
133+
if request.method == "GET":
134+
return gateway_service.handle_queue_get_request()
135+
elif request.method == "POST":
136+
return gateway_service.handle_queue_post_request()
137+
else:
138+
return jsonify({
139+
"error": {
140+
"type": "method_not_allowed",
141+
"message": f"Method {request.method} not allowed"
142+
}
143+
}), 405
144+
145+
except Exception as e:
146+
error_msg = f"Failed to handle queue request: {str(e)}"
147+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
148+
149+
return jsonify({
150+
"error": {
151+
"type": "queue_operation_error",
152+
"message": error_msg
153+
}
154+
}), 500
155+
156+
def _register_prompt_handler(self):
157+
@self.bp.route("/prompt", methods=["POST"])
158+
def handle_prompt():
159+
is_valid, error_response = self._check_backend_status()
160+
if not is_valid:
161+
return error_response
162+
163+
try:
164+
gateway_service = CpuGatewayService()
165+
return gateway_service.handle_prompt_request_async()
166+
except Exception as e:
167+
error_msg = f"Failed to handle prompt request: {str(e)}"
168+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
169+
170+
return jsonify({
171+
"error": {
172+
"type": "prompt_operation_error",
173+
"message": error_msg
174+
}
175+
}), 500
176+
177+
def _register_serverless_run_handler(self):
178+
@self.bp.route("/serverless/run", methods=["POST"])
179+
def handle_serverless_run():
180+
"""
181+
处理 /api/serverless/run 请求,支持同步和异步两种模式
182+
183+
调用方式:
184+
- 默认: 异步调用(与 /api/prompt 处理一致)
185+
- Header X-Art-Invocation-Type: Sync 时: 同步调用,等待GPU返回结果
186+
187+
异步模式:
188+
- 将请求转发到GPU函数(异步调用)
189+
- 返回任务ID,前端通过任务ID轮询获取结果
190+
- 使用任务队列跟踪任务状态
191+
192+
同步模式:
193+
- 将请求转发到GPU函数(同步调用)
194+
- 等待GPU处理完成并返回结果
195+
- 直接返回结果给客户端
196+
"""
197+
is_valid, error_response = self._check_backend_status()
198+
if not is_valid:
199+
return error_response
200+
201+
try:
202+
gateway_service = CpuGatewayService()
203+
204+
# 检查调用类型:Header X-Art-Invocation-Type: Sync 表示同步调用
205+
invocation_type = request.headers.get("X-Art-Invocation-Type", "").strip()
206+
is_sync = invocation_type.lower() == "sync"
207+
208+
if is_sync:
209+
log("DEBUG", f"Processing /serverless/run in SYNC mode (X-Art-Invocation-Type: Sync)")
210+
return gateway_service.handle_serverless_run_sync()
211+
else:
212+
log("DEBUG", f"Processing /serverless/run in ASYNC mode (default)")
213+
return gateway_service.handle_serverless_run_async()
214+
215+
except Exception as e:
216+
error_msg = f"Failed to handle serverless run request: {str(e)}"
217+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
218+
219+
return jsonify({
220+
"error": {
221+
"type": "serverless_run_error",
222+
"message": error_msg
223+
}
224+
}), 500
225+
226+
def _register_history_handler(self):
227+
@self.bp.route("/history", methods=["GET", "POST", "DELETE"])
228+
@self.bp.route("/history/<path:subpath>", methods=["GET", "POST", "DELETE"])
229+
def handle_history(subpath=""):
230+
is_valid, error_response = self._check_backend_status()
231+
if not is_valid:
232+
return error_response
233+
234+
try:
235+
history_gateway = HistoryGatewayService()
236+
path = f"api/history/{subpath}" if subpath else "api/history"
237+
return history_gateway.handle_history_request(path)
238+
except Exception as e:
239+
error_msg = f"Failed to handle history request: {str(e)}"
240+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
241+
242+
return jsonify({
243+
"error": {
244+
"type": "history_operation_error",
245+
"message": error_msg
246+
}
247+
}), 500
248+
249+
def _register_userdata_block(self):
250+
"""在 prod 模式下,阻止保存 userdata 文件"""
251+
@self.bp.route("/userdata/<path:file>", methods=["POST"])
252+
def block_userdata_save(file):
253+
log("WARN", f"Attempt to save userdata blocked in prod mode: {file}")
254+
return jsonify({
255+
"error": {
256+
"type": "forbidden",
257+
"message": "Saving workflow is disabled in prod mode"
258+
}
259+
}), 403

0 commit comments

Comments
 (0)