diff --git a/.gitignore b/.gitignore index 8d87b4e..285facf 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ **/.pytest_cache **/__pycache__ **/*.pyc -**/dist \ No newline at end of file +**/dist +deploy.sh diff --git a/src/code/agent/constants.py b/src/code/agent/constants.py index d2b4f98..22cd171 100644 --- a/src/code/agent/constants.py +++ b/src/code/agent/constants.py @@ -6,6 +6,9 @@ TYPE_SD = 'sd' BACKEND_TYPE = os.getenv('BACKEND_TYPE', TYPE_COMFYUI) +# ComfyUI 模式配置:'cpu' 或 'gpu' +COMFYUI_MODE = os.getenv('COMFYUI_MODE', 'gpu').lower() + # API Mode USE_API_MODE = bool(os.getenv("AUTO_LAUNCH_SNAPSHOT_NAME")) AUTO_LAUNCH_SNAPSHOT_NAME = os.getenv("AUTO_LAUNCH_SNAPSHOT_NAME", "latest") @@ -22,7 +25,26 @@ SNAPSHOT_PATTERN = '%Y%m%d-%H%M%S' COMFYUI_DIR = os.getenv('COMFYUI_DIR', WORK_DIR + '/comfyui') COMFYUI_PROCESS_PORT = 8188 -COMFYUI_BOOT_CMD = [ +# CPU模式启动命令 +COMFYUI_CPU_BOOT_CMD = [ + f"{VENV_DIR}/bin/python", + f"{COMFYUI_DIR}/main.py", + "--cpu", + "--listen", + "0.0.0.0", + "--input-directory", + f"{MNT_DIR}/input", + "--output-directory", + f"{MNT_DIR}/output", + "--temp-directory", + f"{MNT_DIR}/output", + "--user-directory", + f"{MNT_DIR}/output", + "--disable-metadata" +] + +# GPU模式启动命令(不包含--cpu参数) +COMFYUI_GPU_BOOT_CMD = [ f"{VENV_DIR}/bin/python", f"{COMFYUI_DIR}/main.py", "--listen", @@ -37,6 +59,9 @@ f"{MNT_DIR}/output", "--disable-metadata" ] + +# 根据模式选择启动命令 +COMFYUI_BOOT_CMD = COMFYUI_CPU_BOOT_CMD if COMFYUI_MODE == 'cpu' else COMFYUI_GPU_BOOT_CMD SD_DIR = os.getenv('SD_DIR', WORK_DIR + '/stable-diffusion-webui') SD_PROCESS_PORT = 7860 SD_BOOT_CMD = [ @@ -83,6 +108,9 @@ PREWARM_PROMPT = os.getenv("PREWARM_PROMPT", "") +# GPU 函数的 URL,当 COMFYUI_MODE="cpu" 时使用 +GPU_FUNCTION_URL = os.getenv("GPU_FUNCTION_URL", "") + # 日志配置 LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") diff --git a/src/code/agent/routes/cpu_routes.py b/src/code/agent/routes/cpu_routes.py new file mode 100644 index 0000000..6140fe5 --- /dev/null +++ b/src/code/agent/routes/cpu_routes.py @@ -0,0 +1,279 @@ +import json +import os +import time +import traceback + +from flask import Blueprint, Flask, jsonify, request +from flask_sock import Sock + +from services.management_service import ManagementService, BackendStatus +from utils.logger import log +from services.gateway import ( + CpuGatewayService, + HistoryGatewayService, + get_task_queue +) +from services.process.websocket.websocket_manager import ws_manager +from services.serverlessapi.serverless_api_service import ServerlessApiService + + +class CpuRoutes: + """CPU模式路由:处理任务队列和异步转发""" + + def __init__(self): + # HTTP 路由使用 /api 前缀 + self.bp = Blueprint("cpu_routes", __name__, url_prefix="/api") + # WebSocket 路由使用根路径(保持 ComfyUI 兼容性) + self.ws_bp = Blueprint("cpu_ws", __name__) + self.service = ManagementService() # 单例模式,直接创建实例 + self.sock = Sock() + self.sock.bp = self.ws_bp # 将 WebSocket 绑定到单独的 Blueprint + self.setup_routes() + + def register(self, app: Flask): + app.register_blueprint(self.bp) + app.register_blueprint(self.ws_bp) + + def setup_routes(self): + """设置所有路由""" + self._register_websocket() + self._register_queue_handler() + self._register_prompt_handler() + self._register_serverless_run_handler() + self._register_history_handler() + # 通过环境变量控制是否禁用 userdata 保存 + # DISABLE_USERDATA_SAVE=true 时禁用 + disable_userdata = os.environ.get('DISABLE_USERDATA_SAVE', '').lower() in ('true', '1', 'yes') + if disable_userdata: + self._register_userdata_handler() + + def _check_backend_status(self): + """ + 检查后端服务状态 + + Returns: + tuple: (is_valid, error_response) + is_valid为True时error_response为None + is_valid为False时error_response为错误响应 + """ + backend_status = self.service.status + if backend_status not in (BackendStatus.RUNNING, BackendStatus.SAVING): + return False, (jsonify({ + "status": "failed", + "message": "Please start your comfyui/sd service first" + }), 500) + return True, None + + def _register_websocket(self): + @self.sock.route("/ws") + def comfyui_compatible_ws(ws): + """ + CPU函数接收ComfyUI原生的WebSocket连接 + 保持与ComfyUI前端完全兼容,但推送的是基于任务队列和状态轮询的真实状态 + + 支持重连机制: + - 客户端可通过 ?clientId=xxx 参数传递已有的 client_id + - 重连时会复用相同的 client_id,确保能接收到之前任务的状态更新 + """ + try: + # 从查询参数获取 clientId(ComfyUI 前端重连时会传递) + from flask import request as flask_request + client_id = flask_request.args.get('clientId', '') + + if client_id: + # 复用已有的 client_id(重连场景) + log("INFO", f"WebSocket reconnecting with existing client_id: {client_id}") + else: + # 生成新的 client_id(首次连接) + client_id = f"cpu_client_{int(time.time() * 1000)}" + log("INFO", f"New ComfyUI WebSocket connection with client_id: {client_id}") + + # 添加连接到管理器 + ws_manager.add_connection(ws) + + # 发送初始状态消息(模拟ComfyUI原生行为) + try: + ws.send(json.dumps({ + "type": "status", + "data": { + "sid": client_id, + "status": { + "exec_info": { + "queue_remaining": get_task_queue()._get_pending_task_count() + } + } + } + })) + except Exception as e: + log("ERROR", f"Failed to send initial status: {e}") + return + + # 设置客户端ID,用于后续关联任务 + setattr(ws, '_comfyui_client_id', client_id) + + # 将客户端ID与连接关联在WebSocketManager中 + ws_manager.associate_client_id_with_connection(ws, client_id) + + # 如果是重连,重新订阅该客户端的所有进行中的任务 + ws_manager.resubscribe_client_tasks(ws, client_id) + + # TODO 可能是多余的 + while True: + try: + message = ws.receive() + log("DEBUG", f"Received message from ComfyUI frontend: {message[:100]}...") + + except Exception as e: + error_str = str(e) + if "Connection closed" in error_str or "closed" in error_str.lower(): + log("INFO", f"Connection closed by client") + break + log("ERROR", f"Error receiving message: {e}\n{traceback.format_exc()}") + break + + except Exception as e: + log("ERROR", f"Connection error: {e}\n{traceback.format_exc()}") + finally: + try: + ws_manager.remove_connection(ws) + log("INFO", f"ComfyUI WebSocket connection closed") + except Exception as e: + log("ERROR", f"Error removing connection: {e}") + + def _register_queue_handler(self): + @self.bp.route("/queue", methods=["GET", "POST"]) + def handle_queue(): + is_valid, error_response = self._check_backend_status() + if not is_valid: + return error_response + + try: + gateway_service = CpuGatewayService() + + if request.method == "GET": + return gateway_service.handle_queue_get_request() + elif request.method == "POST": + return gateway_service.handle_queue_post_request() + else: + return jsonify({ + "error": { + "type": "method_not_allowed", + "message": f"Method {request.method} not allowed" + } + }), 405 + + except Exception as e: + error_msg = f"Failed to handle queue request: {str(e)}" + log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}") + + return jsonify({ + "error": { + "type": "queue_operation_error", + "message": error_msg + } + }), 500 + + def _register_prompt_handler(self): + @self.bp.route("/prompt", methods=["POST"]) + def handle_prompt(): + is_valid, error_response = self._check_backend_status() + if not is_valid: + return error_response + + try: + gateway_service = CpuGatewayService() + return gateway_service.handle_prompt_request_async() + except Exception as e: + error_msg = f"Failed to handle prompt request: {str(e)}" + log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}") + + return jsonify({ + "error": { + "type": "prompt_operation_error", + "message": error_msg + } + }), 500 + + def _register_serverless_run_handler(self): + @self.bp.route("/serverless/run", methods=["POST"]) + def handle_serverless_run(): + """ + 处理 /api/serverless/run 请求,支持同步和异步两种模式 + + 调用方式: + - 默认: 异步调用(与 /api/prompt 处理一致) + - Header X-Art-Invocation-Type: Sync 时: 同步调用,等待GPU返回结果 + + 异步模式: + - 将请求转发到GPU函数(异步调用) + - 返回任务ID,前端通过任务ID轮询获取结果 + - 使用任务队列跟踪任务状态 + + 同步模式: + - 将请求转发到GPU函数(同步调用) + - 等待GPU处理完成并返回结果 + - 直接返回结果给客户端 + """ + is_valid, error_response = self._check_backend_status() + if not is_valid: + return error_response + + try: + gateway_service = CpuGatewayService() + + # 检查调用类型:Header X-Art-Invocation-Type: Sync 表示同步调用 + invocation_type = request.headers.get("X-Art-Invocation-Type", "").strip() + is_sync = invocation_type.lower() == "sync" + + if is_sync: + log("DEBUG", f"Processing /serverless/run in SYNC mode (X-Art-Invocation-Type: Sync)") + return gateway_service.handle_serverless_run_sync() + else: + log("DEBUG", f"Processing /serverless/run in ASYNC mode (default)") + return gateway_service.handle_serverless_run_async() + + except Exception as e: + error_msg = f"Failed to handle serverless run request: {str(e)}" + log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}") + + return jsonify({ + "error": { + "type": "serverless_run_error", + "message": error_msg + } + }), 500 + + def _register_history_handler(self): + @self.bp.route("/history", methods=["GET", "POST", "DELETE"]) + @self.bp.route("/history/", methods=["GET", "POST", "DELETE"]) + def handle_history(subpath=""): + is_valid, error_response = self._check_backend_status() + if not is_valid: + return error_response + + try: + history_gateway = HistoryGatewayService() + path = f"api/history/{subpath}" if subpath else "api/history" + return history_gateway.handle_history_request(path) + except Exception as e: + error_msg = f"Failed to handle history request: {str(e)}" + log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}") + + return jsonify({ + "error": { + "type": "history_operation_error", + "message": error_msg + } + }), 500 + + def _register_userdata_handler(self): + """在 prod 模式下,阻止保存 userdata 文件""" + @self.bp.route("/userdata/", methods=["POST"]) + def block_userdata_save(file): + log("WARN", f"Attempt to save userdata blocked in prod mode: {file}") + return jsonify({ + "error": { + "type": "forbidden", + "message": "Saving workflow is disabled in prod mode" + } + }), 403 diff --git a/src/code/agent/routes/routes.py b/src/code/agent/routes/routes.py index b590a8d..924318d 100644 --- a/src/code/agent/routes/routes.py +++ b/src/code/agent/routes/routes.py @@ -1,18 +1,17 @@ import json -import logging -import threading +import os import traceback -import requests -import websocket -from flask import Flask, request, jsonify, Response +from flask import Flask, jsonify, request, Response from flask_sock import Sock +import requests import constants from exceptions.exceptions import CustomError -from services.management_service import BackendStatus, ManagementService, Action +from services.management_service import ManagementService, Action, BackendStatus from .management_routes import ManagementRoutes from .serverless_api_routes import ServerlessApiRoutes +from .cpu_routes import CpuRoutes from services.serverlessapi.serverless_api_service import ServerlessApiService @@ -24,16 +23,22 @@ def __init__(self): import logging log = logging.getLogger('werkzeug') log.setLevel(logging.ERROR) + def setup_routes(self): - + # 管控API management = ManagementRoutes() management.register(self.app) + # ServerlessAPI if constants.BACKEND_TYPE == constants.TYPE_COMFYUI: serverless_api = ServerlessApiRoutes() serverless_api.register(self.app) + if constants.COMFYUI_MODE == "cpu": + cpu_router = CpuRoutes() + cpu_router.register(self.app) + @self.app.route("/initialize", methods=["POST"]) def initialize(): # See FC docs for all the HTTP headers: https://www.alibabacloud.com/help/doc-detail/132044.htm#common-headers @@ -48,11 +53,16 @@ def initialize(): # API模式需要自动启动comfyui进程 # TODO 防止抛出5xx导致函数计算一直重试产生大量费用 service = ManagementService() - service.start(constants.AUTO_LAUNCH_SNAPSHOT_NAME) + + # 使用环境变量指定的snapshot,默认为latest-dev + snapshot_name = os.environ.get('AUTO_LAUNCH_SNAPSHOT_NAME', 'latest-dev') + print(f"Initializing function with ComfyUI mode: {constants.COMFYUI_MODE}, snapshot: {snapshot_name}") + service.start(snapshot_name, nodes_map={}) if ( constants.PREWARM_PROMPT and constants.BACKEND_TYPE == constants.TYPE_COMFYUI + and constants.COMFYUI_MODE == "gpu" ): try: print("prewarm models") @@ -87,53 +97,6 @@ def pre_stop(): print("FC PreStop End RequestId: " + request_id) return "OK" - @self._sock.route('/') - def proxy_ws(ws, path): - backend_status = management.service.status - if backend_status not in (BackendStatus.RUNNING, BackendStatus.SAVING): - return jsonify({ - "status": "failed", - "message": "Please start your comfyui/sd service first" - }), 500 - - # print(f"Forwarding websocket request for path: {path}") - target_url = f"ws://{constants.APP_HOST}/{path}" - - def on_message(_, message): - try: - ws.send(message) - except Exception as ex: - logging.error(f"Error sending message to client: {ex}") - - def on_error(_, error): - logging.error(f"WebSocket client error: {error}") - - def on_close(_, close_status_code, close_msg): - logging.info(f"WebSocket connection closed: {close_status_code} - {close_msg}") - - ws_client = websocket.WebSocketApp( - target_url, - on_message=on_message, - on_error=on_error, - on_close=on_close - ) - - ws_thread = threading.Thread(target=ws_client.run_forever) - ws_thread.daemon = True - ws_thread.start() - - from services.process.websocket.websocket_manager import ws_manager - try: - ws_manager.add_connection(ws) - while True: - message = ws.receive() - ws_client.send(message) - except Exception as e: - print(f"ws event occurs: {e}") - finally: - ws_manager.remove_connection(ws) - ws_client.close() - @self.app.route("/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) @self.app.route("/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) def proxy(path=""): diff --git a/src/code/agent/routes/serverless_api_routes.py b/src/code/agent/routes/serverless_api_routes.py index 4513cb8..1b2b118 100644 --- a/src/code/agent/routes/serverless_api_routes.py +++ b/src/code/agent/routes/serverless_api_routes.py @@ -1,8 +1,14 @@ import json import threading +import traceback from queue import Queue from traceback import print_exception +from flask import Blueprint, Flask, request, Response, copy_current_request_context +from flask_cors import cross_origin +from flask_sock import Sock +from simple_websocket import Server + import constants from utils.bool import is_true from services.serverlessapi.serverless_api_service import ( @@ -10,11 +16,6 @@ ServerlessApiService, ) -from flask_sock import Sock -from simple_websocket import Server -from flask import Blueprint, Flask, request, Response, copy_current_request_context -from flask_cors import cross_origin - class ServerlessApiRoutes: HEADER_KEY_TASK_ID_PRIMARY = "x-fc-async-task-id" @@ -22,16 +23,31 @@ class ServerlessApiRoutes: def __init__(self): self.bp = Blueprint("serverless_api", __name__, url_prefix="/api/serverless") - self.service = ServerlessApiService() self.sock = Sock() self.sock.bp = self.bp - self.setup_routes() def register(self, app: Flask): app.register_blueprint(self.bp) + def _extract_task_id(self): + """ + 从请求头中提取task_id + + Returns: + tuple: (task_id, task_id_source) + """ + async_task_id = request.headers.get(ServerlessApiRoutes.HEADER_KEY_TASK_ID_PRIMARY) + fc_request_id = request.headers.get(ServerlessApiRoutes.HEADER_KEY_TASK_ID_SECONDARY) + + if async_task_id: + return async_task_id, "x-fc-async-task-id" + elif fc_request_id: + return fc_request_id, "x-fc-request-id" + else: + return None, "none" + def setup_routes(self): @self.bp.get("/status") @@ -47,116 +63,119 @@ def get_status(): return self.service.get_status_from_store(task_id) - @self.bp.post("/run") - @cross_origin() - def run_http(): - """ - 出图接口,http 协议 - - HTTP POST `/api/serverelss/run` - - Query: - - `stream`: 流式响应(响应 ComfyUI 原生返回的状态信息,并在最后附加非流式的结果) - - `output_base64`: 最终输出结果中,将图片以 base64 形式返回 - - `output_oss`: 输出结果图片至 OSS,并在值中返回 OSS 的 path - - Header: - - `x-serverless-api-task-id`: 指定一个 task id,用于异步获取任务状态,不传输时不会持久化状态 - - Body: - JSON body,内容可参考 ComfyUI 原生 prompt 接口 - 针对如下部分进行优化 - - LoadImage 节点支持 base64 图片、http url 图片 - - KSampler seed 为 -1 时,支持自动生成随机数 + if constants.COMFYUI_MODE != "cpu": + @self.bp.post("/run") + @cross_origin() + def run_http(): + """ + 出图接口,http 协议 + + HTTP POST `/api/serverelss/run` + + Query: + - `stream`: 流式响应(响应 ComfyUI 原生返回的状态信息,并在最后附加非流式的结果) + - `output_base64`: 最终输出结果中,将图片以 base64 形式返回 + - `output_oss`: 输出结果图片至 OSS,并在值中返回 OSS 的 path + + Header: + - `x-serverless-api-task-id`: 指定一个 task id,用于异步获取任务状态,不传输时不会持久化状态 + + Body: + JSON body,内容可参考 ComfyUI 原生 prompt 接口 + 针对如下部分进行优化 + - LoadImage 节点支持 base64 图片、http url 图片 + - KSampler seed 为 -1 时,支持自动生成随机数 + + 返回值: + 输出的图片数组 + """ + body = request.get_json() + stream = is_true(request.args.get("stream")) + output_base64 = is_true(request.args.get("output_base64")) + output_oss = is_true(request.args.get("output_oss")) + + task_id, task_id_source = self._extract_task_id() + print(f"[GPU ServerlessApi] Task ID extracted: '{task_id}' from {task_id_source}") - 返回值: - 输出的图片数组 - """ - body = request.get_json() - stream = is_true(request.args.get("stream")) - output_base64 = is_true(request.args.get("output_base64")) - output_oss = is_true(request.args.get("output_oss")) - task_id = request.headers.get( - ServerlessApiRoutes.HEADER_KEY_TASK_ID_PRIMARY, - request.headers.get( - ServerlessApiRoutes.HEADER_KEY_TASK_ID_SECONDARY, "" - ), - ) - - if not stream: - try: - return self.service.run( - body, - output_base64=output_base64, - output_oss=output_oss, - task_id=task_id, - ) - except ComfyUIException as e: - print_exception(e) - return e.response(), 500 - except Exception as e: - print_exception(e) - return { - "type": "error", - "error_code": constants.ERROR_CODE.UNCLASSIFY.value, - "error_message": str(e), - }, 500 - - else: - q = Queue() - - def do_streaming(content): - """ - 将 callback 的数据推送到队列中 - """ - q.put(content) - - def output_stream(): - """ - 从队列将数据以流式返回给客户端 - """ - while True: - item = q.get(True) - yield f"data: {item if type(item) == str else json.dumps(item)}\n\n" - - if not type(item) == str: - return - - @copy_current_request_context - def run_prompt_task(): - """ - 单独线程需要执行的任务 - """ + if not stream: try: - result = self.service.run( + return self.service.run( body, output_base64=output_base64, output_oss=output_oss, - callback=do_streaming, task_id=task_id, ) except ComfyUIException as e: print_exception(e) return e.response(), 500 except Exception as e: + error_msg = f"Failed to execute prompt: {str(e)}" print_exception(e) + print(f"[ServerlessApi] {error_msg}\nStacktrace:\n{traceback.format_exc()}") return { "type": "error", "error_code": constants.ERROR_CODE.UNCLASSIFY.value, - "error_message": str(e), + "error_message": error_msg, }, 500 - # 推送最终结果 - q.put(result) + else: + q = Queue() + + def do_streaming(content): + """ + 将 callback 的数据推送到队列中 + """ + q.put(content) + + def output_stream(): + """ + 从队列将数据以流式返回给客户端 + """ + while True: + item = q.get(True) + yield f"data: {item if type(item) == str else json.dumps(item)}\n\n" + + if not type(item) == str: + return + + @copy_current_request_context + def run_prompt_task(): + """ + 单独线程需要执行的任务 + """ + try: + result = self.service.run( + body, + output_base64=output_base64, + output_oss=output_oss, + callback=do_streaming, + task_id=task_id, + ) + # 推送最终结果 + q.put(result) + except ComfyUIException as e: + print_exception(e) + error_response = e.response() + q.put(error_response) + except Exception as e: + error_msg = f"Failed to execute prompt in stream mode: {str(e)}" + print_exception(e) + print(f"[ServerlessApi] {error_msg}\nStacktrace:\n{traceback.format_exc()}") + error_response = { + "type": "error", + "error_code": constants.ERROR_CODE.UNCLASSIFY.value, + "error_message": error_msg, + } + q.put(error_response) - # 出图的流程是同步执行的,需要在单独线程执行,不阻塞 stream 的流程 - threading.Thread(target=run_prompt_task).start() + # 出图的流程是同步执行的,需要在单独线程执行,不阻塞 stream 的流程 + threading.Thread(target=run_prompt_task).start() - return Response( - output_stream(), - status=200, - content_type="text/event-stream", - ) + return Response( + output_stream(), + status=200, + content_type="text/event-stream", + ) @self.sock.route("/ws") def run_ws(ws: Server): @@ -181,12 +200,9 @@ def run_ws(ws: Server): try: output_base64 = is_true(request.args.get("output_base64")) output_oss = is_true(request.args.get("output_oss")) - task_id = request.headers.get( - ServerlessApiRoutes.HEADER_KEY_TASK_ID_PRIMARY, - request.headers.get( - ServerlessApiRoutes.HEADER_KEY_TASK_ID_SECONDARY, "" - ), - ) + + task_id, task_id_source = self._extract_task_id() + print(f"[GPU ServerlessApi WS] Task ID extracted: '{task_id}' from {task_id_source}") # 获取第一个 message 作为输入的 prompt data = ws.receive() @@ -208,21 +224,23 @@ def callback(msg): print_exception(e) try: ws.send(json.dumps(e.response())) - except: - pass + except Exception as send_error: + print(f"[GPU ServerlessApi WS] Failed to send error response: {send_error}") except Exception as e: print_exception(e) - + error_msg = f"Unexpected error in WebSocket handler: {str(e)}" + print(f"[GPU ServerlessApi WS] {error_msg}\nStacktrace:\n{traceback.format_exc()}") + try: ws.send( json.dumps( { "type": "error", "error_code": constants.ERROR_CODE.UNCLASSIFY.value, - "error_message": str(e), + "error_message": error_msg, } ) ) - except: - pass + except Exception as send_error: + print(f"[GPU ServerlessApi WS] Failed to send error response: {send_error}") return diff --git a/src/code/agent/services/gateway/__init__.py b/src/code/agent/services/gateway/__init__.py new file mode 100644 index 0000000..9bcacba --- /dev/null +++ b/src/code/agent/services/gateway/__init__.py @@ -0,0 +1,17 @@ +from .gateways import CpuGatewayService, HistoryGatewayService +from .status import StatusPoller +from .queue import ( + TaskStatus, + TaskRequest, + TaskQueue, + get_task_queue +) + +__all__ = [ + 'CpuGatewayService', + 'HistoryGatewayService', + 'StatusPoller', + 'TaskStatus', + 'TaskQueue', + 'get_task_queue' +] diff --git a/src/code/agent/services/gateway/gateways/__init__.py b/src/code/agent/services/gateway/gateways/__init__.py new file mode 100644 index 0000000..f63d051 --- /dev/null +++ b/src/code/agent/services/gateway/gateways/__init__.py @@ -0,0 +1,8 @@ +from .cpu_gateway import CpuGatewayService +from .history_gateway import HistoryGatewayService + +__all__ = [ + 'CpuGatewayService', + 'HistoryGatewayService' +] + diff --git a/src/code/agent/services/gateway/gateways/cpu_gateway.py b/src/code/agent/services/gateway/gateways/cpu_gateway.py new file mode 100644 index 0000000..5a8c553 --- /dev/null +++ b/src/code/agent/services/gateway/gateways/cpu_gateway.py @@ -0,0 +1,580 @@ +""" +CPU Gateway Service +处理 CPU 函数作为网关转发请求到 GPU 函数的逻辑 +""" +import json +import time +import requests +from datetime import datetime +from flask import request, jsonify, Response + +import constants +from utils.logger import log + + +class CpuGatewayService: + """CPU函数网关服务,负责转发请求到GPU函数""" + + def __init__(self): + self.gpu_function_url = constants.GPU_FUNCTION_URL + # 获取任务队列单例 + from services.gateway import get_task_queue + self.task_queue = get_task_queue() + + def handle_queue_get_request(self): + """ + 处理 GET /api/queue 请求 获取队列状态 + + Returns: + Flask response + """ + import time as queue_time + import threading + import queue as thread_queue + from services.gateway.queue.task_models import TaskStatus + + request_start = queue_time.time() + + log("DEBUG", f"Handling GET /api/queue request with timeout protection") + + try: + # 超时保护:如果获取任务列表超过1秒,返回空队列 + result_queue = thread_queue.Queue() + exception_queue = thread_queue.Queue() + + def fetch_tasks_with_timeout(): + try: + # 获取任务列表 + all_tasks = self.task_queue.get_all_tasks() + result_queue.put(all_tasks) + except Exception as e: + exception_queue.put(e) + + fetch_thread = threading.Thread(target=fetch_tasks_with_timeout, daemon=True) + fetch_thread.start() + fetch_thread.join(timeout=1.0) # 1秒超时 + + if not exception_queue.empty(): + raise exception_queue.get() + + if result_queue.empty(): + log("WARNING", f"Queue request timed out after 1 second, returning empty queue") + return jsonify({ + "queue_running": [], + "queue_pending": [], + "_warning": "Request timed out, showing empty queue to prevent blocking" + }) + + all_tasks = result_queue.get() + fetch_time = queue_time.time() - request_start + + except Exception as e: + log("ERROR", f"Error fetching tasks for queue request: {e}") + return jsonify({ + "queue_running": [], + "queue_pending": [], + "_error": "Failed to fetch queue status" + }) + + comfyui_queue_info = { + "queue_running": [], # 正在运行的任务列表 + "queue_pending": [] # 等待中的任务列表 + } + + convert_start = queue_time.time() + + for task in all_tasks: + # 超时保护:如果处理时间过长,停止处理 + if queue_time.time() - request_start > 2.0: # 总处理超过2秒 + log("WARNING", f"Queue processing timeout, stopping after {queue_time.time() - request_start:.2f}s") + comfyui_queue_info["_truncated"] = True + break + + # 根据任务状态分类 + if task.status == TaskStatus.PROCESSING: + # 构造ComfyUI兼容的任务信息格式(简化版) + task_info = [ + 1, # number - 任务优先级 + task.task_id, # prompt_id + task.prompt or {}, # prompt - 避免None导致序列化失败 + {"client_id": task.client_id} if task.client_id else {}, # extra_data + [] # outputs_to_execute + ] + comfyui_queue_info["queue_running"].append(task_info) + + elif task.status in [TaskStatus.PENDING, TaskStatus.SUBMITTED]: + task_info = [ + 1, # number - 任务优先级 + task.task_id, # prompt_id + task.prompt or {}, # prompt - 避免None导致序列化失败 + {"client_id": task.client_id} if task.client_id else {}, # extra_data + [] # outputs_to_execute + ] + comfyui_queue_info["queue_pending"].append(task_info) + + total_time = queue_time.time() - request_start + convert_time = queue_time.time() - convert_start + + log("DEBUG", f"Queue status: {len(comfyui_queue_info['queue_running'])} running, " + f"{len(comfyui_queue_info['queue_pending'])} pending " + f"(fetch: {fetch_time:.2f}s, convert: {convert_time:.2f}s, total: {total_time:.2f}s)") + + return jsonify(comfyui_queue_info) + + def handle_queue_post_request(self): + """ + 处理 POST /api/queue 请求 + 队列管理(清空/删除任务) + + Returns: + Flask response + """ + log("DEBUG", f"Handling POST /api/queue request") + + request_data = request.get_json() or {} + + if "clear" in request_data and request_data["clear"]: + # 清空队列 + log("INFO", f"Clearing task queue") + + cleared_count = self.task_queue.clear_queue() + log("INFO", f"Cleared {cleared_count} tasks from queue") + + return Response(status=200) + + elif "delete" in request_data: + # 删除指定任务 + # 注意:由于函数计算异步调用无法真正取消已在GPU执行的任务, + # 这里只能删除 PENDING/SUBMITTED 状态的任务, + # PROCESSING 状态的任务无法取消,会继续执行直到完成。 + to_delete = request_data.get("delete", []) + log("INFO", f"Deleting tasks: {to_delete}") + + deleted_count = 0 + failed_tasks = [] + for task_id in to_delete: + cancel_result = self.task_queue.cancel_task(task_id) + if cancel_result: + deleted_count += 1 + log("DEBUG", f"Deleted task: {task_id}") + else: + failed_tasks.append(task_id) + log("WARNING", f"Failed to delete task (not found or already running): {task_id}") + + log("INFO", f"Deleted {deleted_count} tasks from queue") + + if failed_tasks: + log("WARNING", f"{len(failed_tasks)} tasks could not be cancelled (already running on GPU): {failed_tasks}") + + return Response(status=200) + + else: + # 无效的队列操作请求 + return jsonify({ + "error": { + "type": "invalid_request_error", + "message": "Invalid queue operation. Supported operations: clear, delete" + } + }), 400 + + def _forward_to_gpu_async(self, + api_type: str, + prompt: dict, + client_id: str, + task_id_prefix: str, + task_id_headers: list, + forward_by_name: str, + trace_prefix: str): + """ + 通用的GPU异步转发逻辑 + + Args: + api_type: API类型 ('prompt' 或 'serverless') + prompt: 工作流定义 + client_id: 客户端ID + task_id_prefix: 任务ID前缀 ('prompt_' 或 'serverless_') + task_id_headers: 优先检查的header列表 + forward_by_name: 转发来源名称 + trace_prefix: trace ID前缀 + + Returns: + tuple: (task_id, response_or_error) + - task_id: 成功时返回任务ID,失败时返回None + - response_or_error: 成功时返回response对象,失败时返回(jsonify(...), status_code) + """ + import time as _ts + trace_id = f"{trace_prefix}_{int(_ts.time()*1000)}" + req_start = _ts.time() + + try: + # 检查GPU URL配置 + if not self.gpu_function_url: + log("ERROR", f"[{trace_prefix}][{trace_id}] GPU_FUNCTION_URL not configured") + return None, (500, "configuration_error", "GPU_FUNCTION_URL not configured for CPU mode") + + # 生成任务ID + task_id = None + task_id_source = "generated" + + for header_name in task_id_headers: + header_value = request.headers.get(header_name) + if header_value: + task_id = header_value + task_id_source = header_name + break + + if not task_id: + task_id = f"{task_id_prefix}{constants.INSTANCE_ID}_{int(_ts.time() * 1000)}" + + log("DEBUG", f"[{trace_prefix}][{trace_id}] Generated task_id: {task_id} from {task_id_source}, client_id={client_id}") + + # 生命周期日志:接收任务 + nodes_count = len(prompt) if isinstance(prompt, dict) else 'unknown' + log("INFO", f"[TaskLifecycle][{task_id}][RECEIVED] Task received, client_id={client_id}, nodes={nodes_count}") + + # 将任务添加到队列(仅用于跟踪) + try: + output_base64 = request.args.get("output_base64", "false").lower() == "true" + output_oss = request.args.get("output_oss", "false").lower() == "true" + + t_submit_start = _ts.time() + submitted_task_id = self.task_queue.submit_task( + prompt=prompt, + client_id=client_id, + task_id=task_id, + output_base64=output_base64, + output_oss=output_oss + ) + + log("DEBUG", f"[{trace_prefix}][{trace_id}] Task {submitted_task_id} added to queue (enqueue_cost_ms={(_ts.time()-t_submit_start)*1000:.1f})") + log("INFO", f"[TaskLifecycle][{task_id}][QUEUED] Task queued, enqueue_time_ms={(_ts.time()-t_submit_start)*1000:.1f}") + + if client_id: + self.task_queue.associate_task_with_client_id(submitted_task_id, client_id) + except Exception as e: + import traceback as _tb + log("WARNING", f"[{trace_prefix}][{trace_id}] Failed to add task to queue: {e}\n{_tb.format_exc()}") + + # 构造GPU URL和headers + gpu_url = f"{self.gpu_function_url.rstrip('/')}/api/serverless/run" + + forward_headers = { + k: v for k, v in request.headers.items() + if k.lower() not in ['host', 'content-length'] + } + forward_headers['X-Forwarded-By'] = forward_by_name + forward_headers['X-Task-ID'] = task_id + forward_headers['x-fc-async-task-id'] = task_id + forward_headers['x-fc-request-id'] = task_id + forward_headers['x-fc-trace-id'] = task_id + forward_headers['x-fc-invocation-type'] = 'Async' + + log("DEBUG", f"[{trace_prefix}][{trace_id}] Async forwarding to GPU: url={gpu_url}, task_id={task_id}") + log("INFO", f"[TaskLifecycle][{task_id}][FORWARDING] Forwarding to GPU function") + + # 转发请求到GPU + t_post_start = _ts.time() + resp = requests.post( + gpu_url, + json=prompt, + headers=forward_headers, + params=request.args, + timeout=30 + ) + t_post_cost = (_ts.time() - t_post_start) * 1000 + + log("DEBUG", f"[{trace_prefix}][{trace_id}] GPU responded: status={resp.status_code}, cost_ms={t_post_cost:.1f}") + log("INFO", f"[TaskLifecycle][{task_id}][FORWARDED] GPU forward completed, status={resp.status_code}, forward_time_ms={t_post_cost:.1f}") + + # 更新任务状态 + try: + from services.gateway.queue.task_models import TaskStatus + + if resp.status_code == 202: + self.task_queue.update_task_status(task_id, TaskStatus.PROCESSING) + log("DEBUG", f"[{trace_prefix}] Task {task_id} marked as processing") + else: + self.task_queue.update_task_status(task_id, TaskStatus.FAILED) + log("WARNING", f"[{trace_prefix}] Task {task_id} marked as failed") + except Exception as e: + import traceback as _tb + log("WARNING", f"[{trace_prefix}][{trace_id}] Failed to update task status: {e}\n{_tb.format_exc()}") + + total_cost = (_ts.time() - req_start) * 1000 + log("DEBUG", f"[{trace_prefix}][{trace_id}] Total cost: {total_cost:.1f}ms") + + # 返回response和task_id + if resp.status_code == 202: + return task_id, resp + else: + return None, (500, "async_invocation_error", f"Failed to invoke GPU function asynchronously: HTTP {resp.status_code}") + + except Exception as e: + import traceback + import requests as _rq + is_timeout = isinstance(e, _rq.exceptions.Timeout) + error_msg = f"Failed to forward request to GPU function: {str(e)} (timeout={is_timeout})" + total_cost = (_ts.time() - req_start) * 1000 if 'req_start' in locals() else -1 + log("ERROR", f"[{trace_prefix}][{trace_id}] {error_msg}\nStacktrace:\n{traceback.format_exc()}\nElapsed_ms={total_cost:.1f}") + return None, (500, "gpu_forward_error", error_msg) + + def handle_prompt_request_async(self): + """ + 处理 POST /api/prompt 请求(异步转发到GPU函数) + + Returns: + tuple: (response_data, status_code) + """ + # 获取请求数据 + request_data = request.get_json() + if not request_data: + return jsonify({ + "error": { + "type": "invalid_request_error", + "message": "Request body must be valid JSON" + } + }), 400 + + # 提取 prompt 数据和 client_id + prompt = request_data.get("prompt") + client_id = request_data.get("client_id", "") + + if not prompt: + return jsonify({ + "error": { + "type": "invalid_request_error", + "message": "Missing 'prompt' in request body" + } + }), 400 + + # 调用通用转发逻辑 + task_id, result = self._forward_to_gpu_async( + api_type="prompt", + prompt=prompt, + client_id=client_id, + task_id_prefix="prompt_", + task_id_headers=["x-fc-async-task-id", "x-fc-request-id"], + forward_by_name="CPU-Router-Async", + trace_prefix="cpu_prompt" + ) + + # 处理结果 + if task_id: + # 成功:返回ComfyUI格式 + return jsonify({ + "prompt_id": task_id, + "number": 1, + "node_errors": {} + }) + else: + # 失败:result 是 (status_code, error_type, error_message) + status_code, error_type, error_message = result + return jsonify({ + "error": { + "type": error_type, + "message": error_message + } + }), status_code + + def handle_serverless_run_async(self): + """ + 处理 POST /api/serverless/run 请求(异步转发到GPU函数) + + Returns: + tuple: (response_data, status_code) + """ + # 获取请求数据 - /serverless/run 的请求体直接是 prompt + prompt = request.get_json(force=True, silent=True) + if prompt is None: + return jsonify({ + "type": "error", + "error_code": "invalid_request_error", + "error_message": "Request body must be valid JSON containing ComfyUI workflow definition" + }), 400 + + # 检查 prompt 是否为空 + if not isinstance(prompt, dict) or len(prompt) == 0: + return jsonify({ + "type": "error", + "error_code": "invalid_request_error", + "error_message": "Prompt (workflow definition) cannot be empty. Please provide a valid ComfyUI workflow." + }), 400 + + # serverless API 不需要 client_id,GPU 端会从 WebSocket 自动获取 + client_id = "" + + # 调用通用转发逻辑 + task_id, result = self._forward_to_gpu_async( + api_type="serverless", + prompt=prompt, + client_id=client_id, + task_id_prefix="serverless_", + task_id_headers=["x-fc-async-task-id", "x-fc-request-id"], + forward_by_name="CPU-Router-Serverless-Async", + trace_prefix="cpu_serverless_async" + ) + + # 处理结果 + if task_id: + # 成功:返回Serverless格式 + return jsonify({ + "task_id": task_id, + "status": "pending" + }), 202 + else: + # 失败:result 是 (status_code, error_type, error_message) + status_code, error_type, error_message = result + return jsonify({ + "type": "error", + "error_code": error_type, + "error_message": error_message + }), status_code + + def handle_serverless_run_sync(self): + """ + 处理 POST /api/serverless/run 请求(同步转发到GPU函数) + + 与 handle_prompt_request_async 不同: + - 不使用任务队列 + - 同步等待 GPU 函数返回结果 + - 直接返回结果给客户端 + + Returns: + response: GPU函数的响应 + """ + import time as _ts + import requests + + trace_id = f"cpu_serverless_sync_{int(_ts.time()*1000)}" + req_start = _ts.time() + + log("DEBUG", f"[CPU Gateway Sync][{trace_id}] Processing sync serverless/run request - ENTER") + + try: + if not self.gpu_function_url: + log("ERROR", f"[CPU Gateway Sync][{trace_id}] GPU_FUNCTION_URL not configured") + return jsonify({ + "type": "error", + "error_code": "configuration_error", + "error_message": "GPU_FUNCTION_URL not configured for CPU mode" + }), 500 + + # 获取请求数据 - /serverless/run 的请求体直接是 prompt(ComfyUI 工作流定义) + request_body = request.get_json(force=True, silent=True) + if request_body is None: + log("ERROR", f"[CPU Gateway Sync][{trace_id}] Empty or invalid JSON body") + return jsonify({ + "type": "error", + "error_code": "invalid_request_error", + "error_message": "Request body must be valid JSON containing ComfyUI workflow definition" + }), 400 + + # 检查 prompt 是否为空对象或无效 + if not isinstance(request_body, dict) or len(request_body) == 0: + log("ERROR", f"[CPU Gateway Sync][{trace_id}] Empty prompt (workflow definition is required)") + return jsonify({ + "type": "error", + "error_code": "invalid_request_error", + "error_message": "Prompt (workflow definition) cannot be empty. Please provide a valid ComfyUI workflow." + }), 400 + + # 获取查询参数 + stream = request.args.get("stream", "false").lower() == "true" + output_base64 = request.args.get("output_base64", "false").lower() == "true" + output_oss = request.args.get("output_oss", "false").lower() == "true" + + log("DEBUG", f"[CPU Gateway Sync][{trace_id}] Request params: stream={stream}, output_base64={output_base64}, output_oss={output_oss}") + + # 构造 GPU 函数的完整 URL + gpu_url = f"{self.gpu_function_url.rstrip('/')}/api/serverless/run" + + # 转发请求头(同步调用) + forward_headers = { + k: v for k, v in request.headers.items() + if k.lower() not in ['host', 'content-length'] + } + forward_headers['X-Forwarded-By'] = 'CPU-Router-Sync' + + # 生成或提取 task_id + async_task_id = request.headers.get("x-fc-async-task-id") + fc_request_id = request.headers.get("x-fc-request-id") + x_task_id = request.headers.get("x-serverless-api-task-id") + + if x_task_id: + task_id = x_task_id + task_id_source = "x-serverless-api-task-id" + elif async_task_id: + task_id = async_task_id + task_id_source = "x-fc-async-task-id" + elif fc_request_id: + task_id = fc_request_id + task_id_source = "x-fc-request-id" + else: + task_id = f"sync_{constants.INSTANCE_ID}_{int(_ts.time() * 1000)}" + task_id_source = "generated" + + # 传递 task_id 给 GPU 函数 + if task_id: + forward_headers['x-serverless-api-task-id'] = task_id + + log("DEBUG", f"[CPU Gateway Sync][{trace_id}] Sync forwarding to GPU: url={gpu_url}, task_id={task_id} (from {task_id_source})") + + # 生命周期日志 + log("INFO", f"[TaskLifecycle][{task_id}][SYNC_FORWARDING] Synchronously forwarding to GPU, url={gpu_url}") + + # 同步转发请求到 GPU 函数 + t_post_start = _ts.time() + resp = requests.post( + gpu_url, + json=request_body, + headers=forward_headers, + params=request.args, + timeout=300 # 同步调用设置较长超时时间(5分钟) + ) + t_post_cost = (_ts.time() - t_post_start) * 1000 + + log("DEBUG", f"[CPU Gateway Sync][{trace_id}] GPU sync response received: " + f"status={resp.status_code}, cost_ms={t_post_cost:.1f}, resp_len={len(resp.content)}") + + # 生命周期日志 + log("INFO", f"[TaskLifecycle][{task_id}][SYNC_COMPLETED] Sync call completed, " + f"status={resp.status_code}, total_time_ms={t_post_cost:.1f}") + + total_cost = (_ts.time() - req_start) * 1000 + log("DEBUG", f"[CPU Gateway Sync][{trace_id}] RETURN to client: status={resp.status_code}, total_cost_ms={total_cost:.1f}") + + # 直接返回 GPU 函数的响应 + try: + response_data = resp.json() + return jsonify(response_data), resp.status_code + except: + # 如果响应不是 JSON,直接返回原始内容 + return Response( + resp.content, + status=resp.status_code, + content_type=resp.headers.get('Content-Type', 'application/octet-stream') + ) + + except requests.exceptions.Timeout: + total_cost = (_ts.time() - req_start) * 1000 + error_msg = f"GPU function request timed out after {total_cost:.1f}ms" + log("ERROR", f"[CPU Gateway Sync][{trace_id}] {error_msg}") + + return jsonify({ + "type": "error", + "error_code": "timeout_error", + "error_message": error_msg + }), 504 + + except Exception as e: + import traceback + total_cost = (_ts.time() - req_start) * 1000 if 'req_start' in locals() else -1 + error_msg = f"Failed to forward sync request to GPU function: {str(e)}" + log("ERROR", f"[CPU Gateway Sync][{trace_id}] {error_msg}\nStacktrace:\n{traceback.format_exc()}\nElapsed_ms={total_cost:.1f}") + + return jsonify({ + "type": "error", + "error_code": "gpu_forward_error", + "error_message": error_msg + }), 500 + diff --git a/src/code/agent/services/gateway/gateways/history_gateway.py b/src/code/agent/services/gateway/gateways/history_gateway.py new file mode 100644 index 0000000..c5b5b89 --- /dev/null +++ b/src/code/agent/services/gateway/gateways/history_gateway.py @@ -0,0 +1,373 @@ +""" +History Gateway Service +处理 ComfyUI history API 相关的逻辑 +""" +import os +import json +import glob +import traceback +from collections import OrderedDict +from flask import request, jsonify, Response + +import constants +from utils.logger import log + + +class HistoryGatewayService: + + def __init__(self): + # 直接访问存储路径,不需要依赖 ServerlessApiService + self.storage_path = f"{constants.MNT_DIR}/output/serverless_api" + + def handle_history_request(self, path): + """ + 处理 history 相关请求 + + Args: + path: 请求路径 + + Returns: + Flask response + """ + try: + log("DEBUG", f"Processing history request from persistent storage") + + if path == "api/history" and request.method == "GET": + return self._handle_get_all_history() + + elif path.startswith("api/history/") and request.method == "GET": + # GET /api/history/{prompt_id} - 获取特定任务的历史记录 + prompt_id = path.split("/")[-1] + return self._handle_get_history_by_id(prompt_id) + + elif path == "api/history" and request.method == "POST": + # POST /api/history - 清理历史记录 + return self._handle_clear_history() + + # 其他情况返回空结果 + return jsonify({}) + + except Exception as e: + error_msg = f"Enhanced history processing failed: {str(e)}" + log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}") + + # 出错时返回空历史记录而不是代理到 ComfyUI + return jsonify({}) + + def _handle_get_all_history(self): + """ + 处理获取所有历史记录的请求 + """ + log("DEBUG", f"Retrieving history from persistent storage") + + # 获取limit参数 + limit_param = request.args.get('limit') + limit = None + if limit_param: + try: + limit = int(limit_param) + if limit <= 0: + limit = None # 无效值时不限制 + except ValueError: + limit = None # 无效值时不限制 + + # 从持久化存储获取历史记录 + all_history = self._get_all_persisted_history(limit=limit) + + if limit: + log("DEBUG", f"Found {len(all_history)} history items (limited to {limit} most recent)") + else: + log("DEBUG", f"Found {len(all_history)} history items (no limit)") + + return jsonify(all_history) + + def _handle_get_history_by_id(self, prompt_id): + """ + 处理获取特定任务历史记录的请求 + """ + log("DEBUG", f"Retrieving history for prompt_id: {prompt_id}") + + history_data = self._get_persisted_history_by_prompt_id(prompt_id) + + if history_data: + log("DEBUG", f"Found persisted history for prompt_id: {prompt_id}") + return jsonify(history_data) + else: + log("DEBUG", f"No persisted history found for prompt_id: {prompt_id}") + return jsonify({}) + + def _handle_clear_history(self): + """ + 处理清理历史记录的请求 + + 注意:此方法目前不清理持久化存储的历史, + 只返回成功响应。如需清理持久化数据, + 需要单独实现。 + """ + request_data = request.get_json() or {} + + if request_data.get("clear"): + log("INFO", f"Clear history requested (persistent storage not affected)") + + # 持久化存储的历史不清理 + # 如果需要清理,可以删除 self.storage_path 下的文件 + + return jsonify({"status": "success", "message": "History clear acknowledged (persistent storage preserved)"}) + + return jsonify({}) + + def _get_all_persisted_history(self, limit=None): + """ + 从持久化存储获取所有历史记录,按生成时间倒序排序(最新的在前) + + Args: + limit: 限制返回的记录数量,None表示返回所有记录 + """ + try: + all_history = OrderedDict() # 使用有序字典保持排序 + history_with_timestamps = [] # 用于排序的临时列表 + + if os.path.exists(self.storage_path): + # 获取所有存储文件 + task_files = glob.glob(os.path.join(self.storage_path, "*")) + log("DEBUG", f"Found {len(task_files)} task files in storage") + + for task_file in task_files: + if os.path.isfile(task_file): + task_id = os.path.basename(task_file) + try: + status_data = self._get_status_from_file(task_id) + file_mtime = os.path.getmtime(task_file) + history_item = self._convert_status_to_history_item(status_data, task_id, file_mtime) + if history_item: + prompt_id = list(history_item.keys())[0] + + # 使用prompt中的序号作为排序依据(ComfyUI原生的排序方式) + # prompt格式:[number, prompt_id, {...}, {...}, [...]] + sort_number = None + if history_item[prompt_id].get('prompt'): + prompt_array = history_item[prompt_id]['prompt'] + if isinstance(prompt_array, list) and len(prompt_array) > 0: + sort_number = prompt_array[0] # 第一个元素是序号 + + # 降级方案:使用文件修改时间 + if sort_number is None: + sort_number = file_mtime + + timestamp = sort_number + + history_with_timestamps.append({ + 'prompt_id': prompt_id, + 'history_item': history_item, # 保持完整结构(包含prompt_id key) + 'timestamp': timestamp + }) + + except Exception as e: + log("WARNING", f"Error processing task_id {task_id}: {e}") + continue + + # 按时间戳倒序排序(最新的在前) + history_with_timestamps.sort(key=lambda x: x['timestamp'], reverse=True) + log("DEBUG", f"Sorted {len(history_with_timestamps)} history items by timestamp (newest first)") + + # 调试信息:显示排序结果的前几条 + if history_with_timestamps: + from datetime import datetime + log("DEBUG", f"Sort order preview:") + for i, item in enumerate(history_with_timestamps[:5]): # 显示前5条 + timestamp_str = datetime.fromtimestamp(item['timestamp']).strftime('%Y-%m-%d %H:%M:%S') + log("DEBUG", f" {i+1}. {item['prompt_id'][:12]}... - {timestamp_str}") + + # 应用limit限制 + if limit is not None and limit > 0: + history_with_timestamps = history_with_timestamps[:limit] + log("DEBUG", f"Limited results to {limit} most recent items") + + # 构建最终的历史记录字典,按倒序赋值序号(最新的序号最大) + total_count = len(history_with_timestamps) + + for idx, item in enumerate(history_with_timestamps): + prompt_id = item['prompt_id'] + history_item = item['history_item'] # 完整结构: {prompt_id: {prompt, outputs, status}} + + # 修改prompt数组的第一个元素为正确的序号(最新的最大) + # 前端按 queueIndex 降序排序: sort((a, b) => b.queueIndex - a.queueIndex) + sequence_number = total_count - idx + if prompt_id in history_item and 'prompt' in history_item[prompt_id]: + prompt_array = history_item[prompt_id]['prompt'] + if isinstance(prompt_array, list) and len(prompt_array) > 0: + prompt_array[0] = sequence_number # 更新序号 + + # 提取内层字典存入all_history + all_history[prompt_id] = history_item[prompt_id] + + return all_history + + except Exception as e: + log("ERROR", f"Error getting all persisted history: {e}") + return {} + + def _get_persisted_history_by_prompt_id(self, prompt_id): + """ + 根据 prompt_id 从持久化存储获取历史记录 + + 首先尝试可能的 task_id 路径,如果失败则遍历所有文件查找 + """ + try: + # 首先尝试可能的 task_id 路径(性能优化) + possible_task_ids = [ + prompt_id, + f"prompt_{constants.INSTANCE_ID}_{prompt_id}", + ] + + for task_id in possible_task_ids: + file_path = os.path.join(self.storage_path, task_id) + if os.path.exists(file_path) and os.path.isfile(file_path): + try: + status_data = self._get_status_from_file(task_id) + # 验证是否包含该 prompt_id + for status in status_data: + if (status.get("type") == "serverless_api" and + status.get("data", {}).get("prompt_id") == prompt_id): + return self._convert_status_to_history_item(status_data, task_id) + except Exception as e: + continue + + # 如果直接路径查找失败,遍历所有文件查找(降级方案) + if os.path.exists(self.storage_path): + task_files = glob.glob(os.path.join(self.storage_path, "*")) + for task_file in task_files: + if os.path.isfile(task_file): + task_id = os.path.basename(task_file) + # 跳过已经尝试过的 task_id + if task_id in possible_task_ids: + continue + try: + status_data = self._get_status_from_file(task_id) + for status in status_data: + if (status.get("type") == "serverless_api" and + status.get("data", {}).get("prompt_id") == prompt_id): + return self._convert_status_to_history_item(status_data, task_id) + except Exception as e: + continue + + return None + + except Exception as e: + log("ERROR", f"Error getting persisted history for prompt_id {prompt_id}: {e}") + return None + + def _convert_status_to_history_item(self, status_data, task_id, file_mtime=None): + """ + 将 ServerlessApiService 的状态数据转换为 ComfyUI 完全兼容的历史记录格式 + 根据 ComfyUI 原生格式: + { + "prompt_id": { + "prompt": [number, "prompt_id", {prompt_data}, {extra_data}, [outputs_to_execute]], + "outputs": {"node_id": {"output_type": [output_data, ...]}}, + "status": {"status_str": "success", "completed": true, "messages": []} + } + } + """ + try: + if not status_data: + return None + + # 查找最终结果 + final_result = None + for status in reversed(status_data): # 从最新的开始查找 + if status.get("type") == "serverless_api": + final_result = status + break + + if not final_result: + return None + + data = final_result.get("data", {}) + prompt_id = data.get("prompt_id", task_id) + results = data.get("results", []) + execution_time = data.get("execution_time") # 获取执行时间 + + # 构造 ComfyUI 兼容的输出格式 + outputs = {} + for result in results: + node_id = result.get("node_id", "unknown") + if node_id not in outputs: + outputs[node_id] = {} + + output_data = result.get("output", {}) + output_type = output_data.get("type", "images") + + if output_type not in outputs[node_id]: + outputs[node_id][output_type] = [] + + # 构造符合 ComfyUI 原生格式的输出数据 + raw_data = output_data.get("raw", {}) + comfy_output = { + "filename": raw_data.get("filename", ""), + "subfolder": raw_data.get("subfolder", ""), + "type": raw_data.get("type", "output") + } + + # 只在 debug 模式下添加增强信息,保持最大兼容性 + # 除非用户明确需要增强信息。在正常情况下不添加任何非标准字段 + outputs[node_id][output_type].append(comfy_output) + + # 构造符合 ComfyUI 格式的 prompt 数据结构 + # ComfyUI 原生格式: [number, prompt_id, {prompt_data}, {extra_data}, [outputs_to_execute]] + # 使用文件修改时间作为初始序号,后续会被正确的序号覆盖 + # 注意:保留浮点数精度,后续在排序时会使用 + initial_number = file_mtime if file_mtime else 1.0 + prompt_structure = [ + initial_number, # number - 使用时间戳作为初始值 + prompt_id, # prompt_id + {}, # prompt_data - 工作流定义(从持久化中可能不完整) + {}, # extra_data - 额外数据 + [] # outputs_to_execute - 要执行的输出节点 + ] + + # 构造 status 对象,包含执行时间(如果有) + status_obj = { + "status_str": "success", + "completed": True, + "messages": [] + } + + # 添加执行时间到 status 中(与原生 ComfyUI 兼容) + if execution_time is not None: + status_obj["execution_time"] = execution_time + + return { + prompt_id: { + "prompt": prompt_structure, + "outputs": outputs, + "status": status_obj + } + } + + except Exception as e: + log("ERROR", f"Error converting status to history: {e}") + return None + + def _get_status_from_file(self, task_id: str): + """ + 从文件读取任务状态 + + Args: + task_id: 任务ID + + Returns: + list: 状态历史列表 + """ + try: + file_path = os.path.join(self.storage_path, task_id) + if not os.path.exists(file_path): + return [] + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + return [json.loads(line) for line in content.split("\n") if line.strip()] + except Exception as e: + log("ERROR", f"Error reading status from file {task_id}: {e}") + return [] + diff --git a/src/code/agent/services/gateway/queue/__init__.py b/src/code/agent/services/gateway/queue/__init__.py new file mode 100644 index 0000000..a9f6062 --- /dev/null +++ b/src/code/agent/services/gateway/queue/__init__.py @@ -0,0 +1,10 @@ +from .task_models import TaskStatus, TaskRequest +from .task_queue import TaskQueue, get_task_queue + +__all__ = [ + 'TaskStatus', + 'TaskRequest', + 'TaskQueue', + 'get_task_queue' +] + diff --git a/src/code/agent/services/gateway/queue/task_models.py b/src/code/agent/services/gateway/queue/task_models.py new file mode 100644 index 0000000..3bd1ebe --- /dev/null +++ b/src/code/agent/services/gateway/queue/task_models.py @@ -0,0 +1,99 @@ +""" +任务模型定义 +定义任务的状态和数据结构 +""" +import time +from typing import Optional, Callable +from dataclasses import dataclass, field +from enum import Enum + + +class TaskStatus(Enum): + """任务状态枚举""" + PENDING = "pending" # 等待执行 + SUBMITTED = "submitted" # 已提交 + PROCESSING = "processing" # 正在执行 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 执行失败 + + def is_terminal(self) -> bool: + """判断是否为终态""" + return self in (TaskStatus.COMPLETED, TaskStatus.FAILED) + + def is_active(self) -> bool: + """判断是否为活跃状态(未完成)""" + return not self.is_terminal() + + +@dataclass +class TaskRequest: + """任务请求数据模型""" + task_id: str + client_id: str + prompt: dict + output_base64: bool = False + output_oss: bool = False + callback: Optional[Callable] = None + status: TaskStatus = TaskStatus.PENDING + + # 时间戳 + create_at: float = field(default_factory=time.time) + started_at: Optional[float] = None + completed_at: Optional[float] = None + + def update_status(self, new_status: TaskStatus) -> bool: + """ + 更新任务状态,包含状态转换校验 + + Returns: + bool: 状态是否成功更新 + """ + # 终态不允许再次更新 + if self.status.is_terminal(): + return False + + # 状态转换逻辑 + valid_transitions = { + TaskStatus.PENDING: {TaskStatus.SUBMITTED, TaskStatus.PROCESSING, TaskStatus.FAILED}, + TaskStatus.SUBMITTED: {TaskStatus.PROCESSING, TaskStatus.FAILED}, + TaskStatus.PROCESSING: {TaskStatus.COMPLETED, TaskStatus.FAILED}, + } + + if new_status not in valid_transitions.get(self.status, set()): + return False + + # 更新状态和时间戳 + self.status = new_status + + if new_status == TaskStatus.PROCESSING and self.started_at is None: + self.started_at = time.time() + elif new_status.is_terminal() and self.completed_at is None: + self.completed_at = time.time() + + return True + + def get_elapsed_time(self) -> Optional[float]: + """获取任务执行耗时(秒)""" + if not self.started_at: + return None + + end_time = self.completed_at or time.time() + return end_time - self.started_at + + def get_age(self) -> float: + """获取任务创建以来的时长(秒)""" + return time.time() - self.create_at + + def to_dict(self) -> dict: + """转换为字典格式""" + return { + 'task_id': self.task_id, + 'client_id': self.client_id, + 'status': self.status.value, + 'create_at': self.create_at, + 'started_at': self.started_at, + 'completed_at': self.completed_at, + 'elapsed_time': self.get_elapsed_time(), + 'age': self.get_age() + } + diff --git a/src/code/agent/services/gateway/queue/task_queue.py b/src/code/agent/services/gateway/queue/task_queue.py new file mode 100644 index 0000000..22f49de --- /dev/null +++ b/src/code/agent/services/gateway/queue/task_queue.py @@ -0,0 +1,819 @@ +import time +import threading +import uuid +from typing import Dict, Optional, Callable, List + +import constants +from .task_models import TaskStatus, TaskRequest +from utils.logger import log + + +class TaskQueue: + """任务队列 - 包含任务跟踪、状态轮询、状态广播功能 + + 注意:任务实际执行在GPU函数端,CPU端只负责跟踪和监控任务状态 + """ + + _instance = None + _instance_lock = threading.Lock() # 类级别的锁,用于单例创建 + + def __new__(cls, max_active_tasks: int = 20, max_total_tasks: int = 100): + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + instance = super().__new__(cls) + instance._initialized = False + instance._max_active_tasks = max_active_tasks + instance._max_total_tasks = max_total_tasks + cls._instance = instance + return cls._instance + + def __init__(self, max_active_tasks: int = 20, max_total_tasks: int = 100): + # 只初始化一次 + if self._initialized: + return + + self._tasks: Dict[str, TaskRequest] = {} # 存储任务信息 + self._lock = threading.Lock() # 实例级别的锁,保护任务字典 + + self._status_pollers = {} # 存储每个任务的状态轮询器 + self._poller_lock = threading.Lock() + + self._queue_status_broadcast_enabled = True + self._queue_status_broadcast_interval = 1.0 + self._queue_status_broadcast_thread = None + self._last_queue_status = None + + self._initialized = True + log("INFO", f"TaskQueue initialized (max_active_tasks={self._max_active_tasks}, max_total_tasks={self._max_total_tasks})") + + def start(self): + # 启动队列状态广播 + self._start_queue_status_broadcast() + log("INFO", "TaskQueue started") + + def stop(self): + # 停止队列状态广播 + self._stop_queue_status_broadcast() + + # 停止所有状态轮询器 + self._stop_all_pollers() + log("INFO", "TaskQueue stopped") + + def submit_task(self, + prompt: dict, + client_id: str, + task_id: Optional[str] = None, + output_base64: bool = False, + output_oss: bool = False, + callback: Optional[Callable] = None) -> str: + if task_id is None: + task_id = str(uuid.uuid4()) + + task_request = TaskRequest( + task_id=task_id, + client_id=client_id, + prompt=prompt, + output_base64=output_base64, + output_oss=output_oss, + callback=callback + ) + + with self._lock: + # 检查活跃任务数量(只计算 pending、submitted 和 processing 状态的任务) + active_tasks = sum( + 1 for task in self._tasks.values() + if task.status in [TaskStatus.PENDING, TaskStatus.SUBMITTED, TaskStatus.PROCESSING] + ) + + # 检查总任务数量 + total_tasks = len(self._tasks) + + if active_tasks >= self._max_active_tasks: + raise RuntimeError( + f"Task queue is full. Active tasks: {active_tasks}/{self._max_active_tasks}. " + f"Please wait for some tasks to complete before submitting new ones." + ) + + if total_tasks >= self._max_total_tasks: + # 尝试清理已完成的任务 + self._cleanup_completed_tasks_unsafe() + + # 重新检查 + if len(self._tasks) >= self._max_total_tasks: + raise RuntimeError( + f"Task storage is full. Total tasks: {len(self._tasks)}/{self._max_total_tasks}. " + f"Some completed tasks could not be cleaned up. Please try again later." + ) + + self._tasks[task_id] = task_request + + log("INFO", f"Task {task_id} registered for tracking (client: {client_id})") + + # 启动状态轮询(监控GPU函数端的状态) + self._start_status_polling(task_id) + + return task_id + + def get_task(self, task_id: str) -> Optional[TaskRequest]: + with self._lock: + return self._tasks.get(task_id) + + def get_all_tasks(self) -> List[TaskRequest]: + with self._lock: + # 快速创建浅拷贝,减少持锁时间 + return list(self._tasks.values()) + + def clear_queue(self) -> int: + """清空等待中的任务(仅清理 pending 和 submitted 状态)""" + cleared_count = 0 + + with self._lock: + task_ids_to_remove = [] + for task_id, task in self._tasks.items(): + if task.status in [TaskStatus.PENDING, TaskStatus.SUBMITTED]: + task_ids_to_remove.append(task_id) + + for task_id in task_ids_to_remove: + self._tasks.pop(task_id, None) + cleared_count += 1 + + return cleared_count + + def update_task_status(self, task_id: str, new_status: TaskStatus) -> bool: + """线程安全地更新任务状态 + + Args: + task_id: 任务ID + new_status: 新状态 + + Returns: + bool: 是否成功更新 + """ + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("WARNING", f"Cannot update status: task {task_id} not found") + return False + + old_status = task.status + + success = task.update_status(new_status) + + if success: + log("INFO", f"Task {task_id} status updated: {old_status.value} -> {new_status.value}") + else: + log("WARNING", f"Failed to update task {task_id} status from {old_status.value} to {new_status.value} (invalid transition)") + + return success + + def cancel_task(self, task_id: str) -> bool: + # 先停止状态轮询 + self._stop_status_polling(task_id) + + # 取消任务 + with self._lock: + task = self._tasks.get(task_id) + if not task: + return False + + # 只能取消未开始执行的任务 + if task.status not in [TaskStatus.PENDING, TaskStatus.SUBMITTED]: + return False + + # 从内存中删除 + self._tasks.pop(task_id, None) + + return True + + def _cleanup_completed_tasks_unsafe(self): + """清理已完成的任务(必须在持有 _lock 的情况下调用) + + 注意:此方法不加锁,调用者必须已经持有 self._lock + """ + now = time.time() + to_remove = [] + + for task_id, task in self._tasks.items(): + if task.status.is_terminal() and task.completed_at: + # 清理完成超过 30 秒的任务 + if now - task.completed_at > 30: + to_remove.append(task_id) + + for task_id in to_remove: + self._tasks.pop(task_id, None) + log("DEBUG", f"Cleaned up old task {task_id} during queue full check") + + return len(to_remove) + + def _stop_all_pollers(self): + with self._poller_lock: + poller_ids = list(self._status_pollers.keys()) + for task_id in poller_ids: + self._stop_status_polling(task_id, async_stop=False) + + def _start_status_polling(self, task_id: str): + """ + 为任务启动状态轮询 + """ + try: + from services.gateway.status import StatusPoller + + with self._poller_lock: + if task_id in self._status_pollers: + log("DEBUG", f"Task {task_id} already has status poller") + return + + poller = StatusPoller( + task_id=task_id, + poll_interval=0.5 + ) + + # 设置状态更新回调 + poller.set_status_callback(self._on_status_update) + + # 启动轮询 + poller.start() + + # 存储轮询器 + self._status_pollers[task_id] = poller + + log("INFO", f"Started status polling for task {task_id}") + + except Exception as e: + log("ERROR", f"Failed to start status polling for task {task_id}: {e}") + + def _stop_status_polling(self, task_id: str, async_stop: bool = False): + """停止任务的状态轮询 + + Args: + task_id: 任务ID + async_stop: 是否异步停止(用于避免死锁) + """ + with self._poller_lock: + if task_id not in self._status_pollers: + return + + poller = self._status_pollers.pop(task_id) + + if async_stop: + # 在状态回调中异步停止,避免死锁 + def stop_poller_with_timeout(): + try: + # 尝试正常停止 + if hasattr(poller, 'stop'): + poller.stop() + log("DEBUG", f"Async stopped status polling for task {task_id}") + except Exception as e: + log("ERROR", f"Error async stopping poller for task {task_id}: {e}") + # 如果正常停止失败,尝试强制清理 + try: + if hasattr(poller, 'force_stop'): + poller.force_stop() + elif hasattr(poller, '_stop_event'): + # 强制设置停止事件 + poller._stop_event.set() + except Exception as force_error: + log("ERROR", f"Failed to force stop poller for task {task_id}: {force_error}") + + # 使用新线程异步停止,设置超时保护 + stop_thread = threading.Thread(target=stop_poller_with_timeout, daemon=True) + stop_thread.start() + # 不等待线程完成,让它在后台运行 + else: + # 直接停止(用于除状态回调外的其他地方) + try: + if hasattr(poller, 'stop'): + poller.stop() + log("INFO", f"Stopped status polling for task {task_id}") + except Exception as e: + log("ERROR", f"Error stopping poller for task {task_id}: {e}") + + def _on_status_update(self, task_id: str, status_data: dict): + """状态更新回调函数(线程安全版本 - 修复竞态条件) + + 确保状态检查和更新是原子性的,避免在检查状态和更新状态之间 + 任务被删除或状态被修改导致的竞态条件。 + + 延迟链路: + 1. 状态轮询器检测到状态 (在 poller 中) + 2. -> 调用此回调函数 (此处开始计时) + 3. -> 在锁内更新内存任务状态 + 4. -> 提交异步广播到线程池 (广播提交耗时) + 5. -> 在线程中进行 WS 发送 (广播实际耗时) + """ + import time as callback_time + callback_start = callback_time.time() + + try: + # 根据状态类型处理更新 + status_type = status_data.get('type', '') + log("DEBUG", f"Processing status update for task {task_id}, type: {status_type}") + + # 原子性地检查任务状态并更新(在锁内完成所有检查和更新) + if status_type == 'serverless_api': + # 任务完成状态:原子性地检查并更新 + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("WARNING", f"Task {task_id} not found in queue (completion update ignored)") + return + + # 在锁内检查状态并更新 + if task.status != TaskStatus.COMPLETED: + if task.update_status(TaskStatus.COMPLETED): + # 状态更新成功 + log("INFO", f"Task {task_id} marked as completed (status: {task.status.value})") + else: + # 状态更新失败(可能是无效的状态转换) + log("WARNING", f"Failed to update task {task_id} to COMPLETED (current status: {task.status.value}, invalid transition)") + return + else: + # 任务已经是完成状态,可能是重复的状态更新 + log("DEBUG", f"Task {task_id} already completed, ignoring duplicate update") + return + + # 在锁外执行异步操作,避免持锁时间过长 + self._async_handle_task_completion_after_update(task_id, status_data) + + elif status_type == 'error' or status_type == 'execution_error': + # 任务失败状态:原子性地检查并更新 + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("WARNING", f"Task {task_id} not found in queue (failure update ignored)") + return + + if task.status != TaskStatus.FAILED: + if task.update_status(TaskStatus.FAILED): + log("INFO", f"Task {task_id} marked as failed (status: {task.status.value})") + else: + log("WARNING", f"Failed to update task {task_id} to FAILED (current status: {task.status.value}, invalid transition)") + return + else: + log("DEBUG", f"Task {task_id} already failed, ignoring duplicate update") + return + + # 在锁外执行异步操作 + self._async_handle_task_failure_after_update(task_id, status_data) + + elif status_type == 'executing': + # 任务执行中状态:原子性地检查并更新 + node = status_data.get('data', {}).get('node') + if node: + with self._lock: + task = self._tasks.get(task_id) + if not task: + # 任务不存在,忽略执行中状态更新 + return + + if task.status == TaskStatus.PENDING: + if task.update_status(TaskStatus.PROCESSING): + log("INFO", f"Task {task_id} marked as processing via status polling") + + # 生命周期日志:任务执行中 + log("INFO", f"[TaskLifecycle][{task_id}][EXECUTING] Task execution started on GPU") + + # 广播状态(不依赖任务状态,可以在锁外执行) + broadcast_start = callback_time.time() + self._async_broadcast_status(task_id, status_data) + broadcast_queued_time = callback_time.time() - broadcast_start + + callback_time_elapsed = callback_time.time() - callback_start + if callback_time_elapsed > 0.1: + log("DEBUG", f"Status update callback took {callback_time_elapsed:.3f}s for task {task_id} " + f"(broadcast queue time: {broadcast_queued_time*1000:.2f}ms, status_type: {status_type})") + + except Exception as e: + log("ERROR", f"Error in status update callback for task {task_id}: {e}") + import traceback + traceback.print_exc() + + def _start_queue_status_broadcast(self): + """启动队列状态广播""" + if not self._queue_status_broadcast_enabled: + return + + if self._queue_status_broadcast_thread and self._queue_status_broadcast_thread.is_alive(): + return + + self._queue_status_broadcast_thread = threading.Thread( + target=self._queue_status_broadcast_loop, + daemon=True, + name="queue-status-broadcast" + ) + self._queue_status_broadcast_thread.start() + log("INFO", "Started queue status broadcast") + + def _stop_queue_status_broadcast(self): + """停止队列状态广播""" + self._queue_status_broadcast_enabled = False + if self._queue_status_broadcast_thread and self._queue_status_broadcast_thread.is_alive(): + self._queue_status_broadcast_thread.join(timeout=5.0) + log("INFO", "Stopped queue status broadcast") + + def _queue_status_broadcast_loop(self): + """队列状态广播循环""" + log("INFO", "Queue status broadcast loop started") + + while self._queue_status_broadcast_enabled: + try: + # 获取当前队列状态 + current_status = self._get_queue_status_for_broadcast() + + # 检查状态是否有变化 + if self._has_queue_status_changed(current_status): + # 广播队列状态 + self._broadcast_queue_status(current_status) + self._last_queue_status = current_status + + # 等待下次广播 + time.sleep(self._queue_status_broadcast_interval) + + except Exception as e: + log("ERROR", f"Error in queue status broadcast loop: {e}") + time.sleep(self._queue_status_broadcast_interval) + + log("INFO", "Queue status broadcast loop stopped") + + def _get_queue_status_for_broadcast(self) -> dict: + """获取用于广播的队列状态""" + try: + # 快速获取队列状态 + with self._lock: + total_tasks = len(self._tasks) + pending_tasks = sum(1 for task in self._tasks.values() + if task.status in [TaskStatus.PENDING, TaskStatus.SUBMITTED]) + processing_tasks = sum(1 for task in self._tasks.values() + if task.status == TaskStatus.PROCESSING) + completed_tasks = sum(1 for task in self._tasks.values() + if task.status == TaskStatus.COMPLETED) + failed_tasks = sum(1 for task in self._tasks.values() + if task.status == TaskStatus.FAILED) + + # 获取轮询器数量 + with self._poller_lock: + active_pollers = len(self._status_pollers) + + # queue_size 表示等待处理的任务数(pending + processing) + queue_size = pending_tasks + processing_tasks + + return { + 'timestamp': time.time(), + 'queue_size': queue_size, # 等待和处理中的任务总数 + 'total_tasks': total_tasks, + 'pending_tasks': pending_tasks, + 'processing_tasks': processing_tasks, + 'completed_tasks': completed_tasks, + 'failed_tasks': failed_tasks, + 'active_pollers': active_pollers + } + + except Exception as e: + log("ERROR", f"Error getting queue status: {e}") + return { + 'timestamp': time.time(), + 'queue_size': 0, + 'total_tasks': 0, + 'pending_tasks': 0, + 'processing_tasks': 0, + 'completed_tasks': 0, + 'failed_tasks': 0, + 'active_pollers': 0 + } + + def _has_queue_status_changed(self, current_status: dict) -> bool: + """检查队列状态是否有变化""" + if not self._last_queue_status: + return True + + # 比较关键字段 + key_fields = ['queue_size', 'pending_tasks', 'processing_tasks', 'completed_tasks', 'failed_tasks'] + + for field in key_fields: + if current_status.get(field) != self._last_queue_status.get(field): + return True + + return False + + def _broadcast_queue_status(self, queue_status: dict): + """广播队列状态给所有WebSocket连接(异步非阻塞)""" + try: + # 只在CPU模式下广播 + if constants.COMFYUI_MODE != "cpu": + return + + # 构建ComfyUI格式的队列状态消息 + # ComfyUI的queue_remaining应该包含所有等待和正在处理的任务 + queue_remaining = queue_status['pending_tasks'] + queue_status['processing_tasks'] + comfyui_message = { + "type": "status", + "data": { + "status": { + "exec_info": { + "queue_remaining": queue_remaining + } + } + } + } + + # 广播给所有WebSocket连接 + from services.process.websocket.websocket_manager import ws_manager + + # 获取所有活跃连接 + active_connections = len(ws_manager.active_connections) + if active_connections > 0: + # 使用特殊的广播ID "queue_status" 来标识队列状态广播 + # 异步广播,不阻塞队列状态线程 + sent_count = ws_manager.broadcast_comfyui_message_async("queue_status", comfyui_message) + + log("DEBUG", f"Queued queue status broadcast to {sent_count} connections: " + f"pending={queue_status['pending_tasks']}, " + f"processing={queue_status['processing_tasks']}, " + f"total={queue_status['total_tasks']}") + + except Exception as e: + log("ERROR", f"Error broadcasting queue status: {e}") + + def _async_handle_task_completion_after_update(self, task_id: str, status_data: dict): + """异步处理任务完成(状态已更新版本) + + 在状态已经原子性地更新为 COMPLETED 后,执行后续操作: + - 记录生命周期日志 + - 停止状态轮询器 + - 调度任务清理 + + 注意:此方法假设状态已经在 _on_status_update 中原子性地更新了 + """ + def handle_completion(): + try: + # 验证任务仍然存在(可能在更新后很快被清理) + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("WARNING", f"Task {task_id} was removed before completion handling") + return + + # 验证状态确实是 COMPLETED + if task.status != TaskStatus.COMPLETED: + log("WARNING", f"Task {task_id} status is {task.status.value}, expected COMPLETED") + return + + # 生命周期日志:任务完成 + data = status_data.get('data', {}) + output_count = len(data.get('results', [])) + log("INFO", f"[TaskLifecycle][{task_id}][COMPLETED] Task completed, outputs={output_count}") + + # 异步停止轮询器 + self._stop_status_polling(task_id, async_stop=True) + + # 延迟清理任务 + self._schedule_task_cleanup(task_id, "completed") + + except Exception as e: + log("ERROR", f"Error handling task completion for {task_id}: {e}") + + # 在单独线程中处理 + threading.Thread(target=handle_completion, daemon=True).start() + + def _async_handle_task_failure_after_update(self, task_id: str, status_data: dict): + """异步处理任务失败(状态已更新版本) + + 在状态已经原子性地更新为 FAILED 后,执行后续操作: + - 记录生命周期日志 + - 停止状态轮询器 + - 调度任务清理 + + 注意:此方法假设状态已经在 _on_status_update 中原子性地更新了 + """ + def handle_failure(): + try: + # 验证任务仍然存在(可能在更新后很快被清理) + with self._lock: + task = self._tasks.get(task_id) + if not task: + log("WARNING", f"Task {task_id} was removed before failure handling") + return + + # 验证状态确实是 FAILED + if task.status != TaskStatus.FAILED: + log("WARNING", f"Task {task_id} status is {task.status.value}, expected FAILED") + return + + # 生命周期日志:任务失败 + error_msg = status_data.get('data', {}).get('exception_message', 'Unknown error') + log("INFO", f"[TaskLifecycle][{task_id}][FAILED] Task failed, error={error_msg}") + + # 异步停止轮询器 + self._stop_status_polling(task_id, async_stop=True) + + # 延迟清理任务 + self._schedule_task_cleanup(task_id, "failed") + + except Exception as e: + log("ERROR", f"Error handling task failure for {task_id}: {e}") + + # 在单独线程中处理 + threading.Thread(target=handle_failure, daemon=True).start() + + def _async_broadcast_status(self, task_id: str, status_data: dict): + """异步广播状态更新""" + def broadcast(): + try: + # WebSocket广播 + self._broadcast_task_status_via_websocket(task_id, status_data) + except Exception as e: + log("ERROR", f"Error broadcasting status for task {task_id}: {e}") + + # 在单独线程中处理WebSocket广播 + threading.Thread(target=broadcast, daemon=True).start() + + def _schedule_task_cleanup(self, task_id: str, status: str, delay: float = 5.0): + """调度任务清理 + + Args: + task_id: 任务ID + status: 任务状态描述 + delay: 延迟清理时间(秒),默认5秒让前端获取最终状态 + """ + def cleanup(): + try: + time.sleep(delay) + + with self._lock: + task = self._tasks.pop(task_id, None) + if task: + log("DEBUG", f"Cleaned up {status} task {task_id} from queue") + else: + log("DEBUG", f"Task {task_id} already cleaned up") + + except Exception as e: + log("ERROR", f"Error cleaning up {status} task {task_id}: {e}") + # 确保任务最终被清理,即使出错 + try: + with self._lock: + self._tasks.pop(task_id, None) + except: + pass + + # 在单独线程中处理清理 + cleanup_thread = threading.Thread(target=cleanup, daemon=True, name=f"cleanup-{task_id}") + cleanup_thread.start() + + def _broadcast_task_status_via_websocket(self, task_id: str, status_data: dict): + """ + 通过WebSocket广播任务状态给订阅者,使用ComfyUI原生消息格式 + + Args: + task_id: 任务ID + status_data: 从 GPU函数获取的原始状态数据 + """ + try: + # 只在CPU模式下广播 + if constants.COMFYUI_MODE != "cpu": + return + + from services.process.websocket.websocket_manager import ws_manager + + # 转换为ComfyUI格式 + comfyui_message = self._convert_to_comfyui_message_format(task_id, status_data) + if not comfyui_message: + return + + # 广播消息 + if isinstance(comfyui_message, list): + for msg in comfyui_message: + sent_count = ws_manager.broadcast_comfyui_message(task_id, msg) + if sent_count > 0: + log("DEBUG", f"Broadcasted {msg.get('type')} message for task {task_id} to {sent_count} connections") + else: + sent_count = ws_manager.broadcast_comfyui_message(task_id, comfyui_message) + if sent_count > 0: + log("DEBUG", f"Broadcasted {comfyui_message.get('type')} message for task {task_id} to {sent_count} connections") + + except Exception as e: + log("ERROR", f"Error broadcasting task status via WebSocket: {e}") + + def _convert_to_comfyui_message_format(self, task_id: str, status_data: dict): + """将状态数据转换为ComfyUI消息格式""" + try: + status_type = status_data.get('type', '') + data = status_data.get('data', {}) + + if status_type == 'serverless_api': + # 最终结果 + return { + "type": "serverless_api", + "data": data + } + + elif status_type == 'executing': + # 执行中状态 + node = data.get('node') + if node: + return { + "type": "executing", + "data": { + "node": node, + "prompt_id": data.get('prompt_id', task_id) + } + } + + elif status_type == 'progress': + # 进度更新 + return { + "type": "progress", + "data": { + "value": data.get('value', 0), + "max": data.get('max', 100), + "prompt_id": data.get('prompt_id', task_id) + } + } + + elif status_type == 'status': + # 状态更新 + return { + "type": "status", + "data": { + "status": { + "exec_info": { + "queue_remaining": self._get_pending_task_count() + } + } + } + } + + elif status_type == 'error' or status_type == 'execution_error': + # 执行错误 + return { + "type": "execution_error", + "data": { + "prompt_id": data.get('prompt_id', task_id), + "node_id": data.get('node_id'), + "exception_message": data.get('exception_message', str(data.get('message', 'Unknown error'))), + "exception_type": data.get('exception_type', 'RuntimeError'), + "traceback": data.get('traceback', []) + } + } + + # 其他类型的消息直接返回 + return status_data + + except Exception as e: + log("ERROR", f"Error converting to ComfyUI message format: {e}") + return None + + def _get_pending_task_count(self) -> int: + """获取待处理任务数量(PENDING、SUBMITTED和PROCESSING状态)""" + with self._lock: + return sum( + 1 for task in self._tasks.values() + if task.status in [TaskStatus.PENDING, TaskStatus.SUBMITTED, TaskStatus.PROCESSING] + ) + + def associate_task_with_client_id(self, task_id: str, client_id: str): + """ + 将任务与指定的ComfyUI客户端关联,使前端能够接收到任务状态更新 + + Args: + task_id: 任务ID + client_id: ComfyUI客户端ID + """ + log("DEBUG", f"Attempting to associate task {task_id} with client_id {client_id}") + log("DEBUG", f"Current COMFYUI_MODE: {constants.COMFYUI_MODE}") + + # 只在CPU模式下才需要关联 + if constants.COMFYUI_MODE != "cpu": + log("DEBUG", f"Skipping client association - not in CPU mode") + return + + try: + from services.process.websocket.websocket_manager import ws_manager + + # 将任务与客户端关联 + associated_count = ws_manager.associate_task_with_client_id(task_id, client_id) + + if associated_count > 0: + log("INFO", f"Task {task_id} successfully associated with ComfyUI client {client_id} ({associated_count} connections)") + else: + log("WARNING", f"Failed to associate task {task_id} with client {client_id} - no connections found") + + except Exception as e: + import traceback + log("ERROR", f"Failed to associate task {task_id} with client {client_id}: {e}") + log("ERROR", f"Traceback: {traceback.format_exc()}") + +# 全局任务队列实例 - 延迟初始化 +_task_queue = None + +def get_task_queue() -> TaskQueue: + """获取全局任务队列实例""" + global _task_queue + if _task_queue is None: + log("DEBUG", f"Creating TaskQueue instance...") + _task_queue = TaskQueue() + # 启动队列状态广播线程 + _task_queue.start() + log("DEBUG", f"TaskQueue instance created and started") + return _task_queue + diff --git a/src/code/agent/services/gateway/status/__init__.py b/src/code/agent/services/gateway/status/__init__.py new file mode 100644 index 0000000..eda18ae --- /dev/null +++ b/src/code/agent/services/gateway/status/__init__.py @@ -0,0 +1,6 @@ +from .poller import StatusPoller + +__all__ = [ + 'StatusPoller', +] + diff --git a/src/code/agent/services/gateway/status/poller.py b/src/code/agent/services/gateway/status/poller.py new file mode 100644 index 0000000..b7bfd50 --- /dev/null +++ b/src/code/agent/services/gateway/status/poller.py @@ -0,0 +1,155 @@ +""" +Status Poller +处理状态轮询相关的逻辑 +""" +import json +import time +import threading +from typing import Optional, Callable, Any + +from utils.logger import log + + +class StatusPoller: + """ + 状态轮询器 + 用于定期从存储中轮询工作流执行状态 + """ + + def __init__(self, task_id: str, poll_interval: float = 1.0): + """ + 初始化状态轮询器 + + Args: + task_id: 要轮询的任务ID + poll_interval: 轮询间隔,单位秒,默认1.0秒 + """ + self.task_id = task_id + self.poll_interval = poll_interval + self.is_running = False + self.thread: Optional[threading.Thread] = None + self.on_status_update: Optional[Callable[[str, Any], None]] = None + self.last_status_count = 0 # 记录已处理的状态数量 + + def set_status_callback(self, callback: Callable[[str, Any], None]): + """ + 设置状态更新回调函数 + + Args: + callback: 回调函数,参数为 (task_id, status_data) + """ + self.on_status_update = callback + + def start(self): + """启动轮询线程""" + if self.is_running: + log("DEBUG", f"Task {self.task_id} poller is already running") + return + + self.is_running = True + self.thread = threading.Thread(target=self._poll_loop, daemon=True) + self.thread.start() + log("INFO", f"Started polling for task {self.task_id}") + + def stop(self): + """停止轮询线程""" + if not self.is_running: + return + + self.is_running = False + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=5.0) + log("INFO", f"Stopped polling for task {self.task_id}") + + def _poll_loop(self): + """轮询循环逻辑(通过serverless_api接口查询状态)""" + # 是否已找到任务完成状态 + task_completed = False + + while self.is_running and not task_completed: + try: + # 时间监控:轮询周期开始 + poll_start = time.time() + + # 通过serverless_api接口读取状态 + read_start = time.time() + all_statuses = self._get_status_from_api() + read_cost = (time.time() - read_start) * 1000 + + # 只处理新增的状态 + new_statuses = all_statuses[self.last_status_count:] + + if new_statuses: + log("INFO", f"[Perf][{self.task_id}] File read cost: {read_cost:.1f}ms, got {len(new_statuses)} new statuses (total: {len(all_statuses)})") + + # 调用回调函数处理新状态 + callback_start = time.time() + for status in new_statuses: + if self.on_status_update: + self.on_status_update(self.task_id, status) + else: + # 默认输出到控制台 + log("DEBUG", f"Task {self.task_id} status update: {json.dumps(status, ensure_ascii=False)}") + + # 更新已处理计数 + self.last_status_count += 1 + + # 检查是否为任务完成状态 + if self._is_status_completed(status): + task_completed = True + break + + callback_cost = (time.time() - callback_start) * 1000 + poll_total = (time.time() - poll_start) * 1000 + log("INFO", f"[Perf][{self.task_id}] Poll cycle: read={read_cost:.1f}ms, callback={callback_cost:.1f}ms, total={poll_total:.1f}ms") + + except Exception as e: + log("ERROR", f"Error polling task {self.task_id}: {e}") + from traceback import print_exception + print_exception(e) + + # 等待下次轮询 + if not task_completed: + time.sleep(self.poll_interval) + + self.is_running = False + + # 不再需要清理缓存(因为不使用增量读取缓存) + log("DEBUG", f"Polling stopped for task {self.task_id}, processed {self.last_status_count} statuses") + + def _get_status_from_api(self): + """ + 通过调用serverless_api接口获取状态 + + Returns: + list: 状态列表 + """ + try: + from services.serverlessapi.serverless_api_service import ServerlessApiService + service = ServerlessApiService() + return service.get_status_from_store(self.task_id) + except Exception as e: + log("ERROR", f"Failed to get status from serverless_api for task {self.task_id}: {e}") + return [] + + def _is_status_completed(self, status: dict) -> bool: + """ + 检查单个状态是否表示任务完成 + + Args: + status: 单个状态对象 + + Returns: + True if status indicates task completion, False otherwise + """ + if not status: + return False + + status_type = status.get("type", "") + + # 如果是最终结果(serverless_api类型)或错误状态,认为任务完成 + if status_type in ["serverless_api", "error", "execution_error"]: + return True + + return False + diff --git a/src/code/agent/services/process/backend_process_manager.py b/src/code/agent/services/process/backend_process_manager.py index 30a5d60..dd3b2d6 100644 --- a/src/code/agent/services/process/backend_process_manager.py +++ b/src/code/agent/services/process/backend_process_manager.py @@ -34,7 +34,8 @@ def is_alive(self) -> bool: Returns: bool: 如果在2秒内收到正常响应则返回True,否则返回False """ - if constants.USE_API_MODE: # TODO: 针对ComfyUI和SD各类子进程崩溃情况作梳理,此前API模式忽略监控检查 + # GPU模式下跳过健康检查 + if constants.USE_API_MODE or constants.COMFYUI_MODE == 'gpu': return True try: diff --git a/src/code/agent/services/process/websocket/websocket_manager.py b/src/code/agent/services/process/websocket/websocket_manager.py index 9553436..8ef6358 100644 --- a/src/code/agent/services/process/websocket/websocket_manager.py +++ b/src/code/agent/services/process/websocket/websocket_manager.py @@ -1,6 +1,11 @@ import json import threading +import time from datetime import datetime +from typing import Set, Dict +from concurrent.futures import ThreadPoolExecutor + +from utils.logger import log class WebSocketManager: @@ -8,6 +13,21 @@ def __init__(self): self.active_connections = set() # 存储和用户comfyui client建立的所有活跃WebSocket连接 self._lock = threading.Lock() self._connection_times = {} + + # 任务状态推送功能 + self._task_subscriptions: Dict[str, Set] = {} # task_id -> set of websockets + self._client_subscriptions: Dict = {} # websocket -> set of task_ids + self._client_id_mapping: Dict[str, Set] = {} # client_id -> set of websockets + self._ws_client_id_mapping: Dict = {} # websocket -> client_id + + # 阶段一优化:性能监控和异步广播 + self._thread_pool = ThreadPoolExecutor(max_workers=5, thread_name_prefix="ws-broadcast") + self._performance_metrics = { + 'broadcast_times': [], # 最近100次广播时间 + 'send_failures': 0, # 发送失败数 + 'total_broadcasts': 0, # 总广播次数 + 'connection_errors': 0 # 连接错误数 + } def get_connection_info(self, ws): environ = ws.environ @@ -16,25 +36,700 @@ def get_connection_info(self, ws): 'address': environ.get('REMOTE_ADDR'), 'port': environ.get('REMOTE_PORT') } + + def _set_tcp_nodelay(self, ws): + """ + 设置TCP_NODELAY禁用Nagle算法,确保消息立即发送而不是等待缓冲 + """ + try: + import socket + # 尝试多种方式获取底层socket + sock = None + + # 方法1: 通过environ获取werkzeug.socket + if hasattr(ws, 'environ'): + sock = ws.environ.get('werkzeug.socket') + + # 方法2: 通过sock属性 + if not sock and hasattr(ws, 'sock'): + sock = ws.sock + + # 方法3: 通过_sock属性 + if not sock and hasattr(ws, '_sock'): + sock = ws._sock + + # 方法4: simple-websocket的ws对象可能有connected属性 + if not sock and hasattr(ws, 'connected') and hasattr(ws.connected, 'sock'): + sock = ws.connected.sock + + if sock and hasattr(sock, 'setsockopt'): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + log("DEBUG", f"[WebSocketManager] Set TCP_NODELAY for connection {id(ws)}") + else: + log("DEBUG", f"[WebSocketManager] Could not find socket for connection {id(ws)}") + + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to set TCP_NODELAY: {e}") def add_connection(self, ws): with self._lock: self.active_connections.add(ws) self._connection_times[id(ws)] = datetime.now() + self._client_subscriptions[ws] = set() # 初始化订阅记录 conn_info = self.get_connection_info(ws) - print(f"ws connected: {json.dumps(conn_info, indent=2)}") + log("INFO", f"ws connected: {json.dumps(conn_info, indent=2)}") + + # 设置TCP_NODELAY禁用Nagle算法,确保消息立即发送 + self._set_tcp_nodelay(ws) + + # 发送初始队列状态 + self._send_initial_status(ws) def remove_connection(self, ws): with self._lock: self.active_connections.discard(ws) conn_id = id(ws) start_time = self._connection_times.pop(conn_id, None) + + # 清理该连接的所有任务订阅 + subscribed_tasks = self._client_subscriptions.get(ws, set()) + for task_id in subscribed_tasks: + if task_id in self._task_subscriptions: + self._task_subscriptions[task_id].discard(ws) + # 如果没有其他订阅者,删除任务记录 + if not self._task_subscriptions[task_id]: + del self._task_subscriptions[task_id] + + # 清理ComfyUI客户端ID映射 + client_id = self._ws_client_id_mapping.get(ws) + if client_id: + if client_id in self._client_id_mapping: + self._client_id_mapping[client_id].discard(ws) + if not self._client_id_mapping[client_id]: + del self._client_id_mapping[client_id] + self._ws_client_id_mapping.pop(ws, None) + + # 清理客户端订阅记录 + self._client_subscriptions.pop(ws, None) + conn_info = self.get_connection_info(ws) if start_time: duration = (datetime.now() - start_time).total_seconds() conn_info['duration'] = f"{duration:.2f}s" - print(f"ws disconnected, {json.dumps(conn_info, indent=2)}") + log("INFO", f"ws disconnected, {json.dumps(conn_info, indent=2)}") + def subscribe_task_status(self, ws, task_id: str) -> bool: + """ + 订阅任务状态推送 + + Args: + ws: WebSocket连接 + task_id: 任务ID + + Returns: + bool: 订阅是否成功 + """ + try: + with self._lock: + # 检查连接是否有效 + if ws not in self.active_connections: + return False + + # 添加任务订阅 + if task_id not in self._task_subscriptions: + self._task_subscriptions[task_id] = set() + self._task_subscriptions[task_id].add(ws) + + # 添加客户端订阅记录 + if ws not in self._client_subscriptions: + self._client_subscriptions[ws] = set() + self._client_subscriptions[ws].add(task_id) + + subscriber_count = len(self._task_subscriptions[task_id]) + log("DEBUG", f"[WebSocketManager] Client subscribed to task {task_id} " + f"(total subscribers: {subscriber_count})") + return True + + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to subscribe to task {task_id}: {e}") + return False + + def unsubscribe_task_status(self, ws, task_id: str) -> bool: + """ + 取消订阅任务状态 + + Args: + ws: WebSocket连接 + task_id: 任务ID + + Returns: + bool: 取消订阅是否成功 + """ + try: + with self._lock: + # 从任务订阅中移除 + if task_id in self._task_subscriptions: + self._task_subscriptions[task_id].discard(ws) + # 如果没有其他订阅者,删除任务记录 + if not self._task_subscriptions[task_id]: + del self._task_subscriptions[task_id] + + # 从客户端订阅记录中移除 + if ws in self._client_subscriptions: + self._client_subscriptions[ws].discard(task_id) + + remaining_subscribers = len(self._task_subscriptions.get(task_id, set())) + log("DEBUG", f"[WebSocketManager] Client unsubscribed from task {task_id} " + f"(remaining subscribers: {remaining_subscribers})") + return True + + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to unsubscribe from task {task_id}: {e}") + return False + + def broadcast_task_status(self, task_id: str, status_data: dict) -> int: + """ + 广播任务状态给所有订阅者(阶段一优化:带超时和性能监控) + + Args: + task_id: 任务ID + status_data: 状态数据 + + Returns: + int: 成功发送的连接数 + """ + broadcast_start_time = time.time() + + with self._lock: + subscribers = self._task_subscriptions.get(task_id, set()).copy() + + if not subscribers: + return 0 + + # 使用线程池异步处理广播,避免阻塞主线程 + # 注意:这里只记录提交到线程池的时间,实际发送在线程中进行 + future = self._thread_pool.submit( + self._do_broadcast_with_timeout, + task_id, + status_data, + subscribers, + broadcast_start_time + ) + + # 异步提交耗时(只包括线程池队列时间,不包括实际发送) + submit_time = time.time() - broadcast_start_time + if submit_time > 0.01: # 如果提交耗时超过10ms,记录警告 + log("WARNING", f"[WebSocketManager] Thread pool submit time {submit_time*1000:.2f}ms for task {task_id} " + f"(may indicate thread pool is busy)") + + # 不等待结果,立即返回,提高响应速度 + # 后续可通过future.result()获取结果,但为了不阻塞这里直接返回订阅者数 + return len(subscribers) + + def _do_broadcast_with_timeout(self, task_id: str, status_data: dict, + subscribers: set, broadcast_start_time: float) -> int: + """ + 执行实际的广播逻辑,带超时和异常处理 + + Args: + task_id: 任务ID + status_data: 状态数据 + subscribers: 订阅者集合 + broadcast_start_time: 广播开始时间 + + Returns: + int: 成功发送的连接数 + """ + try: + message = json.dumps(status_data, ensure_ascii=False) + successful_sends = 0 + disconnected = set() + send_timeout = 2.0 # 2秒发送超时 + + # 记录每个客户端的发送耗时 + send_times = {} + + for ws in subscribers: + try: + # 阶段一优化:带超时的WebSocket发送 + send_start = time.time() + self._send_message_with_timeout(ws, message, send_timeout) + send_duration = time.time() - send_start + send_times[id(ws)] = send_duration + successful_sends += 1 + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to send to client (task {task_id}): {e}") + disconnected.add(ws) + # 更新性能指标(线程安全) + with self._lock: + self._performance_metrics['send_failures'] += 1 + self._performance_metrics['connection_errors'] += 1 + + # 清理断开的连接(异步执行避免阻塞) + if disconnected: + for ws in disconnected: + try: + self.remove_connection(ws) + except Exception as e: + log("ERROR", f"[WebSocketManager] Error removing connection: {e}") + + # 记录性能指标 + broadcast_duration = time.time() - broadcast_start_time + self._record_broadcast_performance(broadcast_duration, successful_sends) + + if successful_sends > 0: + status_type = status_data.get('type', 'unknown') + avg_send_time = sum(send_times.values()) / len(send_times) if send_times else 0 + max_send_time = max(send_times.values()) if send_times else 0 + min_send_time = min(send_times.values()) if send_times else 0 + + log("DEBUG", f"[WebSocketManager] Broadcasted {status_type} for task {task_id} to {successful_sends} clients " + f"(total: {broadcast_duration*1000:.1f}ms, " + f"avg_send: {avg_send_time*1000:.1f}ms, " + f"max_send: {max_send_time*1000:.1f}ms, " + f"min_send: {min_send_time*1000:.1f}ms)") + + return successful_sends + + except Exception as e: + log("ERROR", f"[WebSocketManager] Error in broadcast thread: {e}") + return 0 + + def _send_message_with_timeout(self, ws, message: str, timeout: float = 2.0): + """ + 带超时的WebSocket消息发送 + + Args: + ws: WebSocket连接 + message: 要发送的消息 + timeout: 超时时间(秒) + """ + try: + # 直接发送消息,不进行flush操作 + # TCP_NODELAY已经在连接初始化时设置,消息会立即发送 + # 每次send后尝试flush可能增加不必要的开销 + ws.send(message) + except Exception as e: + # 记录错误类型以便分析 + error_type = type(e).__name__ + log("ERROR", f"[WebSocketManager] WebSocket send failed ({error_type}): {str(e)[:100]}") + raise # 重新抛出异常由调用者处理 + + def associate_client_id_with_connection(self, ws, client_id: str): + """ + 将ComfyUI客户端ID与连接关联 + + Args: + ws: WebSocket连接 + client_id: ComfyUI客户端ID + """ + with self._lock: + # 如果是重连,移除旧连接 + if client_id in self._client_id_mapping: + old_connections = self._client_id_mapping[client_id].copy() + for old_ws in old_connections: + if old_ws != ws: + log("INFO", f"[WebSocketManager] Removing old connection for client_id {client_id} (reconnect)") + # 不调用 remove_connection,只清理映射 + self._client_id_mapping[client_id].discard(old_ws) + self._ws_client_id_mapping.pop(old_ws, None) + + if client_id not in self._client_id_mapping: + self._client_id_mapping[client_id] = set() + self._client_id_mapping[client_id].add(ws) + self._ws_client_id_mapping[ws] = client_id + log("DEBUG", f"[WebSocketManager] Associated client_id {client_id} with WebSocket connection") + + def associate_task_with_client_id(self, task_id: str, client_id: str) -> int: + """ + 将任务与指定的ComfyUI客户端关联,并自动订阅该任务的状态 + + Args: + task_id: 任务ID + client_id: ComfyUI客户端ID + + Returns: + int: 成功关联的连接数 + """ + log("DEBUG", f"[WebSocketManager] Attempting to associate task {task_id} with client_id {client_id}") + + with self._lock: + connections = self._client_id_mapping.get(client_id, set()).copy() + log("DEBUG", f"[WebSocketManager] Found {len(connections)} connections for client_id {client_id}") + + if not connections: + log("WARNING", f"[WebSocketManager] No connections found for client_id {client_id}") + log("DEBUG", f"[WebSocketManager] Available client IDs: {list(self._client_id_mapping.keys())}") + return 0 + + associated_count = 0 + for ws in connections: + log("DEBUG", f"[WebSocketManager] Subscribing connection {id(ws)} to task {task_id}") + if self.subscribe_task_status(ws, task_id): + associated_count += 1 + + log("DEBUG", f"[WebSocketManager] Associated task {task_id} with client_id {client_id} " + f"({associated_count} connections)") + + return associated_count + + def resubscribe_client_tasks(self, ws, client_id: str): + """ + 当客户端重连时,重新订阅该客户端的所有进行中的任务 + + Args: + ws: 新的WebSocket连接 + client_id: 客户端ID + """ + try: + # 获取所有任务订阅,找到属于该 client_id 的任务 + from services.gateway import get_task_queue + from services.gateway.queue.task_models import TaskStatus + + task_queue = get_task_queue() + all_tasks = task_queue.get_all_tasks() + + # 过滤出该客户端的进行中的任务 + active_tasks = [ + task for task in all_tasks + if task.client_id == client_id and + task.status in [TaskStatus.PENDING, TaskStatus.SUBMITTED, TaskStatus.PROCESSING] + ] + + if not active_tasks: + log("DEBUG", f"[WebSocketManager] No active tasks found for client_id {client_id} on reconnect") + return + + # 为每个活跃任务重新订阅 + resubscribed_count = 0 + for task in active_tasks: + if self.subscribe_task_status(ws, task.task_id): + resubscribed_count += 1 + log("DEBUG", f"[WebSocketManager] Resubscribed task {task.task_id} (status={task.status.value}) for client_id {client_id}") + + if resubscribed_count > 0: + log("INFO", f"[WebSocketManager] Resubscribed {resubscribed_count} active tasks for client_id {client_id} on reconnect") + + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to resubscribe tasks for client_id {client_id}: {e}") + import traceback + log("ERROR", f"Traceback: {traceback.format_exc()}") + + def broadcast_comfyui_message(self, task_id: str, comfyui_message: dict) -> int: + """ + 广播ComfyUI原生格式的消息给任务订阅者(同步) + + Args: + task_id: 任务ID(特殊值"queue_status"表示广播给所有连接) + comfyui_message: ComfyUI原生格式的消息 + + Returns: + int: 成功发送的连接数 + """ + # 特殊处理队列状态广播 + if task_id == "queue_status": + return self._broadcast_to_all_connections(comfyui_message) + else: + return self.broadcast_task_status(task_id, comfyui_message) + + def broadcast_comfyui_message_async(self, task_id: str, comfyui_message: dict) -> int: + """ + 异步广播ComfyUI原生格式的消息(不阻塞调用者) + + Args: + task_id: 任务ID(特殊值"queue_status"表示广播给所有连接) + comfyui_message: ComfyUI原生格式的消息 + + Returns: + int: 订阅者/连接数(实际发送在后台线程中进行) + """ + # 特殊处理队列状态广播 + if task_id == "queue_status": + return self._broadcast_to_all_connections_async(comfyui_message) + else: + # 任务状态广播已经是异步的 + return self.broadcast_task_status(task_id, comfyui_message) + + def _broadcast_to_all_connections(self, message: dict) -> int: + """ + 广播消息给所有活跃连接(同步发送,确保立即交付) + + Args: + message: 要广播的消息 + + Returns: + int: 成功发送的连接数 + """ + broadcast_start_time = time.time() + + with self._lock: + all_connections = self.active_connections.copy() + + if not all_connections: + return 0 + + # 队列状态广播改为同步发送,避免延迟 + return self._do_broadcast_to_all_with_timeout(message, all_connections, broadcast_start_time) + + def _broadcast_to_all_connections_async(self, message: dict) -> int: + """ + 异步广播消息给所有活跃连接(不阻塞调用者) + + Args: + message: 要广播的消息 + + Returns: + int: 连接数(实际发送在后台线程中进行) + """ + broadcast_start_time = time.time() + + with self._lock: + all_connections = self.active_connections.copy() + + if not all_connections: + return 0 + + # 提交到线程池异步执行,不阻塞调用者 + self._thread_pool.submit( + self._do_broadcast_to_all_with_timeout, + message, + all_connections, + broadcast_start_time + ) + + return len(all_connections) + + def _do_broadcast_to_all_with_timeout(self, message: dict, + connections: set, broadcast_start_time: float) -> int: + """ + 执行实际的广播逻辑,带超时和异常处理(并行发送优化) + + Args: + message: 要广播的消息 + connections: 连接集合 + broadcast_start_time: 广播开始时间 + + Returns: + int: 成功发送的连接数 + """ + try: + message_str = json.dumps(message, ensure_ascii=False) + successful_sends = 0 + disconnected = set() + send_timeout = 2.0 # 2秒发送超时 + + # 记录每个客户端的发送耗时 + send_times = {} + send_lock = threading.Lock() + + def send_to_client(ws): + """并行发送给单个客户端""" + nonlocal successful_sends + try: + send_start = time.time() + self._send_message_with_timeout(ws, message_str, send_timeout) + send_duration = time.time() - send_start + + with send_lock: + send_times[id(ws)] = send_duration + successful_sends += 1 + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to send queue status to client: {e}") + with send_lock: + disconnected.add(ws) + # 更新性能指标 + with self._lock: + self._performance_metrics['send_failures'] += 1 + self._performance_metrics['connection_errors'] += 1 + + # 并行发送给所有连接(使用线程池) + from concurrent.futures import ThreadPoolExecutor, wait, FIRST_EXCEPTION + + # 根据连接数动态调整并行度(最多10个并行) + max_workers = min(10, len(connections)) + with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="ws-send") as executor: + futures = [executor.submit(send_to_client, ws) for ws in connections] + # 等待所有发送完成,最多等待3秒 + wait(futures, timeout=3.0) + + # 清理断开的连接 + if disconnected: + for ws in disconnected: + try: + self.remove_connection(ws) + except Exception as e: + log("ERROR", f"[WebSocketManager] Error removing connection: {e}") + + # 记录性能指标 + broadcast_duration = time.time() - broadcast_start_time + self._record_broadcast_performance(broadcast_duration, successful_sends) + + if successful_sends > 0: + avg_send_time = sum(send_times.values()) / len(send_times) if send_times else 0 + max_send_time = max(send_times.values()) if send_times else 0 + min_send_time = min(send_times.values()) if send_times else 0 + queue_remaining = message.get('data', {}).get('status', {}).get('exec_info', {}).get('queue_remaining', 0) + + log("DEBUG", f"[WebSocketManager] Broadcasted queue_status (remaining={queue_remaining}) to {successful_sends}/{len(connections)} clients " + f"(total: {broadcast_duration*1000:.1f}ms, " + f"avg_send: {avg_send_time*1000:.1f}ms, " + f"max_send: {max_send_time*1000:.1f}ms, " + f"min_send: {min_send_time*1000:.1f}ms, " + f"parallel_workers: {max_workers})") + + return successful_sends + + except Exception as e: + log("ERROR", f"[WebSocketManager] Error in queue status broadcast thread: {e}") + return 0 + + def get_task_subscribers(self, task_id: str) -> int: + """获取任务的订阅者数量""" + with self._lock: + return len(self._task_subscriptions.get(task_id, set())) + + # NOTE: update_task_id_association方法已废弃 - 使用x-fc-trace-id保持ID一致性 + # def update_task_id_association(self, old_task_id: str, new_task_id: str) -> bool: + # """更新任务ID关联 - 已废弃,通过x-fc-trace-id保持ID一致性""" + # # 通过传递x-fc-trace-id给GPU函数,CPU和GPU两边的requestId保持一致 + # # 无需动态更新WebSocket任务ID关联 + # log("WARNING", "[WebSocketManager] update_task_id_association method deprecated") + # return True + + def _send_initial_status(self, ws): + """ + 向新连接的WebSocket客户端发送初始队列状态 + """ + try: + # 只在CPU模式下发送初始状态 + import constants + if constants.COMFYUI_MODE != "cpu": + return + + # 获取当前队列状态 + from services.gateway import get_task_queue + task_queue = get_task_queue() + # 获取待处理任务数量 + pending_count = task_queue._get_pending_task_count() + + # 构建ComfyUI状态消息 + initial_status = { + "type": "status", + "data": { + "status": { + "exec_info": { + "queue_remaining": pending_count + } + } + } + } + + # 发送初始状态消息 + message = json.dumps(initial_status, ensure_ascii=False) + ws.send(message) + + log("DEBUG", f"[WebSocketManager] Sent initial status to new connection (queue_remaining: {pending_count})") + + except Exception as e: + log("ERROR", f"[WebSocketManager] Failed to send initial status: {e}") + + def _record_broadcast_performance(self, duration: float, successful_sends: int): + """ + 记录广播性能指标 + + Args: + duration: 广播耗时(秒) + successful_sends: 成功发送数 + """ + with self._lock: + # 记录广播时间,只保留最近100次 + self._performance_metrics['broadcast_times'].append(duration) + if len(self._performance_metrics['broadcast_times']) > 100: + self._performance_metrics['broadcast_times'].pop(0) + + # 更新统计 + self._performance_metrics['total_broadcasts'] += 1 + + def get_performance_metrics(self) -> dict: + """ + 获取WebSocket性能指标(阶段一优化:用于监控优化效果) + + Returns: + dict: 包含各种性能指标的字典 + """ + with self._lock: + broadcast_times = self._performance_metrics['broadcast_times'].copy() + send_failures = self._performance_metrics['send_failures'] + total_broadcasts = self._performance_metrics['total_broadcasts'] + connection_errors = self._performance_metrics['connection_errors'] + + # 计算统计数据 + metrics = { + 'active_connections': len(self.active_connections), + 'active_tasks': len(self._task_subscriptions), + 'total_broadcasts': total_broadcasts, + 'send_failures': send_failures, + 'connection_errors': connection_errors, + 'failure_rate': round(send_failures / max(total_broadcasts, 1) * 100, 2) + } + + if broadcast_times: + avg_time = sum(broadcast_times) / len(broadcast_times) + max_time = max(broadcast_times) + min_time = min(broadcast_times) + + metrics.update({ + 'avg_broadcast_ms': round(avg_time * 1000, 2), + 'max_broadcast_ms': round(max_time * 1000, 2), + 'min_broadcast_ms': round(min_time * 1000, 2), + 'broadcast_samples': len(broadcast_times) + }) + else: + metrics.update({ + 'avg_broadcast_ms': 0, + 'max_broadcast_ms': 0, + 'min_broadcast_ms': 0, + 'broadcast_samples': 0 + }) + + return metrics + + def print_performance_summary(self): + """ + 输出WebSocket性能摘要(阶段一优化:用于调试和监控) + """ + metrics = self.get_performance_metrics() + + log("INFO", "[WebSocketManager] === Performance Summary ===") + log("INFO", f"[WebSocketManager] Active connections: {metrics['active_connections']}") + log("INFO", f"[WebSocketManager] Active task subscriptions: {metrics['active_tasks']}") + log("INFO", f"[WebSocketManager] Total broadcasts: {metrics['total_broadcasts']}") + log("INFO", f"[WebSocketManager] Send failures: {metrics['send_failures']} ({metrics['failure_rate']}%)") + log("INFO", f"[WebSocketManager] Connection errors: {metrics['connection_errors']}") + + if metrics['broadcast_samples'] > 0: + log("INFO", f"[WebSocketManager] Broadcast timing (last {metrics['broadcast_samples']} samples):") + log("INFO", f"[WebSocketManager] Average: {metrics['avg_broadcast_ms']}ms") + log("INFO", f"[WebSocketManager] Min: {metrics['min_broadcast_ms']}ms") + log("INFO", f"[WebSocketManager] Max: {metrics['max_broadcast_ms']}ms") + else: + log("INFO", "[WebSocketManager] No broadcast timing data available") + + log("INFO", "[WebSocketManager] ================================") + + def reset_performance_metrics(self): + """ + 重置性能指标计数器 + """ + with self._lock: + self._performance_metrics = { + 'broadcast_times': [], + 'send_failures': 0, + 'total_broadcasts': 0, + 'connection_errors': 0 + } + log("INFO", "[WebSocketManager] Performance metrics reset") + def close_all_connections(self): with self._lock: for ws in self.active_connections: @@ -42,8 +737,15 @@ def close_all_connections(self): ws.send('Server shutting down') ws.close() except Exception as e: - print(f"Error closing WebSocket connection: {e}") + log("ERROR", f"Error closing WebSocket connection: {e}") self.active_connections.clear() + # 清理任务订阅记录 + self._task_subscriptions.clear() + self._client_subscriptions.clear() + + # 关闭线程池 + self._thread_pool.shutdown(wait=False) + log("INFO", "[WebSocketManager] All connections closed and thread pool shutdown") ws_manager = WebSocketManager() diff --git a/src/code/agent/services/serverlessapi/serverless_api_service.py b/src/code/agent/services/serverlessapi/serverless_api_service.py index 7d7f7bc..9a41af9 100644 --- a/src/code/agent/services/serverlessapi/serverless_api_service.py +++ b/src/code/agent/services/serverlessapi/serverless_api_service.py @@ -57,17 +57,21 @@ def __init__(self): # # 必要时,也可以参考对应代码实现基于 Redis、TableStore、MySQL 等方式的状态持久化 self.store: Store = FileSystem(f"{constants.MNT_DIR}/output/serverless_api") - + # 记录启动信息和日志级别 log("INFO", f"ServerlessApiService initialized with endpoint: {self.endpoint}") log("INFO", f"Current log level: {constants.LOG_LEVEL}") + # 用于检测状态变化的缓存 + self._last_status_cache = {} # task_id -> last_status_summary + self._cache_lock = threading.Lock() + def get_credentials(self): """ 获取阿里云访问凭证 - + 优先从 HTTP 请求 header 中获取,如果获取失败则从环境变量中获取。 - + Returns: tuple: (access_key_id, access_key_secret, security_token) """ @@ -98,9 +102,9 @@ def get_credentials(self): def get_oss_store(self): """ 创建 OSS 存储客户端 - + 使用当前请求的凭证创建 OSS 客户端实例,用于上传生成的图片/视频到阿里云 OSS。 - + Returns: OSS: OSS 客户端实例 """ @@ -119,16 +123,16 @@ def get_oss_store(self): def api_prompt(self, client_id: str, prompt: Any): """ 提交 ComfyUI 工作流任务 - + 将处理后的 prompt(工作流定义)提交给 ComfyUI 后端执行。 - + Args: client_id: WebSocket 客户端 ID,用于关联 WebSocket 连接 prompt: ComfyUI 工作流定义(节点图) - + Returns: dict: 包含 prompt_id 等信息的响应 - + Raises: ComfyUIException: 当 ComfyUI API 调用失败时抛出 """ @@ -159,13 +163,13 @@ def api_prompt(self, client_id: str, prompt: Any): def api_websocket(self, client_id: str, on_message): """ 创建 WebSocket 连接到 ComfyUI - + 用于实时接收任务执行状态更新(如进度、完成、错误等)。 - + Args: client_id: 客户端 ID,用于标识此连接 on_message: 消息回调函数,接收 (ws, message) 参数 - + Returns: WebSocketApp: WebSocket 应用实例 """ @@ -182,13 +186,13 @@ def api_websocket(self, client_id: str, on_message): def api_upload_image(self, content: bytes, overwrite: bool): """ 上传图片到 ComfyUI - + 将图片内容上传到 ComfyUI 的 input 目录,供工作流节点使用。 - + Args: content: 图片二进制内容 overwrite: 是否覆盖同名文件 - + Returns: dict: 包含上传后的文件信息(如 name 字段) """ @@ -212,10 +216,10 @@ def api_upload_image(self, content: bytes, overwrite: bool): def api_get_history(self, prompt_id: str): """ 获取任务执行历史 - + Args: prompt_id: 任务 ID - + Returns: dict: 包含任务执行结果、输出文件等信息 """ @@ -224,12 +228,12 @@ def api_get_history(self, prompt_id: str): def api_view_image(self, filename: str, img_type: str, sub_folder: str): """ 下载生成的图片/视频文件 - + Args: filename: 文件名 img_type: 文件类型(如 "output", "temp") sub_folder: 子目录名 - + Returns: bytes: 文件二进制内容 """ @@ -246,7 +250,7 @@ def api_view_image(self, filename: str, img_type: str, sub_folder: str): def api_clear_history(self): """ 清除 ComfyUI 的历史记录 - + 释放内存和磁盘空间。 """ requests.post(os.path.join(self.endpoint, "history"), json={"clear": True}) @@ -254,18 +258,18 @@ def api_clear_history(self): def parse_prompt(self, prompt: map): """ 预处理工作流定义 - + 自动处理以下内容: 1. LoadImage/LoadImageMask 节点:支持 HTTP URL、OSS URL、Base64 格式的图片输入 2. KSampler 节点:自动生成随机种子(当 seed=-1 时) 3. SaveImage 节点:自动添加实例 ID 到文件名前缀 - + Args: prompt: ComfyUI 工作流定义(字典格式) - + Returns: map: 处理后的工作流定义 - + Raises: Exception: 当图片加载失败时抛出 """ @@ -287,7 +291,7 @@ def parse_prompt(self, prompt: map): else: input_key = "image" file_type = "image" - + file_url = value.get("inputs", {}).get(input_key, "") content = "" @@ -296,7 +300,7 @@ def parse_prompt(self, prompt: map): log("DEBUG", f"downloading {file_type} from HTTP URL: {file_url}") start_time = time.perf_counter() response = requests.get(file_url) - + if response.status_code >= 400: raise Exception( f"can not get {file_type} {file_url} from http url, got status code {response.status_code}" @@ -305,7 +309,7 @@ def parse_prompt(self, prompt: map): content = response.content if content == "": raise Exception(f"can not get {file_type} {file_url} from http url") - + elapsed = time.perf_counter() - start_time log("INFO", f"successfully downloaded {file_type} from HTTP URL ({len(content)} bytes) in {elapsed:.2f}s") @@ -322,9 +326,9 @@ def parse_prompt(self, prompt: map): if content == "": raise Exception(f"can not get {file_type} {file_url} from oss") - + log("DEBUG", f"successfully downloaded {file_type} from OSS ({len(content)} bytes) in {elapsed:.2f}s") - + elif len(file_url) > 64: # 文件可能是 base64,尝试解析 log("DEBUG", f"decoding {file_type} from Base64") @@ -337,7 +341,7 @@ def parse_prompt(self, prompt: map): elapsed = time.perf_counter() - start_time log("DEBUG", f"failed to decode {file_type} from Base64 in {elapsed:.2f}s") pass - + if content: # 上传文件并更新对应的输入字段 log("DEBUG", f"uploading {file_type} to ComfyUI") @@ -346,7 +350,7 @@ def parse_prompt(self, prompt: map): elapsed = time.perf_counter() - start_time log("INFO", f"successfully uploaded {file_type} to ComfyUI as '{res['name']}' in {elapsed:.2f}s") prompt[key]["inputs"][input_key] = res["name"] - + except Exception as e: raise Exception(f"{class_type} failed: {e}") @@ -369,16 +373,16 @@ def parse_prompt(self, prompt: map): def get_history_result(self, prompt_id: str, output_base64=False, output_oss=False): """ 获取任务执行结果并处理输出文件 - + 从 ComfyUI 历史记录中提取输出文件(图片/视频),并根据参数选择: - 下载文件并转换为 Base64 - 上传文件到 OSS 并生成签名 URL - + Args: prompt_id: 任务 ID output_base64: 是否将输出文件转换为 Base64 编码 output_oss: 是否上传输出文件到 OSS - + Returns: dict: 包含所有输出结果的字典,格式: { @@ -412,7 +416,7 @@ def get_history_result(self, prompt_id: str, output_base64=False, output_oss=Fal # 调试日志:查看实际的数据类型和结构 log("DEBUG", f"node_id={node_id}, output_type={output_type}, index={index}") log("DEBUG", f"img type: {type(img)}, img: {img}") - + if type(img) != dict or not img.get("filename"): log("DEBUG", f"skipping: type check={type(img) != dict}, filename check={not img.get('filename') if isinstance(img, dict) else 'N/A'}") continue @@ -492,7 +496,7 @@ def get_history_result(self, prompt_id: str, output_base64=False, output_oss=Fal def put_status_to_store(self, task_id: str, status: str): """ 保存任务状态到持久化存储 - + 用于异步场景,状态信息以追加方式存储,每条状态占一行。 这样多个实例可以通过共享存储(如 NAS)查询任务状态。 @@ -512,19 +516,41 @@ def put_status_to_store(self, task_id: str, status: str): def get_status_from_store(self, task_id: str): """ 从持久化存储中读取任务状态历史 - + Args: task_id: 任务 ID - + Returns: list: 状态历史列表,每个元素是一条状态消息(已解析为字典) """ if self.store: value = self.store.get(task_id) - return [json.loads(line) for line in value.split("\n") if line] + result = [json.loads(line) for line in value.split("\n") if line] + + # 检测状态变化并打印日志 + if result: + # 生成当前状态摘要 + last_status = result[-1] if result else {} + status_type = last_status.get('type', 'unknown') + current_summary = f"{len(result)}:{status_type}" + + # 使用锁保护对_last_status_cache的访问 + with self._cache_lock: + # 获取上一次的状态摘要 + last_summary = self._last_status_cache.get(task_id) + + # 如果状态发生变化,打印日志 + if last_summary != current_summary: + log("INFO", f"[Status Change] Task {task_id}: {last_summary or 'initial'} -> {current_summary}") + + # 更新缓存 + self._last_status_cache[task_id] = current_summary + + return result else: return [] + def run( self, prompt: map, @@ -535,7 +561,7 @@ def run( ): """ 执行 ComfyUI 工作流(Serverless API 核心方法) - + 完整流程: 1. 预处理工作流定义(parse_prompt) 2. 建立 WebSocket 连接监听任务状态 @@ -543,17 +569,17 @@ def run( 4. 等待任务完成 5. 获取并处理输出结果 6. 保存状态到持久化存储 - + Args: prompt: ComfyUI 工作流定义 output_base64: 是否将输出转换为 Base64(适用于小文件) output_oss: 是否上传输出到 OSS(推荐用于生产环境) callback: WebSocket 消息回调函数,接收原始消息 task_id: 任务 ID,用于状态持久化(默认使用 prompt_id) - + Returns: dict: 任务执行结果,包含所有输出文件信息 - + Raises: ComfyUIException: 当 ComfyUI 执行出错时 Exception: 其他异常 @@ -574,7 +600,7 @@ def on_message(ws: websocket.WebSocket, message: str): # 忽略空消息 if not message or not message.strip(): return - + # 尝试解析 JSON try: msg = json.loads(message) @@ -677,13 +703,22 @@ def on_message(ws: websocket.WebSocket, message: str): prompt_id = prompt_result.get("prompt_id", "") log("DEBUG", f"workflow submitted, prompt_id: {prompt_id}") + print(f"[ServerlessApiService] Before task_id processing: received_task_id='{task_id}' (type={type(task_id)}), generated_prompt_id='{prompt_id}'") + # 如果 task id 未指定,则使用 prompt id + original_task_id = task_id if not task_id: task_id = prompt_id + print(f"[ServerlessApiService] Task ID was empty/None, replaced with prompt_id: '{original_task_id}' -> '{task_id}'") + else: + print(f"[ServerlessApiService] Using provided task_id: '{task_id}'") if not prompt_id: raise Exception("can not get prompt_id from ComfyUI") + # 记录执行开始时间 + execution_start_time = time.time() + # 已经有结果,则不必等待 if len(self.api_get_history(prompt_id)) > 0: ws.close() @@ -691,16 +726,15 @@ def on_message(ws: websocket.WebSocket, message: str): # 等待工作流完成:WebSocket(主) + 轮询历史记录(备用) check_interval = int(os.getenv("SERVERLESS_API_CHECK_INTERVAL", "60")) log("INFO", f"waiting for prompt {prompt_id} to complete (check_interval={check_interval}s)") - start_time = time.time() - + while ws_threading.is_alive(): # 等待一小段时间 ws_threading.join(timeout=check_interval) - + # 如果线程已结束,退出循环。仅当 ws.close() 被调用时,线程才会结束。 if not ws_threading.is_alive(): break - + # 定期检查历史记录(备用检测) try: if len(self.api_get_history(prompt_id)) > 0: @@ -709,9 +743,10 @@ def on_message(ws: websocket.WebSocket, message: str): break except Exception as e: log("DEBUG", f"history check failed: {e}") - - elapsed = time.time() - start_time - log("INFO", f"workflow completed in {elapsed:.1f}s") + + # 计算执行时间 + execution_time = time.time() - execution_start_time + log("INFO", f"workflow completed in {execution_time:.2f}s") if ws_err: log("ERROR", f"websocket error occurred: {ws_err}") @@ -721,9 +756,14 @@ def on_message(ws: websocket.WebSocket, message: str): result = self.get_history_result( prompt_id, output_base64=output_base64, output_oss=output_oss ) + + # 添加执行时间到结果数据中 + if result and "data" in result: + result["data"]["execution_time"] = execution_time + log("DEBUG", f"saving result to store for task_id: {task_id}") self.put_status_to_store(task_id, json.dumps(result)) - log("INFO", f"finished running prompt: {prompt_id}") + log("INFO", f"finished running prompt: {prompt_id} in {execution_time:.2f}s") return result except ComfyUIException as e: self.put_status_to_store( diff --git a/src/code/agent/store/filesystem.py b/src/code/agent/store/filesystem.py index c31cc78..2a1dae3 100644 --- a/src/code/agent/store/filesystem.py +++ b/src/code/agent/store/filesystem.py @@ -11,7 +11,19 @@ def __file_path(self, key: str): def get(self, key: str) -> str: try: - with open(self.__file_path(key), "r", encoding="utf-8") as f: + file_path = self.__file_path(key) + + # 刷新NFS/网络文件系统缓存 + try: + # 1. 刷新目录缓存 - 触发目录元数据更新 + os.listdir(os.path.dirname(file_path)) + # 2. 使用os.stat()获取最新文件元数据 + os.stat(file_path) + except: + pass + + # 读取文件内容 + with open(file_path, "r", encoding="utf-8") as f: return f.read() except Exception as e: if "No such file or directory" not in str(e): diff --git a/src/code/agent/test/__init__.py b/src/code/agent/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/code/agent/test/unit/__init__.py b/src/code/agent/test/unit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/code/agent/test/unit/services/__init__.py b/src/code/agent/test/unit/services/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/code/agent/test/unit/services/backend_process_manager_test.py b/src/code/agent/test/unit/services/backend_process_manager_test.py deleted file mode 100644 index c929030..0000000 --- a/src/code/agent/test/unit/services/backend_process_manager_test.py +++ /dev/null @@ -1,153 +0,0 @@ -import signal -import time -from pathlib import Path - -import pytest - -from services.process.backend_process_manager import BackendProcessManager - - -@pytest.fixture -def process_manager(): - return BackendProcessManager() - - -def test_start_real_process(process_manager): - script_path = Path(__file__).parent / "mock_comfyui_process.py" - - # 启动进程 - command = ['python3', str(script_path)] - process_manager.start(command) - - assert process_manager.process is not None - assert process_manager.process.pid > 0 - assert process_manager.is_ready() is False # 子进程Readiness探针未就绪 - - # 等待子进程启动完成 - process_manager.wait_until_ready() - - # 等待正常运行N秒 - time.sleep(10) - - # 停止进程 - process_manager.stop() - - # 验证进程已终止 - time.sleep(1) - assert process_manager.is_ready() is False - - -def test_process_manually_restart_on_unexpected_exit(process_manager): - script_path = Path(__file__).parent / "mock_comfyui_process.py" - - # 首次启动进程 - command = ['python3', str(script_path)] - process_manager.start(command) - - # 等待进程就绪 - process_manager.wait_until_ready() - - # 记录第一次启动的进程ID - first_pid = process_manager.process.pid - - # 等待几秒确保进程稳定运行 - time.sleep(2) - - # 模拟进程意外退出(发送SIGTERM信号) - process_manager.process.terminate() - - # 等待进程完全退出 - process_manager.process.wait() - time.sleep(1) - - # 验证进程已经不再运行 - assert process_manager.is_ready() is False - - # 重新启动进程 - process_manager.start(command) - - # 等待新进程就绪 - process_manager.wait_until_ready() - - # 验证是新的进程(PID不同) - assert process_manager.process.pid != first_pid - - # 最后停止进程 - process_manager.stop() - - # 验证进程已完全终止 - time.sleep(1) - assert process_manager.is_ready() is False - - -def test_process_auto_restart_on_unexpected_exit(process_manager): - script_path = Path(__file__).parent / "mock_comfyui_process.py" - - # 启动进程 - command = ['python3', str(script_path)] - process_manager.start(command) - - # 等待进程就绪 - process_manager.wait_until_ready() - - # 记录第一次启动的进程ID - first_pid = process_manager.process.pid - - # 等待几秒确保进程稳定运行和健康检查线程启动 - time.sleep(2) - - # 模拟进程意外退出(发送SIGTERM信号) - process_manager.process.send_signal(signal.SIGHUP) - - # 等待足够长的时间,让健康检查线程检测到进程死亡并记录日志 - time.sleep(10) - - # 验证健康检查线程仍在运行 - assert process_manager.health_check_thread.is_alive() - - # 捕获并验证控制台输出包含预期的消息 - - - # 最后停止进程 - process_manager.stop() - - # 验证进程已完全终止 - time.sleep(1) - assert not process_manager.is_ready() - assert not process_manager.is_alive() - - -def test_wait_until_ready_timeout(process_manager): - script_path = Path(__file__).parent / "mock_never_ready_process.py" - - # 创建一个永不就绪的mock进程脚本 - with open(script_path, 'w') as f: - f.write(''' - import time - while True: - time.sleep(1) - ''') - - # 启动进程 - command = ['python3', str(script_path)] - process_manager.start(command) - - # 设置一个较短的超时时间进行测试 - short_timeout = 3 - - # 验证等待超时会抛出RuntimeError - with pytest.raises(RuntimeError) as exc_info: - process_manager.wait_until_ready(timeout=short_timeout) - - assert "Process startup timed out" in str(exc_info.value) - - # 清理临时文件 - script_path.unlink() - - -@pytest.fixture(autouse=True) -def cleanup(process_manager): - yield - if process_manager.process: - process_manager.stop() - time.sleep(1) # 确保进程完全终止 diff --git a/src/code/agent/test/unit/services/management_service_test.py b/src/code/agent/test/unit/services/management_service_test.py deleted file mode 100644 index 475b02d..0000000 --- a/src/code/agent/test/unit/services/management_service_test.py +++ /dev/null @@ -1,194 +0,0 @@ -import time -import pytest -from unittest.mock import Mock, patch -from concurrent.futures import ThreadPoolExecutor, as_completed - -from exceptions.exceptions import StateTransitionError -from services.management_service import ManagementService, BackendStatus, Action - -@pytest.fixture -def mock_process_mgr(): - with patch('services.process.backend_process_manager.BackendProcessManager') as mock: - instance = mock.return_value - instance.start = Mock() - instance.wait_until_ready = Mock(side_effect=lambda: time.sleep(1)) - instance.stop = Mock() - yield instance - - -@pytest.fixture -def mock_snapshot_mgr(): - with patch('services.workspace.snapshot_manager.SnapshotManager') as mock: - instance = mock.return_value - - def load_with_delay(snapshot_name): - time.sleep(1) # 延迟1秒 - return {"time_load": 1.0} - - def save_with_delay(snapshot_type): - time.sleep(1) # 延迟1秒 - return {"time_save": 1.0} - - instance.load = Mock(side_effect=load_with_delay) - instance.save = Mock(side_effect=save_with_delay) - yield instance - -@pytest.fixture -def service(mock_process_mgr, mock_snapshot_mgr): - # 确保每次测试都使用新的 service 实例 - ManagementService._instances = {} # 清除单例缓存 - service = ManagementService() - service._process_mgr = mock_process_mgr - service._snapshot_mgr = mock_snapshot_mgr - service._status = BackendStatus.STOPPED # 确保初始状态为 STOPPED - service._latest_action = None # 重置最后操作 - service._sub_status = "" # 重置子状态 - return service - -def test_concurrent_start(service, mock_process_mgr, mock_snapshot_mgr): - thread_count = 3 - - def start_service(): - try: - initial_status = service.status - service.start("test_snapshot") - return { - 'success': True, - 'initial_status': initial_status, - 'final_status': service.status - } - except Exception as e: - return { - 'success': False, - 'initial_status': service.status, - 'final_status': service.status, - 'error': str(e) - } - - with ThreadPoolExecutor(max_workers=thread_count) as executor: - futures = [executor.submit(start_service) for _ in range(thread_count)] - start_attempts = [f.result() for f in as_completed(futures)] - - successful_starts = [attempt for attempt in start_attempts if attempt['success']] - failed_starts = [attempt for attempt in start_attempts if not attempt['success']] - - assert len(successful_starts) == 1 - assert len(failed_starts) == thread_count - 1 - - successful_start = successful_starts[0] - assert successful_start['initial_status'] == BackendStatus.STOPPED - assert successful_start['final_status'] == BackendStatus.RUNNING - - for failed_start in failed_starts: - assert "Illegal state transition" in str(failed_start['error']) - - assert service.status == BackendStatus.RUNNING - assert mock_snapshot_mgr.load.call_count == 1 - assert mock_process_mgr.start.call_count == 1 - assert mock_process_mgr.wait_until_ready.call_count == 1 - -def test_start_during_starting_fails(service): - def slow_start(): - service._transition_to(BackendStatus.STARTING, Action.START) - time.sleep(0.5) - service._transition_to(BackendStatus.RUNNING, Action.START) - - with ThreadPoolExecutor(max_workers=2) as executor: - first_start = executor.submit(slow_start) - time.sleep(0.1) - - with pytest.raises(StateTransitionError): - service.start("test_snapshot") - - first_start.result() - - assert service.status == BackendStatus.RUNNING - -def test_initial_status(service): - assert service.status == BackendStatus.STOPPED - -def test_start_success(service, mock_process_mgr, mock_snapshot_mgr): - service.start("test_snapshot") - - mock_snapshot_mgr.load.assert_called_once_with("test_snapshot") - mock_process_mgr.start.assert_called_once() - mock_process_mgr.wait_until_ready.assert_called_once() - assert service.status == BackendStatus.RUNNING - -def test_start_failure(service, mock_process_mgr, mock_snapshot_mgr): - mock_process_mgr.start.side_effect = Exception("Start failed") - - with pytest.raises(Exception): - service.start("test_snapshot") - - assert service.status == BackendStatus.STOPPED - -def test_save_success(service, mock_snapshot_mgr): - service._status = BackendStatus.RUNNING - service.save("test_type") - - mock_snapshot_mgr.save.assert_called_once_with("test_type") - assert service.status == BackendStatus.RUNNING - -def test_save_failure(service, mock_snapshot_mgr): - service._status = BackendStatus.RUNNING - mock_snapshot_mgr.save.side_effect = Exception("Save failed") - - with pytest.raises(Exception): - service.save("test_type") - - assert service.status == BackendStatus.RUNNING - -def test_stop_success(service, mock_process_mgr): - service._status = BackendStatus.RUNNING - service.stop() - - mock_process_mgr.stop.assert_called_once() - assert service.status == BackendStatus.STOPPED - -def test_stop_failure(service, mock_process_mgr): - service._status = BackendStatus.RUNNING - mock_process_mgr.stop.side_effect = Exception("Stop failed") - - with pytest.raises(Exception): - service.stop() - - assert service.status == BackendStatus.RUNNING - -def test_save_and_stop(service): - service._status = BackendStatus.RUNNING - with patch.object(service, 'save') as mock_save: - with patch.object(service, 'stop') as mock_stop: - service.save_and_stop("test_type") - - mock_save.assert_called_once_with("test_type") - mock_stop.assert_called_once() - -def test_invalid_transition(service): - with pytest.raises(StateTransitionError): - service._transition_to(BackendStatus.RUNNING, Action.START) - -@pytest.mark.parametrize("current_status,new_status,action", [ - (BackendStatus.STOPPED, BackendStatus.STARTING, Action.START), - (BackendStatus.STARTING, BackendStatus.RUNNING, Action.START), - (BackendStatus.RUNNING, BackendStatus.SAVING, Action.SAVE), - (BackendStatus.SAVING, BackendStatus.RUNNING, Action.SAVE), - (BackendStatus.RUNNING, BackendStatus.STOPPING, Action.STOP), - (BackendStatus.STOPPING, BackendStatus.STOPPED, Action.STOP), -]) -def test_valid_transitions(service, current_status, new_status, action): - service._status = current_status - service._transition_to(new_status, action) - assert service.status == new_status - assert service.latest_action == action - -@pytest.mark.parametrize("current_status,new_status,action", [ - (BackendStatus.STOPPED, BackendStatus.RUNNING, Action.START), - (BackendStatus.RUNNING, BackendStatus.STARTING, Action.START), - (BackendStatus.SAVING, BackendStatus.STOPPED, Action.STOP), - (BackendStatus.STOPPING, BackendStatus.SAVING, Action.SAVE), -]) -def test_invalid_transitions(service, current_status, new_status, action): - service._status = current_status - with pytest.raises(StateTransitionError): - service._transition_to(new_status, action) diff --git a/src/code/agent/test/unit/services/mock_comfyui_process.py b/src/code/agent/test/unit/services/mock_comfyui_process.py deleted file mode 100644 index ee0a361..0000000 --- a/src/code/agent/test/unit/services/mock_comfyui_process.py +++ /dev/null @@ -1,84 +0,0 @@ -import signal -import time -import sys -import os -from flask import Flask -import threading -from werkzeug.serving import make_server - -app = Flask(__name__) -server = None - - -class ServerThread(threading.Thread): - def __init__(self, app): - threading.Thread.__init__(self) - self.server = make_server('127.0.0.1', 8188, app) - self.ctx = app.app_context() - self.ctx.push() - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -@app.route('/', methods=['GET']) -def hello(): - return "Flask server is running!" - - -def signal_handler(signum, frame): - print(f"Received signal {signum}") - if signum == signal.SIGTERM: - print("Shutting down...") - if server: - server.shutdown() - sys.exit(0) - elif signum == signal.SIGHUP: - print("Restarting server...") - if server: - server.shutdown() - os.execv(sys.executable, ['python3'] + sys.argv) - - -def main(): - current_pid = os.getpid() - print(f"Current process PID: {current_pid}") - - # 注册信号处理器 - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGHUP, signal_handler) - - # 模拟启动过程 - print("Flask server starting...") - counter = 0 - while counter < 3: - print(f"Flask server boot log message #{counter}") - counter += 1 - time.sleep(1) - - try: - # 创建并启动服务器线程 - global server - server = ServerThread(app) - server.start() - print("Server is listening on port 8188") - - # 主线程继续输出日志 - counter = 0 - while True: - print(f"Flask server log message #{counter}") - counter += 1 - time.sleep(1) - - except Exception as e: - print(f"Failed to start server: {e}") - if server: - server.shutdown() - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/src/code/agent/test/unit/services/pip/__init__.py b/src/code/agent/test/unit/services/pip/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/code/agent/test/unit/services/pip/pip_installer_test.py b/src/code/agent/test/unit/services/pip/pip_installer_test.py deleted file mode 100644 index 45610a9..0000000 --- a/src/code/agent/test/unit/services/pip/pip_installer_test.py +++ /dev/null @@ -1,1347 +0,0 @@ -import os -import shutil -import tempfile -import unittest -from unittest.mock import patch, MagicMock, mock_open - -from services.pip.pip_installer import PIPInstaller, DependencyInfo, InstallRecord, DependencyInstallRecord - - -class TestPIPInstaller(unittest.TestCase): - def setUp(self): - """测试初始化""" - self.test_dir = tempfile.mkdtemp() - self.comfyui_dir = os.path.join(self.test_dir, "comfyui") - self.custom_nodes_dir = os.path.join(self.comfyui_dir, "custom_nodes") - os.makedirs(self.custom_nodes_dir, exist_ok=True) - - # Mock constants - self.constants_patcher = patch('services.pip.pip_installer.constants') - self.mock_constants = self.constants_patcher.start() - self.mock_constants.COMFYUI_DIR = self.comfyui_dir - self.mock_constants.VENV_EXECUTABLE = "/test/venv/bin/python" - - def tearDown(self): - """测试清理""" - self.constants_patcher.stop() - shutil.rmtree(self.test_dir, ignore_errors=True) - - def create_test_node_dir(self, node_name): - """创建测试节点目录""" - node_dir = os.path.join(self.custom_nodes_dir, node_name) - os.makedirs(node_dir, exist_ok=True) - return node_dir - - def create_requirements_file(self, node_dir, content): - """创建 requirements.txt 文件""" - req_file = os.path.join(node_dir, "requirements.txt") - with open(req_file, "w") as f: - f.write(content) - return req_file - - def create_install_script(self, node_dir, content): - """创建 install.py 脚本""" - install_file = os.path.join(node_dir, "install.py") - with open(install_file, "w") as f: - f.write(content) - return install_file - - -class TestDataStructures(TestPIPInstaller): - """测试数据结构""" - - def test_install_record_creation(self): - """测试 InstallRecord 创建""" - record = InstallRecord( - node_name="test_node", - script_name="install.py", - duration=10.5, - success=True, - error_msg="" - ) - - self.assertEqual(record.node_name, "test_node") - self.assertEqual(record.script_name, "install.py") - self.assertEqual(record.duration, 10.5) - self.assertTrue(record.success) - self.assertEqual(record.error_msg, "") - - # 测试 to_dict 方法 - record_dict = record.to_dict() - expected = { - "node_name": "test_node", - "script_name": "install.py", - "duration": 10.5, - "success": True, - "error_msg": "" - } - self.assertEqual(record_dict, expected) - - def test_dependency_install_record_creation(self): - """测试 DependencyInstallRecord 创建""" - requirements_txt = "torch>=1.8.0\nnumpy>=1.20.0" - record = DependencyInstallRecord( - requirements_txt=requirements_txt, - duration=120.0, - success=False, - error_msg="Installation timeout" - ) - - self.assertEqual(record.requirements_txt, requirements_txt) - self.assertEqual(record.duration, 120.0) - self.assertFalse(record.success) - self.assertEqual(record.error_msg, "Installation timeout") - - # 测试 to_dict 方法 - record_dict = record.to_dict() - expected = { - "requirements_txt": requirements_txt, - "duration": 120.0, - "success": False, - "error_msg": "Installation timeout" - } - self.assertEqual(record_dict, expected) - - -class TestPIPInstallerInit(TestPIPInstaller): - """测试初始化功能""" - - @patch('services.pip.pip_installer.subprocess.check_output') - def test_init_without_blacklist(self, mock_subprocess): - """测试不带黑名单的初始化""" - mock_subprocess.return_value = "Package Version\nrequests 2.28.0\nnumpy 1.21.0" - - installer = PIPInstaller() - - self.assertEqual(installer.blacklist, set()) - self.assertIsInstance(installer._origin_packages, dict) - # 不再有 _history 成员变量 - - @patch('services.pip.pip_installer.subprocess.check_output') - def test_init_with_blacklist(self, mock_subprocess): - """测试带黑名单的初始化""" - mock_subprocess.return_value = "Package Version\nrequests 2.28.0" - blacklist = ["torch", "tensorflow"] - - installer = PIPInstaller(blacklist=blacklist) - - self.assertEqual(installer.blacklist, {"torch", "tensorflow"}) - - -class TestPackageSpecParsing(TestPIPInstaller): - """测试包规范解析功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_parse_regular_package_with_version(self): - """测试普通包带版本的解析""" - test_cases = [ - ("torch>=1.8.0", ("torch", ">=1.8.0")), - ("numpy==1.21.0", ("numpy", "==1.21.0")), - ("requests>=2.25.0,<3.0.0", ("requests", ">=2.25.0,<3.0.0")), - ("pillow>=8.0.0,!=8.3.0", ("pillow", ">=8.0.0,!=8.3.0")), - ] - - for package_spec, expected in test_cases: - with self.subTest(package_spec=package_spec): - result = self.installer._parse_package_spec(package_spec) - self.assertEqual(result, expected) - - def test_parse_regular_package_without_version(self): - """测试普通包不带版本的解析""" - result = self.installer._parse_package_spec("requests") - self.assertEqual(result, ("requests", "")) - - def test_parse_git_dependencies(self): - """测试 git+ 依赖的解析""" - test_cases = [ - "git+https://github.com/user/repo.git", - "git+https://github.com/user/repo.git@main", - "git+https://github.com/user/repo.git#egg=package", - "hg+https://bitbucket.org/user/repo", - "svn+https://svn.example.com/repo", - "bzr+https://bzr.example.com/repo", - ] - - for git_url in test_cases: - with self.subTest(git_url=git_url): - result = self.installer._parse_package_spec(git_url) - self.assertEqual(result, (git_url, "")) - - def test_parse_invalid_package_spec(self): - """测试无效包规范的解析""" - result = self.installer._parse_package_spec("invalid@#$%^&*()") - self.assertEqual(result, ("invalid@#$%^&*()", "")) - - -class TestVersionConflictResolution(TestPIPInstaller): - """测试版本冲突解决功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_exact_vs_exact_newer_wins(self): - """测试精确版本 vs 精确版本:选择较新的""" - result = self.installer._resolve_version_conflict("==1.0.0", "==2.0.0", "torch") - self.assertEqual(result, "==2.0.0") - - def test_exact_vs_exact_keep_newer(self): - """测试精确版本 vs 精确版本:保持较新的""" - result = self.installer._resolve_version_conflict("==2.0.0", "==1.0.0", "torch") - self.assertEqual(result, "==2.0.0") - - def test_exact_vs_range_keep_exact(self): - """测试精确版本 vs 范围版本:保持精确版本""" - result = self.installer._resolve_version_conflict("==1.5.0", ">=1.0.0", "numpy") - self.assertEqual(result, "==1.5.0") - - def test_range_vs_exact_choose_exact(self): - """测试范围版本 vs 精确版本:选择精确版本""" - result = self.installer._resolve_version_conflict(">=1.0.0", "==1.5.0", "numpy") - self.assertEqual(result, "==1.5.0") - - @patch('services.pip.pip_installer.PIPInstaller._try_merge_version_ranges') - def test_range_vs_range_merge(self, mock_merge): - """测试范围版本 vs 范围版本:尝试合并""" - mock_merge.return_value = ">=1.5.0" - - result = self.installer._resolve_version_conflict(">=1.0.0", ">=1.5.0", "torch") - - mock_merge.assert_called_once_with(">=1.0.0", ">=1.5.0") - self.assertEqual(result, ">=1.5.0") - - def test_same_version_specs(self): - """测试相同版本规范""" - result = self.installer._resolve_version_conflict(">=1.0.0", ">=1.0.0", "torch") - self.assertEqual(result, ">=1.0.0") - - def test_empty_specs(self): - """测试空版本规范""" - result = self.installer._resolve_version_conflict("", ">=1.0.0", "torch") - self.assertEqual(result, ">=1.0.0") - - result = self.installer._resolve_version_conflict(">=1.0.0", "", "torch") - self.assertEqual(result, ">=1.0.0") - - -class TestVersionRangeMerging(TestPIPInstaller): - """测试版本范围合并功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - @patch('services.pip.pip_installer.PIPInstaller._try_merge_with_packaging') - def test_merge_with_packaging_success(self, mock_packaging): - """测试使用 packaging 库成功合并""" - mock_packaging.return_value = ">=1.5.0,<2.0.0" - - result = self.installer._try_merge_version_ranges(">=1.0.0,<2.0.0", ">=1.5.0,<3.0.0") - - self.assertEqual(result, ">=1.5.0,<2.0.0") - mock_packaging.assert_called_once() - - @patch('services.pip.pip_installer.PIPInstaller._try_merge_with_packaging') - @patch('services.pip.pip_installer.PIPInstaller._try_merge_simple_ranges') - def test_merge_fallback_to_simple(self, mock_simple, mock_packaging): - """测试回退到简单合并逻辑""" - mock_packaging.return_value = None - mock_simple.return_value = ">=1.5.0" - - result = self.installer._try_merge_version_ranges(">=1.0.0", ">=1.5.0") - - self.assertEqual(result, ">=1.5.0") - mock_packaging.assert_called_once() - mock_simple.assert_called_once() - - def test_simple_range_merge_success(self): - """测试简单范围合并成功""" - result = self.installer._try_merge_simple_ranges(">=1.0.0", ">=1.5.0") - self.assertEqual(result, ">=1.5.0") - - result = self.installer._try_merge_simple_ranges(">=2.0.0", ">=1.5.0") - self.assertEqual(result, ">=2.0.0") - - def test_simple_range_merge_failure(self): - """测试简单范围合并失败""" - result = self.installer._try_merge_simple_ranges(">=1.0.0,<2.0.0", ">=1.5.0") - self.assertIsNone(result) - - @patch('builtins.__import__') - def test_merge_with_packaging_import_error(self, mock_import): - """测试 packaging 库导入失败""" - mock_import.side_effect = ImportError("No module named 'packaging'") - - result = self.installer._try_merge_with_packaging(">=1.0.0", ">=1.5.0") - - self.assertIsNone(result) - - -class TestDependencyFiltering(TestPIPInstaller): - """测试依赖过滤功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output') as mock_subprocess: - mock_subprocess.return_value = "Package Version\nrequests 2.28.0\nnumpy 1.21.0" - self.installer = PIPInstaller(blacklist=["torch", "tensorflow"]) - - def test_filter_git_dependencies(self): - """测试过滤 git+ 依赖""" - self.installer._merged_dependencies = { - "git+https://github.com/user/repo.git": DependencyInfo( - package_name="git+https://github.com/user/repo.git", - version_spec="", - original_line="git+https://github.com/user/repo.git", - source_nodes=["node1"] - ), - "numpy": DependencyInfo( - package_name="numpy", - version_spec=">=1.20.0", - original_line="numpy>=1.20.0", - source_nodes=["node1"] - ) - } - - with patch('builtins.print'): # Suppress print output - filtered = self.installer._filter_merged_dependencies() # 更新方法名 - - self.assertNotIn("git+https://github.com/user/repo.git", filtered) - # numpy虽然已安装,但有版本要求,不应被过滤(让pip处理版本升级) - self.assertIn("numpy", filtered) - - def test_filter_blacklisted_packages(self): - """测试过滤黑名单包""" - self.installer._merged_dependencies = { - "torch": DependencyInfo( - package_name="torch", - version_spec=">=1.8.0", - original_line="torch>=1.8.0", - source_nodes=["node1"] - ), - "requests": DependencyInfo( - package_name="requests", - version_spec=">=2.25.0", - original_line="requests>=2.25.0", - source_nodes=["node1"] - ) - } - - with patch('builtins.print'): - filtered = self.installer._filter_merged_dependencies() - - self.assertNotIn("torch", filtered) # Blacklisted - # requests虽然已安装,但有版本要求,不应被过滤(让pip处理版本升级) - self.assertIn("requests", filtered) - - def test_filter_already_installed_packages_with_version_requirement(self): - """测试已安装但有版本要求的包不被过滤(交给pip处理升级)""" - self.installer._merged_dependencies = { - "numpy": DependencyInfo( # numpy已安装,但有版本要求 - package_name="numpy", - version_spec=">=1.25.0", # 需要升级到 1.25.0+ - original_line="numpy>=1.25.0", - source_nodes=["node1"] - ), - "new-package": DependencyInfo( # 新包,未安装 - package_name="new-package", - version_spec=">=1.0.0", - original_line="new-package>=1.0.0", - source_nodes=["node1"] - ) - } - - with patch('builtins.print'): - filtered = self.installer._filter_merged_dependencies() - - # numpy虽然已安装,但有版本要求,不应被过滤 - self.assertIn("numpy", filtered) - # 新包未安装,不应被过滤 - self.assertIn("new-package", filtered) - - def test_filter_already_installed_packages_without_version_requirement(self): - """测试已安装且无版本要求的包被过滤""" - self.installer._merged_dependencies = { - "numpy": DependencyInfo( # numpy已安装,无版本要求 - package_name="numpy", - version_spec="", # 无版本要求 - original_line="numpy", - source_nodes=["node1"] - ), - "requests": DependencyInfo( # requests已安装,无版本要求 - package_name="requests", - version_spec="", - original_line="requests", - source_nodes=["node1"] - ) - } - - with patch('builtins.print'): - filtered = self.installer._filter_merged_dependencies() - - # 无版本要求且已安装的包应被过滤 - self.assertNotIn("numpy", filtered) - self.assertNotIn("requests", filtered) - - -class TestRequirementsMerging(TestPIPInstaller): - """测试 requirements.txt 合并功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - @patch('services.pip.pip_installer.file_ops.robust_readlines') - def test_merge_requirements_file_new_dependency(self, mock_readlines): - """测试合并新依赖""" - mock_readlines.return_value = ["torch>=1.8.0\n", "numpy==1.21.0\n"] - - self.installer._merge_requirements_file("/fake/requirements.txt", "test_node") - - self.assertEqual(len(self.installer._merged_dependencies), 2) - self.assertIn("torch", self.installer._merged_dependencies) - self.assertIn("numpy", self.installer._merged_dependencies) - - torch_dep = self.installer._merged_dependencies["torch"] - self.assertEqual(torch_dep.version_spec, ">=1.8.0") - self.assertEqual(torch_dep.source_nodes, ["test_node"]) - - @patch('services.pip.pip_installer.file_ops.robust_readlines') - def test_merge_requirements_file_conflicting_dependency(self, mock_readlines): - """测试合并冲突依赖""" - # 先添加一个依赖 - self.installer._merged_dependencies["torch"] = DependencyInfo( - package_name="torch", - version_spec=">=1.7.0", - original_line="torch>=1.7.0", - source_nodes=["node1"] - ) - - mock_readlines.return_value = ["torch>=1.8.0\n"] - - with patch.object(self.installer, '_resolve_version_conflict') as mock_resolve: - mock_resolve.return_value = ">=1.8.0" - - self.installer._merge_requirements_file("/fake/requirements.txt", "node2") - - mock_resolve.assert_called_once_with(">=1.7.0", ">=1.8.0", "torch") - torch_dep = self.installer._merged_dependencies["torch"] - self.assertEqual(torch_dep.version_spec, ">=1.8.0") - self.assertEqual(set(torch_dep.source_nodes), {"node1", "node2"}) - - @patch('services.pip.pip_installer.file_ops.robust_readlines') - def test_merge_requirements_with_git_dependencies(self, mock_readlines): - """测试合并包含 git+ 依赖的 requirements""" - mock_readlines.return_value = [ - "git+https://github.com/user/repo.git\n", - "torch>=1.8.0\n" - ] - - self.installer._merge_requirements_file("/fake/requirements.txt", "test_node") - - self.assertEqual(len(self.installer._merged_dependencies), 2) - self.assertIn("git+https://github.com/user/repo.git", self.installer._merged_dependencies) - self.assertIn("torch", self.installer._merged_dependencies) - - -class TestRequirementsGeneration(TestPIPInstaller): - """测试 requirements.txt 生成功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_generate_requirements_content_with_dependencies(self): - """测试生成含有依赖的 requirements.txt 内容""" - filtered_deps = { - "torch": DependencyInfo( - package_name="torch", - version_spec=">=1.8.0", - original_line="torch>=1.8.0", - source_nodes=["node1", "node2"] - ), - "numpy": DependencyInfo( - package_name="numpy", - version_spec="==1.21.0", - original_line="numpy==1.21.0", - source_nodes=["node1"] - ) - } - - with patch('builtins.print'): # Suppress print output - content = self.installer._generate_requirements_content(filtered_deps) - - expected_lines = ["numpy==1.21.0", "torch>=1.8.0"] # 按字母顺序排列 - self.assertEqual(content, "\n".join(expected_lines)) - - def test_generate_requirements_content_empty(self): - """测试生成空的 requirements.txt 内容""" - with patch('builtins.print'): # Suppress print output - content = self.installer._generate_requirements_content({}) - - self.assertEqual(content, "") - - -class TestInstallationProcess(TestPIPInstaller): - """测试安装过程""" - - def setUp(self): - super().setUp() - # 创建测试节点 - self.node1_dir = self.create_test_node_dir("node1") - self.node2_dir = self.create_test_node_dir("node2") - - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - @patch('tempfile.NamedTemporaryFile') - @patch('subprocess.run') - @patch('os.unlink') - def test_install_merged_dependencies_success(self, mock_unlink, mock_subprocess_run, mock_tempfile): - """测试依赖批量安装成功""" - # Mock 临时文件 - mock_file = MagicMock() - mock_file.name = "/tmp/test_requirements.txt" - mock_tempfile.return_value.__enter__.return_value = mock_file - - # Mock subprocess.run 返回成功 - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "Successfully installed torch numpy" - mock_result.stderr = "" - mock_subprocess_run.return_value = mock_result - - requirements_content = "torch>=1.8.0\nnumpy>=1.20.0" - - with patch('time.time', side_effect=[0, 10]): # Mock start and end time - result = self.installer._install_merged_dependencies(requirements_content, timeout=300) - - # 验证结果 - self.assertIsInstance(result, DependencyInstallRecord) - self.assertTrue(result.success) - self.assertEqual(result.requirements_txt, requirements_content) - self.assertEqual(result.duration, 10.0) - self.assertEqual(result.error_msg, "") - - # 验证调用 - mock_tempfile.assert_called_once() - mock_file.write.assert_called_once_with(requirements_content) - mock_subprocess_run.assert_called_once() - mock_unlink.assert_called_once_with("/tmp/test_requirements.txt") - - @patch('tempfile.NamedTemporaryFile') - @patch('subprocess.run') - @patch('os.unlink') - def test_install_merged_dependencies_failure(self, mock_unlink, mock_subprocess_run, mock_tempfile): - """测试依赖批量安装失败""" - # Mock 临时文件 - mock_file = MagicMock() - mock_file.name = "/tmp/test_requirements.txt" - mock_tempfile.return_value.__enter__.return_value = mock_file - - # Mock subprocess.run 返回失败 - mock_result = MagicMock() - mock_result.returncode = 1 - mock_result.stdout = "" - mock_result.stderr = "ERROR: No matching distribution found for invalid-package" - mock_subprocess_run.return_value = mock_result - - requirements_content = "invalid-package>=1.0.0" - - with patch('time.time', side_effect=[0, 5]): # Mock start and end time - result = self.installer._install_merged_dependencies(requirements_content, timeout=300) - - # 验证结果 - self.assertIsInstance(result, DependencyInstallRecord) - self.assertFalse(result.success) - self.assertEqual(result.requirements_txt, requirements_content) - self.assertEqual(result.duration, 5.0) - self.assertIn("pip install failed with return code 1", result.error_msg) - - def test_do_install_script_success(self): - """测试 install.py 脚本执行成功""" - with patch.object(self.installer, '_install_script', return_value=True), \ - patch('time.time', side_effect=[0, 2.5]): - - result = self.installer._do_install_script( - self.node1_dir, "node1", "install.py", ["python", "install.py"] - ) - - # 验证结果是字典形式 - self.assertIsInstance(result, dict) - self.assertEqual(result["node_name"], "node1") - self.assertEqual(result["script_name"], "install.py") - self.assertTrue(result["success"]) - self.assertEqual(result["duration"], 2.5) - self.assertEqual(result["error_msg"], "") - - def test_do_install_script_failure(self): - """测试 install.py 脚本执行失败""" - with patch.object(self.installer, '_install_script', side_effect=Exception("Script failed")), \ - patch('time.time', side_effect=[0, 1.0]): - - result = self.installer._do_install_script( - self.node1_dir, "node1", "install.py", ["python", "install.py"] - ) - - # 验证结果是字典形式 - self.assertIsInstance(result, dict) - self.assertEqual(result["node_name"], "node1") - self.assertEqual(result["script_name"], "install.py") - self.assertFalse(result["success"]) - self.assertEqual(result["duration"], 1.0) - self.assertEqual(result["error_msg"], "Script failed") - - -class TestFullIntegration(TestPIPInstaller): - """测试完整安装流程""" - - def setUp(self): - super().setUp() - # 创建测试节点和文件 - self.node1_dir = self.create_test_node_dir("node1") - self.node2_dir = self.create_test_node_dir("node2") - self.node3_dir = self.create_test_node_dir("node3") - - # 创建 requirements.txt 文件 - self.create_requirements_file(self.node1_dir, "torch>=1.8.0\nnumpy>=1.20.0\n") - self.create_requirements_file(self.node2_dir, "torch>=1.9.0\nrequests>=2.25.0\n") - self.create_requirements_file(self.node3_dir, "git+https://github.com/user/repo.git\n") - - # 创建 install.py 脚本 - self.create_install_script(self.node2_dir, "print('Installing node2')\n") - - @patch('services.pip.pip_installer.subprocess.check_output') - @patch('builtins.print') - def test_install_all_full_process(self, mock_print, mock_check_output): - """测试完整安装流程(不真正执行 pip)""" - # Mock pip list 输出 - mock_check_output.return_value = "Package Version\nrequests 2.28.0" - - installer = PIPInstaller() - - # Mock 生成 requirements 内容 - with patch.object(installer, '_install_merged_dependencies') as mock_dep_install, \ - patch.object(installer, '_execute_install_scripts') as mock_scripts: - - # 假设依赖安装成功 - dep_record = DependencyInstallRecord(requirements_txt="torch>=1.9.0\nnumpy>=1.20.0", duration=12.3, success=True, error_msg="") - mock_dep_install.return_value = dep_record - - # 假设脚本执行成功 - mock_scripts.return_value = [ - InstallRecord(node_name="node2", script_name="install.py", duration=1.2, success=True, error_msg="").to_dict() - ] - - result_map = installer.install_all() - - # 验证返回结构 - self.assertIsInstance(result_map, dict) - self.assertIn("baseline", result_map) - self.assertIn("dependencies", result_map) - self.assertIn("scripts", result_map) - - self.assertTrue(result_map["dependencies"]["success"]) # 依赖安装成功 - self.assertEqual(len(result_map["scripts"]), 1) # 一个脚本被执行 - - # 验证合并的依赖 - self.assertIn("torch", installer._merged_dependencies) - self.assertIn("numpy", installer._merged_dependencies) - self.assertIn("git+https://github.com/user/repo.git", installer._merged_dependencies) - - # 验证版本冲突解决(torch: >=1.8.0 vs >=1.9.0 -> >=1.9.0) - torch_dep = installer._merged_dependencies["torch"] - self.assertEqual(torch_dep.version_spec, ">=1.9.0") - - @patch('services.pip.pip_installer.subprocess.check_output') - def test_install_all_with_nodes_map(self, mock_subprocess): - """测试指定节点的安装""" - mock_subprocess.return_value = "Package Version\n" - - installer = PIPInstaller() - nodes_map = {"node1": {}, "node3": {}} - - with patch.object(installer, '_install_merged_dependencies') as mock_dep_install, \ - patch.object(installer, '_execute_install_scripts') as mock_scripts, \ - patch('builtins.print'): - - dep_record = DependencyInstallRecord(requirements_txt="numpy>=1.20.0\ntorch>=1.8.0", duration=1.0, success=True, error_msg="") - mock_dep_install.return_value = dep_record - mock_scripts.return_value = [] - - result_map = installer.install_all(nodes_map=nodes_map) - - # 验证只处理了指定的节点 - merged_deps = installer._merged_dependencies - self.assertIn("torch", merged_deps) # from node1 - self.assertIn("numpy", merged_deps) # from node1 - self.assertIn("git+https://github.com/user/repo.git", merged_deps) # from node3 - self.assertNotIn("requests", merged_deps) # node2 not selected - - @patch('services.pip.pip_installer.subprocess.check_output') - @patch('builtins.print') - def test_install_all_timeout_behavior(self, mock_print, mock_subprocess): - """测试超时捕获行为(install_all 会捕获 TimeoutError 不再抛出)""" - mock_subprocess.return_value = "Package Version\n" - - installer = PIPInstaller() - - # 创建一个测试节点目录,确保 _merge_requirements_from_nodes 有内容 - test_node_dir = self.create_test_node_dir("test_timeout_node") - self.create_requirements_file(test_node_dir, "some-package>=1.0.0\n") - - # 模拟时间推进,触发 _merge_requirements_from_nodes 的 timeout - # install_all 会捕获 TimeoutError 并打印日志,然后正常返回 - with patch('time.time', side_effect=[0, 1000, 1001, 1002]): - result_map = installer.install_all(timeout=300) - - # 验证返回结果是有效的 - self.assertIsInstance(result_map, dict) - self.assertIn("baseline", result_map) - self.assertIn("dependencies", result_map) - self.assertIn("scripts", result_map) - - # 验证打印了超时日志 - printed_timeout = any("Timeout (" in str(call) for call in mock_print.call_args_list) - self.assertTrue(printed_timeout) # 应该打印超时信息 - - -class TestUtilityMethods(TestPIPInstaller): - """测试工具方法""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_extract_version_from_exact(self): - """测试从精确版本约束中提取版本号""" - self.assertEqual(self.installer._extract_version_from_exact("==1.8.0"), "1.8.0") - self.assertEqual(self.installer._extract_version_from_exact("==2.1.5"), "2.1.5") - self.assertEqual(self.installer._extract_version_from_exact("invalid"), "invalid") - - def test_is_version_newer(self): - """测试版本比较""" - self.assertTrue(self.installer._is_version_newer("2.0.0", "1.0.0")) - self.assertTrue(self.installer._is_version_newer("1.5.0", "1.4.9")) - self.assertFalse(self.installer._is_version_newer("1.0.0", "2.0.0")) - self.assertFalse(self.installer._is_version_newer("1.4.9", "1.5.0")) - self.assertFalse(self.installer._is_version_newer("1.0.0", "1.0.0")) - - def test_is_version_newer_invalid_format(self): - """测试无效版本格式的比较""" - # 当版本格式无效时,应该进行字符串比较 - self.assertTrue(self.installer._is_version_newer("b", "a")) - self.assertFalse(self.installer._is_version_newer("a", "b")) - @patch('services.pip.pip_installer.subprocess.check_output') - def test_try_get_installed_packages_success(self, mock_subprocess): - """测试获取已安装包列表成功""" - mock_subprocess.return_value = """Package Version ----------- ------- -requests 2.28.0 -numpy 1.21.0""" - - installer = PIPInstaller() - packages = installer.get_origin_packages() - - expected = {"requests": "2.28.0", "numpy": "1.21.0"} - self.assertEqual(packages, expected) - - @patch('services.pip.pip_installer.subprocess.check_output') - def test_try_get_installed_packages_failure(self, mock_subprocess): - """测试获取已安装包列表失败""" - from subprocess import CalledProcessError - mock_subprocess.side_effect = CalledProcessError(1, "pip list") - - with patch('builtins.print'): # Suppress error output - installer = PIPInstaller() - - self.assertEqual(installer.get_origin_packages(), {}) - - -class TestNewIntegrationFlow(TestPIPInstaller): - """测试重构后的集成流程""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_empty_nodes_map(self): - """测试空节点映射的情况""" - with patch('services.pip.pip_installer.subprocess.check_output') as mock_subprocess: - mock_subprocess.return_value = "Package Version\n" - installer = PIPInstaller() - - # 传递空字典,应该不安装任何节点 - result_map = installer.install_all(timeout=60, nodes_map={}) - - # 验证返回结构 - self.assertIsInstance(result_map, dict) - self.assertIn("baseline", result_map) - self.assertIn("dependencies", result_map) - self.assertIn("scripts", result_map) - - # 验证空节点映射的情况 - self.assertEqual(len(result_map["scripts"]), 0) # 无脚本执行 - self.assertTrue(result_map["dependencies"]["success"]) # 依赖安装成功(但无内容) - self.assertEqual(result_map["dependencies"]["requirements_txt"], "") # 无依赖内容 - - def test_return_structure_consistency(self): - """测试返回结构的一致性""" - with patch('services.pip.pip_installer.subprocess.check_output') as mock_subprocess: - mock_subprocess.return_value = "Package Version\n" - - installer = PIPInstaller() - - # 测试不同参数的情况 - test_cases = [ - None, # 安装所有节点 - {}, # 不安装任何节点 - {"nonexistent": {}}, # 不存在的节点 - ] - - for nodes_map in test_cases: - with self.subTest(nodes_map=nodes_map): - result_map = installer.install_all(timeout=10, nodes_map=nodes_map) - - # 验证返回结构一致性 - self.assertIsInstance(result_map, dict) - self.assertIn("baseline", result_map) - self.assertIn("dependencies", result_map) - self.assertIn("scripts", result_map) - - # 验证 dependencies 结构 - deps = result_map["dependencies"] - self.assertIn("requirements_txt", deps) - self.assertIn("duration", deps) - self.assertIn("success", deps) - self.assertIn("error_msg", deps) - - # 验证 scripts 结构 - self.assertIsInstance(result_map["scripts"], list) - - -class TestCustomDependencyStrategies(TestPIPInstaller): - """测试定制化依赖策略功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_apply_custom_dependency_strategies_basic(self): - """测试定制化依赖策略基本功能""" - filtered_deps = { - "requests": DependencyInfo( - package_name="requests", - version_spec=">=2.25.0", - original_line="requests>=2.25.0", - source_nodes=["test_node"] - ) - } - - nodes_to_install = ["test_node"] - nodes_map = {"test_node": {}} - - with patch('builtins.print'): # Suppress print output - result = self.installer._apply_custom_dependency_strategies( - filtered_deps, nodes_to_install, nodes_map - ) - - # 应该返回原始的 filtered_deps(无 nunchaku 节点) - self.assertEqual(result, filtered_deps) - self.assertIn("requests", result) - - def test_handle_nunchaku_strategy_no_nunchaku_node(self): - """测试无 ComfyUI-nunchaku 节点的情况""" - filtered_deps = { - "torch": DependencyInfo( - package_name="torch", - version_spec=">=1.8.0", - original_line="torch>=1.8.0", - source_nodes=["other_node"] - ) - } - - nodes_to_install = ["other_node"] - nodes_map = {"other_node": {}} - - result = self.installer._handle_nunchaku_strategy( - filtered_deps, nodes_to_install, nodes_map - ) - - # 无 nunchaku 节点,应该直接返回原始 deps - self.assertEqual(result, filtered_deps) - self.assertNotIn("https://modelscope.cn", str(result)) - - def test_handle_nunchaku_strategy_v1_0_0(self): - """测试 ComfyUI-nunchaku v1.0.0 的特殊处理""" - filtered_deps = { - "requests": DependencyInfo( - package_name="requests", - version_spec=">=2.25.0", - original_line="requests>=2.25.0", - source_nodes=["ComfyUI-nunchaku"] - ) - } - - nodes_to_install = ["ComfyUI-nunchaku"] - nodes_map = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "source": { - "webUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku", - "type": "github", - "cloneUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku.git" - }, - "version": { - "type": "tag", - "value": "v1.0.0" - } - } - } - - with patch('builtins.print'): # Suppress print output - result = self.installer._handle_nunchaku_strategy( - filtered_deps, nodes_to_install, nodes_map - ) - - # 应该添加 nunchaku wheel URL - expected_wheel_url = "https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/nunchaku-1.0.0+torch2.8-cp310-cp310-linux_x86_64.whl" - - self.assertIn(expected_wheel_url, result) - self.assertIn("requests", result) # 原有依赖仍在 - - # 验证 wheel 依赖的结构 - wheel_dep = result[expected_wheel_url] - self.assertEqual(wheel_dep.package_name, expected_wheel_url) - self.assertEqual(wheel_dep.version_spec, "") - self.assertEqual(wheel_dep.source_nodes, ["ComfyUI-nunchaku"]) - - def test_handle_nunchaku_strategy_v1_0_1(self): - """测试 ComfyUI-nunchaku v1.0.1 的特殊处理""" - filtered_deps = { - "requests": DependencyInfo( - package_name="requests", - version_spec=">=2.25.0", - original_line="requests>=2.25.0", - source_nodes=["ComfyUI-nunchaku"] - ) - } - - nodes_to_install = ["ComfyUI-nunchaku"] - nodes_map = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "source": { - "webUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku", - "type": "github", - "cloneUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku.git" - }, - "version": { - "type": "tag", - "value": "v1.0.1" - } - } - } - - with patch('builtins.print'): # Suppress print output - result = self.installer._handle_nunchaku_strategy( - filtered_deps, nodes_to_install, nodes_map - ) - - # 应该添加 nunchaku v1.0.1 的 wheel URL(与 v1.0.0 相同) - expected_wheel_url = "https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/nunchaku-1.0.0+torch2.8-cp310-cp310-linux_x86_64.whl" - - self.assertIn(expected_wheel_url, result) - self.assertIn("requests", result) # 原有依赖仍在 - - # 验证 wheel 依赖的结构 - wheel_dep = result[expected_wheel_url] - self.assertEqual(wheel_dep.package_name, expected_wheel_url) - self.assertEqual(wheel_dep.version_spec, "") - self.assertEqual(wheel_dep.source_nodes, ["ComfyUI-nunchaku"]) - - def test_handle_nunchaku_strategy_other_version(self): - """测试 ComfyUI-nunchaku 其他版本的处理""" - filtered_deps = { - "requests": DependencyInfo( - package_name="requests", - version_spec=">=2.25.0", - original_line="requests>=2.25.0", - source_nodes=["ComfyUI-nunchaku"] - ) - } - - nodes_to_install = ["ComfyUI-nunchaku"] - nodes_map = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "version": { - "type": "tag", - "value": "v0.2.0" - } - } - } - - with patch('builtins.print'): # Suppress print output - result = self.installer._handle_nunchaku_strategy( - filtered_deps, nodes_to_install, nodes_map - ) - - # v0.2.0 不应该添加 wheel URL - self.assertEqual(result, filtered_deps) - self.assertNotIn("https://modelscope.cn", str(result)) - - def test_extract_nunchaku_version_valid_structure(self): - """测试从正确的 nodes_map 结构中提取版本""" - nodes_map = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "version": { - "type": "tag", - "value": "v1.0.0" - } - } - } - - with patch('builtins.print'): # Suppress print output - version = self.installer._extract_nunchaku_version(nodes_map, "ComfyUI-nunchaku") - - self.assertEqual(version, "v1.0.0") - - def test_extract_nunchaku_version_missing_node(self): - """测试从缺少节点的 nodes_map 中提取版本""" - nodes_map = {"other_node": {}} - - with patch('builtins.print'): # Suppress print output - version = self.installer._extract_nunchaku_version(nodes_map, "ComfyUI-nunchaku") - - self.assertEqual(version, "unknown") - - def test_extract_nunchaku_version_invalid_structure(self): - """测试从无效结构的 nodes_map 中提取版本""" - test_cases = [ - # 缺少 version 字段 - {"ComfyUI-nunchaku": {"name": "ComfyUI-nunchaku"}}, - # version 不是字典 - {"ComfyUI-nunchaku": {"version": "v1.0.0"}}, - # version 字典缺少 value - {"ComfyUI-nunchaku": {"version": {"type": "tag"}}}, - # 空的 nodes_map - None, - # 空字典 - {} - ] - - for nodes_map in test_cases: - with self.subTest(nodes_map=nodes_map): - with patch('builtins.print'): # Suppress print output - version = self.installer._extract_nunchaku_version(nodes_map, "ComfyUI-nunchaku") - - self.assertEqual(version, "unknown") - - -class TestProxyEnvironmentHandling(TestPIPInstaller): - """测试代理环境变量处理功能""" - - def setUp(self): - super().setUp() - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_get_script_env_keeps_proxy_vars(self): - """测试 _get_script_env 方法保留代理环境变量(用于install.py脚本)""" - original_env = { - 'PATH': '/usr/bin', - 'HOME': '/home/user', - 'http_proxy': 'http://proxy.example.com:8080', - 'https_proxy': 'https://proxy.example.com:8080', - 'HTTP_PROXY': 'http://proxy.example.com:8080', - 'HTTPS_PROXY': 'https://proxy.example.com:8080', - 'OTHER_VAR': 'value' - } - - with patch('os.environ.copy', return_value=original_env.copy()): - result_env = self.installer._get_script_env() - - # 验证代理变量被保留(不被移除) - proxy_vars = ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY'] - for proxy_var in proxy_vars: - self.assertIn(proxy_var, result_env) - self.assertEqual(result_env[proxy_var], original_env[proxy_var]) - - # 验证其他变量仍在 - self.assertIn('PATH', result_env) - self.assertIn('HOME', result_env) - self.assertIn('OTHER_VAR', result_env) - - # 验证 ComfyUI 相关变量被添加 - self.assertEqual(result_env['COMFYUI_PATH'], self.comfyui_dir) - self.assertEqual(result_env['COMFYUI_FOLDERS_BASE_PATH'], self.comfyui_dir) - - def test_get_pip_install_env_removes_proxy_vars(self): - """测试 _get_pip_install_env 方法移除代理环境变量(用于pip install -r)""" - original_env = { - 'PATH': '/usr/bin', - 'HOME': '/home/user', - 'http_proxy': 'http://proxy.example.com:8080', - 'https_proxy': 'https://proxy.example.com:8080', - 'HTTP_PROXY': 'http://proxy.example.com:8080', - 'HTTPS_PROXY': 'https://proxy.example.com:8080', - 'OTHER_VAR': 'value' - } - - with patch('os.environ.copy', return_value=original_env.copy()): - result_env = self.installer._get_pip_install_env() - - # 验证代理变量被移除 - proxy_vars = ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY'] - for proxy_var in proxy_vars: - self.assertNotIn(proxy_var, result_env) - - # 验证其他变量仍在 - self.assertIn('PATH', result_env) - self.assertIn('HOME', result_env) - self.assertIn('OTHER_VAR', result_env) - - # 验证 ComfyUI 相关变量被添加 - self.assertEqual(result_env['COMFYUI_PATH'], self.comfyui_dir) - self.assertEqual(result_env['COMFYUI_FOLDERS_BASE_PATH'], self.comfyui_dir) - - def test_get_script_env_no_proxy_vars_initially(self): - """测试初始环境中无代理变量的情况""" - original_env = { - 'PATH': '/usr/bin', - 'HOME': '/home/user' - } - - with patch('os.environ.copy', return_value=original_env.copy()): - result_env = self.installer._get_script_env() - - # 验证不会报错 - self.assertIn('PATH', result_env) - self.assertIn('HOME', result_env) - self.assertEqual(result_env['COMFYUI_PATH'], self.comfyui_dir) - - def test_get_pip_install_env_no_proxy_vars_initially(self): - """测试初始环境中无代理变量的情况(pip install env)""" - original_env = { - 'PATH': '/usr/bin', - 'HOME': '/home/user' - } - - with patch('os.environ.copy', return_value=original_env.copy()): - result_env = self.installer._get_pip_install_env() - - # 验证不会报错 - self.assertIn('PATH', result_env) - self.assertIn('HOME', result_env) - self.assertEqual(result_env['COMFYUI_PATH'], self.comfyui_dir) - - def test_get_pip_install_env_partial_proxy_vars(self): - """测试部分代理变量存在的情况(pip install env)""" - original_env = { - 'PATH': '/usr/bin', - 'http_proxy': 'http://proxy.example.com:8080', - 'HTTPS_PROXY': 'https://proxy.example.com:8080' - # 只有部分代理变量 - } - - with patch('os.environ.copy', return_value=original_env.copy()): - result_env = self.installer._get_pip_install_env() - - # 验证存在的代理变量被移除 - self.assertNotIn('http_proxy', result_env) - self.assertNotIn('HTTPS_PROXY', result_env) - - # 验证不存在的代理变量不会引起错误 - self.assertNotIn('https_proxy', result_env) - self.assertNotIn('HTTP_PROXY', result_env) - - -class TestMergeRequirementsFromNodesParameterUpdate(TestPIPInstaller): - """测试 _merge_requirements_from_nodes 方法参数更新""" - - def setUp(self): - super().setUp() - # 创建测试节点 - self.node_dir = self.create_test_node_dir("test_node") - self.create_requirements_file(self.node_dir, "requests>=2.25.0\n") - - with patch('services.pip.pip_installer.subprocess.check_output'): - self.installer = PIPInstaller() - - def test_merge_requirements_from_nodes_new_signature(self): - """测试新的方法签名支持 nodes_map 参数""" - nodes_to_install = ["test_node"] - nodes_map = {"test_node": {}} - timeout = 300 - start_time = 0 - - with patch('time.time', return_value=0), \ - patch.object(self.installer, '_apply_custom_dependency_strategies') as mock_custom: - - # Mock 定制策略返回原始依赖 - mock_custom.return_value = { - "requests": DependencyInfo( - package_name="requests", - version_spec=">=2.25.0", - original_line="requests>=2.25.0", - source_nodes=["test_node"] - ) - } - - with patch('builtins.print'): # Suppress print output - result = self.installer._merge_requirements_from_nodes( - nodes_to_install, nodes_map, timeout, start_time - ) - - # 验证定制策略被调用 - mock_custom.assert_called_once() - args = mock_custom.call_args[0] - self.assertEqual(args[1], nodes_to_install) # nodes_to_install - self.assertEqual(args[2], nodes_map) # nodes_map - - # 验证返回的 requirements.txt 内容 - self.assertIn("requests>=2.25.0", result) - - def test_merge_requirements_calls_custom_strategies(self): - """测试 _merge_requirements_from_nodes 调用定制策略""" - nodes_to_install = ["test_node"] - nodes_map = {"test_node": {}} - - with patch('time.time', return_value=0), \ - patch.object(self.installer, '_apply_custom_dependency_strategies') as mock_custom: - - # 设置 mock 返回值 - mock_custom.return_value = {} - - with patch('builtins.print'): - self.installer._merge_requirements_from_nodes( - nodes_to_install, nodes_map, 300, 0 - ) - - # 验证定制策略被调用 - mock_custom.assert_called_once() - - # 验证调用参数 - call_args = mock_custom.call_args[0] - self.assertEqual(len(call_args), 3) # filtered_deps, nodes_to_install, nodes_map - self.assertEqual(call_args[1], nodes_to_install) - self.assertEqual(call_args[2], nodes_map) - - -class TestInstallAllIntegrationWithCustomStrategies(TestPIPInstaller): - """测试 install_all 与定制策略的集成""" - - def setUp(self): - super().setUp() - # 创建 ComfyUI-nunchaku 测试节点 - self.nunchaku_dir = self.create_test_node_dir("ComfyUI-nunchaku") - self.create_requirements_file(self.nunchaku_dir, "requests>=2.25.0\n") - - @patch('services.pip.pip_installer.subprocess.check_output') - @patch('builtins.print') - def test_install_all_with_nunchaku_v1_0_0_integration(self, mock_print, mock_subprocess): - """测试 install_all 与 nunchaku v1.0.0 的集成""" - mock_subprocess.return_value = "Package Version\n" - - installer = PIPInstaller() - - nodes_map = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "version": { - "type": "tag", - "value": "v1.0.0" - } - } - } - - with patch.object(installer, '_install_merged_dependencies') as mock_dep_install, \ - patch.object(installer, '_execute_install_scripts') as mock_scripts: - - # 设置 mock 返回值 - dep_record = DependencyInstallRecord( - requirements_txt="", # 将在 mock 中检查实际内容 - duration=1.0, - success=True, - error_msg="" - ) - mock_dep_install.return_value = dep_record - mock_scripts.return_value = [] - - result_map = installer.install_all(timeout=10, nodes_map=nodes_map) - - # 验证返回结构 - self.assertIn("dependencies", result_map) - self.assertIn("scripts", result_map) - self.assertIn("baseline", result_map) - - # 验证 _install_merged_dependencies 被调用 - mock_dep_install.assert_called_once() - - # 检查传递给 _install_merged_dependencies 的 requirements_txt 内容 - called_requirements = mock_dep_install.call_args[0][0] # 第一个参数 - - # 应该包含 nunchaku wheel URL - expected_wheel_url = "https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/nunchaku-1.0.0+torch2.8-cp310-cp310-linux_x86_64.whl" - self.assertIn(expected_wheel_url, called_requirements) - - # 也应该包含原有的 requirements - self.assertIn("requests>=2.25.0", called_requirements) - - @patch('services.pip.pip_installer.subprocess.check_output') - @patch('builtins.print') - def test_install_all_with_nunchaku_other_version_integration(self, mock_print, mock_subprocess): - """测试 install_all 与 nunchaku 非 v1.0.0 版本的集成""" - mock_subprocess.return_value = "Package Version\n" - - installer = PIPInstaller() - - nodes_map = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "version": { - "type": "tag", - "value": "v0.2.0" - } - } - } - - with patch.object(installer, '_install_merged_dependencies') as mock_dep_install, \ - patch.object(installer, '_execute_install_scripts') as mock_scripts: - - dep_record = DependencyInstallRecord( - requirements_txt="", - duration=1.0, - success=True, - error_msg="" - ) - mock_dep_install.return_value = dep_record - mock_scripts.return_value = [] - - result_map = installer.install_all(timeout=10, nodes_map=nodes_map) - - # 检查传递给 _install_merged_dependencies 的 requirements_txt 内容 - called_requirements = mock_dep_install.call_args[0][0] - - # 不应该包含 nunchaku wheel URL - self.assertNotIn("https://modelscope.cn", called_requirements) - - # 但应该包含原有的 requirements - self.assertIn("requests>=2.25.0", called_requirements) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/code/agent/test/unit/services/pip/real_scenario_test.py b/src/code/agent/test/unit/services/pip/real_scenario_test.py deleted file mode 100644 index dbc7b10..0000000 --- a/src/code/agent/test/unit/services/pip/real_scenario_test.py +++ /dev/null @@ -1,767 +0,0 @@ -import unittest -import os -import tempfile -import shutil -import sys -import subprocess -from pathlib import Path - -# 添加当前项目路径到Python路径 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))) - -from services.pip.pip_installer import PIPInstaller - - -class RealScenarioTest(unittest.TestCase): - """真实场景测试:创建真实的venv环境,查看完整的安装日志效果""" - - def setUp(self): - """测试初始化 - 创建真实的venv环境""" - self.test_dir = tempfile.mkdtemp(prefix="pip_installer_test_") - self.comfyui_dir = os.path.join(self.test_dir, "comfyui") - self.custom_nodes_dir = os.path.join(self.comfyui_dir, "custom_nodes") - self.venv_dir = os.path.join(self.test_dir, "test_venv") - - os.makedirs(self.custom_nodes_dir, exist_ok=True) - - print(f"\n创建测试环境在: {self.test_dir}") - print(f"ComfyUI目录: {self.comfyui_dir}") - print(f"虚拟环境目录: {self.venv_dir}") - - # 使用命令创建真实的虚拟环境 - print("\n创建虚拟环境...") - try: - subprocess.run([sys.executable, "-m", "venv", self.venv_dir], - check=True, capture_output=True, timeout=60) - print("✓ 虚拟环境创建成功") - except subprocess.CalledProcessError as e: - print(f"✗ 虚拟环境创建失败: {e.stderr.decode()}") - raise - except subprocess.TimeoutExpired: - print("✗ 虚拟环境创建超时") - raise - - # 设置虚拟环境的Python路径 - if sys.platform == "win32": - self.venv_python = os.path.join(self.venv_dir, "Scripts", "python.exe") - else: - self.venv_python = os.path.join(self.venv_dir, "bin", "python") - - # 验证虚拟环境可用 - try: - result = subprocess.run([self.venv_python, "--version"], - capture_output=True, text=True, timeout=10) - if result.returncode == 0: - print(f"✓ 虚拟环境Python: {result.stdout.strip()}") - else: - raise Exception(f"Python版本检查失败: {result.stderr}") - except subprocess.TimeoutExpired: - print("✗ Python版本检查超时") - raise - except Exception as e: - print(f"✗ 虚拟环境验证失败: {e}") - raise - - # 升级pip到最新版本 - print("\n升级pip...") - try: - subprocess.run([self.venv_python, "-m", "pip", "install", "--upgrade", "pip"], - check=True, capture_output=True, timeout=120) - print("✓ pip升级完成") - except subprocess.CalledProcessError as e: - print(f"⚠ pip升级失败: {e.stderr.decode()}") - # 不抛出异常,继续测试 - except subprocess.TimeoutExpired: - print("⚠ pip升级超时,继续使用现有版本") - - # 设置constants使用真实路径 - import constants - self.original_comfyui_dir = getattr(constants, 'COMFYUI_DIR', None) - self.original_venv_executable = getattr(constants, 'VENV_EXECUTABLE', None) - - constants.COMFYUI_DIR = self.comfyui_dir - constants.VENV_EXECUTABLE = self.venv_python - - self._create_simple_test_nodes() - - def tearDown(self): - """测试清理 - 删除venv环境""" - # 恢复constants - import constants - if self.original_comfyui_dir is not None: - constants.COMFYUI_DIR = self.original_comfyui_dir - if self.original_venv_executable is not None: - constants.VENV_EXECUTABLE = self.original_venv_executable - - # 删除测试目录 - print(f"清理测试环境: {self.test_dir}") - try: - shutil.rmtree(self.test_dir, ignore_errors=True) - print("✓ 测试环境清理完成") - except Exception as e: - print(f"✗ 测试环境清理失败: {e}") - - def create_test_node_dir(self, node_name): - """创建测试节点目录""" - node_dir = os.path.join(self.custom_nodes_dir, node_name) - os.makedirs(node_dir, exist_ok=True) - return node_dir - - def create_requirements_file(self, node_dir, content): - """创建 requirements.txt 文件""" - req_file = os.path.join(node_dir, "requirements.txt") - with open(req_file, "w") as f: - f.write(content) - return req_file - - def _create_simple_test_nodes(self): - """创建3个简单的测试插件""" - - # 插件1: 基础工具插件 - self.node1_dir = self.create_test_node_dir("basic-tools") - self.create_requirements_file(self.node1_dir, """ -# 基础工具依赖 - 使用小的包便于测试 -click>=7.0 -colorama>=0.4.0 -pytz>=2021.1 - """.strip()) - - # 插件2: 数据处理插件 - self.node2_dir = self.create_test_node_dir("data-processor") - self.create_requirements_file(self.node2_dir, """ -# 数据处理依赖 -click>=8.0 # 版本冲突测试 -requests>=2.25.0 -json5>=0.9.0 -pyyaml>=5.4.0 - """.strip()) - - # 插件3: 图像处理插件 - self.node3_dir = self.create_test_node_dir("image-helper") - self.create_requirements_file(self.node3_dir, """ -# 图像处理依赖 -colorama>=0.3.0 # 低于上面的版本 -chardet>=4.0.0 -urllib3>=1.26.0 -click # 无版本要求 -git+https://github.com/python/cpython.git # git依赖测试(仅用于演示) - """.strip()) - - print(f"创建了3个测试插件: basic-tools, data-processor, image-helper") - - def test_install_all_direct(self): - """直接调用install_all方法,观察完整安装日志""" - - print("\n==== 直接调用install_all方法测试 ====\n") - - # 创建安装器实例 - installer = PIPInstaller() - - # 打印环境信息 - print(f"虚拟环境: {self.venv_python}") - print(f"自定义节点目录: {self.custom_nodes_dir}") - - # 打印安装前的包基线 - baseline = installer.get_origin_packages() - print(f"\n安装前包基线: {len(baseline)} 个包") - print(f"示例: {list(baseline.keys())[:5]}" + ("..." if len(baseline) > 5 else "")) - - # 只用特定的几个节点进行测试,避免安装所有节点 - nodes_map = { - "basic-tools": {}, - "data-processor": {}, - "image-helper": {} - } - - # 直接调用install_all方法(使用较短的超时时间以便测试) - try: - result_map = installer.install_all(timeout=300, nodes_map=nodes_map) # 5分钟超时 - - # 打印新的结果结构 - print("\n\n==== 安装结果分析 ====\n") - - # 1. 基线信息 - print(f"1. 包基线: {len(result_map['baseline'])} 个包") - - # 2. 依赖安装结果 - dep_result = result_map.get('dependencies') - if dep_result: - dep_success = dep_result.get('success', False) - dep_duration = dep_result.get('duration', 0) - dep_requirements = dep_result.get('requirements_txt', '') - dep_error = dep_result.get('error_msg', '') - - print(f"2. 依赖安装: {'✓ 成功' if dep_success else '✗ 失败'} (耗时: {dep_duration}s)") - - if dep_requirements: - dep_lines = dep_requirements.strip().split('\n') - print(f" 安装的依赖包 ({len(dep_lines)} 个):") - for line in dep_lines[:10]: # 只显示前10个 - print(f" - {line}") - if len(dep_lines) > 10: - print(f" ... 及其他 {len(dep_lines) - 10} 个包") - else: - print(" 无依赖包需要安装") - - if dep_error: - print(f" 错误信息: {dep_error}") - else: - print("2. 依赖安装: 未执行") - - # 3. install.py 脚本安装结果 - script_results = result_map.get('scripts', []) - print(f"\n3. install.py 脚本安装: {len(script_results)} 个脚本") - - success_count = len([r for r in script_results if r.get('success', False)]) - print(f" 成功/总数: {success_count}/{len(script_results)}") - - for script_result in script_results: - node_name = script_result.get('node_name', 'Unknown') - script_name = script_result.get('script_name', 'install.py') - success = script_result.get('success', False) - duration = script_result.get('duration', 0) - error_msg = script_result.get('error_msg', '') - - status = '✓ 成功' if success else '✗ 失败' - print(f" - {node_name}/{script_name}: {status} ({duration}s)") - if error_msg: - print(f" 错误: {error_msg[:100]}" + ("..." if len(error_msg) > 100 else "")) - - # 4. 总体结果评估 - overall_success = ( - dep_result and dep_result.get('success', False) and - success_count == len(script_results) - ) - print(f"\n4. 总体结果: {'✓ 完全成功' if overall_success else '⚠ 部分成功或失败'}") - - return result_map - - except Exception as e: - print(f"安装过程出错: {e}") - import traceback - traceback.print_exc() - return None - - def test_install_with_blacklist(self): - """测试带黑名单的安装""" - print("\n==== 黑名单过滤测试 ====\n") - - # 创建带黑名单的安装器 - blacklist = ["colorama", "click"] # 将一些常用包加入黑名单 - installer = PIPInstaller(blacklist=blacklist) - - print(f"黑名单: {blacklist}") - - # 使用部分节点进行测试 - nodes_map = { - "basic-tools": {}, # 这个插件依赖于 click 和 colorama - "data-processor": {} # 这个插件依赖于 click - } - - try: - result_map = installer.install_all(timeout=200, nodes_map=nodes_map) - - # 分析黑名单过滤效果 - dep_result = result_map.get('dependencies') - if dep_result and dep_result.get('requirements_txt'): - requirements_lines = dep_result['requirements_txt'].strip().split('\n') - print(f"\n实际安装的依赖 ({len(requirements_lines)} 个):") - for line in requirements_lines: - print(f" - {line}") - - # 检查黑名单是否生效 - blacklisted_found = [] - for blacklisted_pkg in blacklist: - for line in requirements_lines: - if line.lower().startswith(blacklisted_pkg.lower()): - blacklisted_found.append(line) - - if blacklisted_found: - print(f"\n⚠ 发现黑名单包未被过滤: {blacklisted_found}") - else: - print(f"\n✓ 黑名单过滤正常,无黑名单包被安装") - else: - print("\n无依赖包需要安装") - - return result_map - - except Exception as e: - print(f"黑名单测试出错: {e}") - import traceback - traceback.print_exc() - return None - - def test_empty_nodes_map(self): - """测试空节点映射的情况""" - print("\n==== 空节点映射测试 ====\n") - - installer = PIPInstaller() - - try: - # 传递空字典,应该不安装任何节点 - result_map = installer.install_all(timeout=60, nodes_map={}) - - print("\n空节点映射结果:") - print(f" 基线包数量: {len(result_map['baseline'])}") - - dep_result = result_map.get('dependencies') - if dep_result: - print(f" 依赖安装: {'成功' if dep_result.get('success') else '失败'} (耗时: {dep_result.get('duration', 0)}s)") - print(f" requirements.txt: '{dep_result.get('requirements_txt', '')}'") - - script_results = result_map.get('scripts', []) - print(f" install.py 脚本: {len(script_results)} 个") - - expected_empty = ( - len(script_results) == 0 and - dep_result and dep_result.get('success') and - not dep_result.get('requirements_txt', '').strip() - ) - - print(f"\n结果检查: {'✓ 符合预期(无安装任何内容)' if expected_empty else '⚠ 与预期不符'}") - - return result_map - - except Exception as e: - print(f"空节点映射测试出错: {e}") - import traceback - traceback.print_exc() - return None - - def test_timeout_scenario(self): - """测试超时停止场景 - 模拟长时间安装触发超时机制""" - print("\n==== 超时停止测试 ====\n") - - # 创建一个特殊的测试节点,包含大型包或慢速安装的包 - timeout_node_dir = self.create_test_node_dir("timeout-test-node") - - # 创建一个会导致较长安装时间的 requirements.txt - # 这里使用一些需要编译的包或者大型包 - timeout_requirements = """ -# 这些包可能需要较长时间安装(用于触发超时) -numpy>=1.20.0 # 编译需要时间 -scipy>=1.7.0 # 大型包,依赖较多 -pandas>=1.3.0 # 大型数据分析包 -matplotlib>=3.4.0 # 图形包,依赖较多 -scikit-learn>=1.0.0 # 机器学习包,非常大 -tensorflow>=2.6.0 # 最大的机器学习框架之一 -torch>=1.9.0 # PyTorch,另一个大型框架 - """.strip() - - self.create_requirements_file(timeout_node_dir, timeout_requirements) - print(f"创建了超时测试节点: timeout-test-node") - print("包含的大型包: numpy, scipy, pandas, matplotlib, scikit-learn, tensorflow, torch") - - installer = PIPInstaller() - - # 使用非常短的超时时间来快速触发超时 - short_timeout = 30 # 30秒超时,对于安装 TensorFlow 等大包远远不够 - - nodes_map = { - "timeout-test-node": {} - } - - print(f"\n设置超时时间: {short_timeout} 秒") - print("注意: 由于需要安装大型包,超时几乎是必然的") - print("开始安装测试...") - - import time - start_test_time = time.time() - - try: - result_map = installer.install_all(timeout=short_timeout, nodes_map=nodes_map) - - test_duration = time.time() - start_test_time - print(f"\n测试完成,总耗时: {test_duration:.1f} 秒") - - # 分析超时结果 - print("\n==== 超时结果分析 ====\n") - - dep_result = result_map.get('dependencies') - if dep_result: - dep_success = dep_result.get('success', False) - dep_duration = dep_result.get('duration', 0) - dep_error = dep_result.get('error_msg', '') - - print(f"1. 依赖安装结果: {'✓ 成功' if dep_success else '✗ 失败'} (耗时: {dep_duration}s)") - - if not dep_success and dep_error: - print(f" 错误信息: {dep_error}") - - # 检查是否是超时错误 - timeout_indicators = [ - "timed out", - "timeout", - "TimeoutExpired", - "Hard timeout", - "Timeout (" - ] - - is_timeout_error = any(indicator in dep_error for indicator in timeout_indicators) - - if is_timeout_error: - print(f" ✓ 确认超时错误:超时机制正常工作") - - # 检查是否在预期时间内停止 - if dep_duration <= short_timeout + 10: # 允许10秒误差 - print(f" ✓ 超时控制正常:在 {short_timeout}s(+10s误差) 内停止") - else: - print(f" ⚠ 超时控制异常:实际耗时 {dep_duration}s 超出预期") - else: - print(f" ⚠ 非超时错误,可能是其他问题") - else: - print(f" ⚠ 意外成功:在 {short_timeout}s 内完成了大型包安装") - - # 检查 install.py 脚本是否被执行 - script_results = result_map.get('scripts', []) - print(f"\n2. install.py 脚本结果: {len(script_results)} 个脚本") - - if not script_results: - print(" ✓ 正常:由于依赖安装超时,未执行 install.py 脚本") - else: - for script_result in script_results: - node_name = script_result.get('node_name', 'Unknown') - success = script_result.get('success', False) - print(f" - {node_name}: {'✓ 成功' if success else '✗ 失败'}") - - # 总结 - print(f"\n3. 超时测试总结:") - if dep_result and not dep_result.get('success') and 'timeout' in dep_result.get('error_msg', '').lower(): - print(" ✓ 超时机制正常工作") - print(" ✓ 安装进程在超时后被正确停止") - print(" ✓ 返回了完整的错误信息") - else: - print(" ⚠ 超时测试结果与预期不符") - - return result_map - - except Exception as e: - test_duration = time.time() - start_test_time - print(f"\n超时测试抛出异常 (耗时: {test_duration:.1f}s): {e}") - - # 检查是否是预期的超时异常 - if "TimeoutError" in str(type(e)) or "timeout" in str(e).lower(): - print(" ✓ 确认为超时异常,超时机制工作正常") - else: - print(" ⚠ 非超时异常,可能是其他问题") - - import traceback - traceback.print_exc() - return None - - def test_lightweight_timeout(self): - """轻量级超时测试 - 使用模拟慢速安装的小包""" - print("\n==== 轻量级超时测试 ====\n") - - # 创建一个包含一些小包但使用极短超时的测试 - lightweight_node_dir = self.create_test_node_dir("lightweight-timeout-node") - - lightweight_requirements = """ -# 轻量级包,但使用极短超时来模拟超时场景 -requests>=2.25.0 -click>=8.0 -colorama>=0.4.0 -pytz>=2021.1 -json5>=0.9.0 - """.strip() - - self.create_requirements_file(lightweight_node_dir, lightweight_requirements) - print(f"创建了轻量级超时测试节点") - - installer = PIPInstaller() - - # 使用极短的超时时间(2秒) - very_short_timeout = 2 - - nodes_map = { - "lightweight-timeout-node": {} - } - - print(f"\n设置极短超时时间: {very_short_timeout} 秒") - print("注意: 即使是小包,5秒也很难完成安装") - - import time - start_test_time = time.time() - - try: - result_map = installer.install_all(timeout=very_short_timeout, nodes_map=nodes_map) - - test_duration = time.time() - start_test_time - print(f"\n轻量级测试完成,耗时: {test_duration:.1f} 秒") - - # 分析结果 - dep_result = result_map.get('dependencies') - if dep_result: - dep_success = dep_result.get('success', False) - dep_error = dep_result.get('error_msg', '') - - print(f"依赖安装: {'✓ 成功' if dep_success else '✗ 失败'}") - - if not dep_success and 'timeout' in dep_error.lower(): - print(f"✓ 轻量级超时测试成功:触发超时机制") - elif dep_success: - print(f"⚠ 意外完成:在 {very_short_timeout}s 内完成安装 (可能网络非常快或包已缓存)") - - return result_map - - except Exception as e: - test_duration = time.time() - start_test_time - print(f"\n轻量级超时测试异常 (耗时: {test_duration:.1f}s): {e}") - return None - - def test_nunchaku_custom_strategy(self): - """测试 ComfyUI-nunchaku 定制化依赖策略""" - print("\n==== ComfyUI-nunchaku 定制化策略测试 ====\n") - - # 创建 ComfyUI-nunchaku 测试节点 - nunchaku_node_dir = self.create_test_node_dir("ComfyUI-nunchaku") - - # 为 nunchaku 节点创建一个基本的 requirements.txt - nunchaku_requirements = """ -# ComfyUI-nunchaku 基础依赖 -requests>=2.25.0 -click>=8.0 - """.strip() - - self.create_requirements_file(nunchaku_node_dir, nunchaku_requirements) - print(f"创建了 ComfyUI-nunchaku 测试节点") - - installer = PIPInstaller() - - # 测试情景1: nunchaku 版本为 v1.0.0,应该添加定制 wheel - print("\n--- 测试情景1: nunchaku v1.0.0 ---") - - nodes_map_v1 = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "source": { - "webUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku", - "type": "github", - "cloneUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku.git" - }, - "version": { - "type": "tag", - "value": "v1.0.0" - } - } - } - - try: - result_map_v1 = installer.install_all(timeout=10, nodes_map=nodes_map_v1) - - # 分析结果 - 检查是否添加了 nunchaku wheel - dep_result_v1 = result_map_v1.get('dependencies') - if dep_result_v1 and dep_result_v1.get('requirements_txt'): - requirements_lines = dep_result_v1['requirements_txt'].strip().split('\n') - print(f"\n实际安装的依赖 ({len(requirements_lines)} 个):") - - nunchaku_wheel_found = False - expected_wheel_url = "https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/nunchaku-1.0.0+torch2.8-cp310-cp310-linux_x86_64.whl" - - for line in requirements_lines: - print(f" - {line}") - if expected_wheel_url in line: - nunchaku_wheel_found = True - - if nunchaku_wheel_found: - print(f"\n✓ 成功: 检测到 nunchaku v1.0.0,已添加定制 wheel URL") - print(f" Wheel URL: {expected_wheel_url}") - else: - print(f"\n⚠ 失败: 未找到预期的 nunchaku wheel URL") - print(f" 预期: {expected_wheel_url}") - else: - print("\n⚠ 无依赖包需要安装") - - except Exception as e: - print(f"nunchaku v1.0.0 测试出错: {e}") - import traceback - traceback.print_exc() - - # 测试情景2: nunchaku 版本为 v0.2.0,不应该添加定制 wheel - print("\n--- 测试情景2: nunchaku v0.2.0 ---") - - nodes_map_v0 = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "source": { - "webUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku", - "type": "github", - "cloneUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku.git" - }, - "version": { - "type": "tag", - "value": "v0.2.0" - } - } - } - - try: - # 创建新的安装器实例以避免状态干扰 - installer_v0 = PIPInstaller() - result_map_v0 = installer_v0.install_all(timeout=10, nodes_map=nodes_map_v0) - - # 分析结果 - 检查是否没有添加 nunchaku wheel - dep_result_v0 = result_map_v0.get('dependencies') - if dep_result_v0 and dep_result_v0.get('requirements_txt'): - requirements_lines = dep_result_v0['requirements_txt'].strip().split('\n') - print(f"\n实际安装的依赖 ({len(requirements_lines)} 个):") - - nunchaku_wheel_found = False - expected_wheel_url = "https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/nunchaku-1.0.0+torch2.8-cp310-cp310-linux_x86_64.whl" - - for line in requirements_lines: - print(f" - {line}") - if expected_wheel_url in line: - nunchaku_wheel_found = True - - if not nunchaku_wheel_found: - print(f"\n✓ 成功: nunchaku v0.2.0 未添加定制 wheel,符合预期") - else: - print(f"\n⚠ 失败: nunchaku v0.2.0 不应该添加定制 wheel") - else: - print(f"\n✓ 正常: nunchaku v0.2.0 无需安装额外依赖") - - except Exception as e: - print(f"nunchaku v0.2.0 测试出错: {e}") - import traceback - traceback.print_exc() - - # 测试情景3: 没有 ComfyUI-nunchaku 节点 - print("\n--- 测试情景3: 无 nunchaku 节点 ---") - - nodes_map_no_nunchaku = { - "basic-tools": {}, - "data-processor": {} - } - - try: - installer_no_nunchaku = PIPInstaller() - result_map_no_nunchaku = installer_no_nunchaku.install_all(timeout=10, nodes_map=nodes_map_no_nunchaku) - - # 分析结果 - 确认没有 nunchaku 相关处理 - dep_result_no_nunchaku = result_map_no_nunchaku.get('dependencies') - if dep_result_no_nunchaku and dep_result_no_nunchaku.get('requirements_txt'): - requirements_lines = dep_result_no_nunchaku['requirements_txt'].strip().split('\n') - - nunchaku_wheel_found = any( - "nunchaku" in line.lower() and "modelscope.cn" in line - for line in requirements_lines - ) - - if not nunchaku_wheel_found: - print(f"\n✓ 成功: 无 nunchaku 节点时未触发定制策略") - else: - print(f"\n⚠ 异常: 无 nunchaku 节点但检测到 nunchaku wheel") - else: - print(f"\n✓ 正常: 无 nunchaku 节点时无需安装任何依赖") - - except Exception as e: - print(f"无 nunchaku 节点测试出错: {e}") - import traceback - traceback.print_exc() - - # 测试情景4: 无效的 nodes_map 结构 - print("\n--- 测试情景4: 无效的 nodes_map 结构 ---") - - # 测试缺少 version 字段的情况 - nodes_map_invalid = { - "ComfyUI-nunchaku": { - "name": "ComfyUI-nunchaku", - "source": { - "webUrl": "https://github.com/nunchaku-tech/ComfyUI-nunchaku", - "type": "github" - } - # 没有 version 字段 - } - } - - try: - installer_invalid = PIPInstaller() - result_map_invalid = installer_invalid.install_all(timeout=10, nodes_map=nodes_map_invalid) - - # 分析结果 - 应该能够处理无效结构而不崩溃 - dep_result_invalid = result_map_invalid.get('dependencies') - if dep_result_invalid: - print(f"\n✓ 成功: 处理无效 nodes_map 结构而不崩溃") - - # 检查是否没有添加 nunchaku wheel - if dep_result_invalid.get('requirements_txt'): - requirements_lines = dep_result_invalid['requirements_txt'].strip().split('\n') - nunchaku_wheel_found = any( - "nunchaku" in line.lower() and "modelscope.cn" in line - for line in requirements_lines - ) - - if not nunchaku_wheel_found: - print(f" ✓ 正确: 无效版本信息时未添加 nunchaku wheel") - else: - print(f" ⚠ 异常: 无效版本信息但仍添加了 nunchaku wheel") - else: - print(f" ✓ 无依赖需要安装") - - except Exception as e: - print(f"无效 nodes_map 测试出错: {e}") - import traceback - traceback.print_exc() - - print("\n==== nunchaku 定制化策略测试完成 ====\n") - - return { - "v1.0.0": result_map_v1 if 'result_map_v1' in locals() else None, - "v0.2.0": result_map_v0 if 'result_map_v0' in locals() else None, - "no_nunchaku": result_map_no_nunchaku if 'result_map_no_nunchaku' in locals() else None, - "invalid": result_map_invalid if 'result_map_invalid' in locals() else None - } - - -if __name__ == "__main__": - import sys - - print("="*60) - print("PIP安装器优化版 - 真实场景测试") - print("="*60) - print("注意: 此测试会创建真实的venv环境并执行实际安装") - print("测试完成后会自动清理环境") - print() - - # 可用的测试方法 - available_tests = { - '1': 'test_install_all_direct', - '2': 'test_install_with_blacklist', - '3': 'test_empty_nodes_map', - '4': 'test_timeout_scenario', - '5': 'test_lightweight_timeout', - '6': 'test_nunchaku_custom_strategy', - 'all': 'RealScenarioTest' # 运行所有测试 - } - - print("可用的测试:") - print(" 1. 直接安装测试 (test_install_all_direct)") - print(" 2. 黑名单过滤测试 (test_install_with_blacklist)") - print(" 3. 空节点映射测试 (test_empty_nodes_map)") - print(" 4. 超时停止测试 (test_timeout_scenario) - 使用大型包") - print(" 5. 轻量级超时测试 (test_lightweight_timeout) - 使用极短超时") - print(" 6. ComfyUI-nunchaku 定制策略测试 (test_nunchaku_custom_strategy)") - print(" all. 运行所有测试") - print() - - # 检查命令行参数 - if len(sys.argv) > 1: - test_choice = sys.argv[1] - else: - test_choice = input("请选择要运行的测试 (1/2/3/4/5/6/all, 默认为 1): ").strip() or '1' - - if test_choice in available_tests: - test_name = available_tests[test_choice] - print(f"\n开始运行测试: {test_name}") - print("-" * 40) - - if test_choice == 'all': - # 运行所有测试 - unittest.main(argv=['first-arg-is-ignored'], exit=False, verbosity=2) - else: - # 运行特定测试 - unittest.main(argv=['first-arg-is-ignored', test_name], exit=False, verbosity=2) - else: - print(f"无效的选择: {test_choice}") - print("请选择 1, 2, 3, 4, 5, 6 或 all") diff --git a/src/code/agent/test/unit/services/snapshot_manager_test.py b/src/code/agent/test/unit/services/snapshot_manager_test.py deleted file mode 100644 index 1cb77b8..0000000 --- a/src/code/agent/test/unit/services/snapshot_manager_test.py +++ /dev/null @@ -1,166 +0,0 @@ -import io -import tarfile -import pytest -import os -import shutil -from datetime import datetime -from services.workspace.snapshot_manager import SnapshotManager -import constants - - -@pytest.fixture -def setup_snapshot_files(tmp_path): - # 创建快照目录结构 - snapshot_dir = tmp_path / "snapshots" - snapshot_dir.mkdir() - - # 创建多个快照目录 - snapshots = [ - "dev-20231201-120000", - "dev-20231202-115959", - "dev-20231202-120000", - "prod-20231202-120000" - ] - - for snapshot in snapshots: - snapshot_path = snapshot_dir / snapshot - snapshot_path.mkdir() - - # 创建comfyui.zip文件 - import zipfile - comfyui_zip = snapshot_path / "comfyui.zip" - with zipfile.ZipFile(comfyui_zip, 'w') as zf: - zf.writestr('comfyui/test.txt', 'test content') - - # 创建venv.tar文件 - tar_path = snapshot_path / "venv.tar" - with tarfile.open(tar_path, "w") as tar: - test_content = "This is a test file in venv" - test_file = io.BytesIO(test_content.encode()) - tarinfo = tarfile.TarInfo(name="venv/test_venv.txt") - tarinfo.size = len(test_content) - tar.addfile(tarinfo, test_file) - - # 创建挂载目录和必要的子目录 - mnt_dir = tmp_path / "mnt" - mnt_dir.mkdir() - - # 创建models目录 - models_dir = mnt_dir / "models" - models_dir.mkdir() - (models_dir / "test_model.bin").write_text("test model content") - - # 创建custom_nodes目录 - custom_nodes_dir = mnt_dir / "custom_nodes" - custom_nodes_dir.mkdir() - (custom_nodes_dir / "test_node.py").write_text("test node content") - - # 设置常量 - constants.SNAPSHOT_DIR = str(snapshot_dir) - constants.WORK_DIR = str(tmp_path / "work") - constants.MNT_DIR = str(mnt_dir) - constants.MODEL_DIR = str(mnt_dir / "models") - constants.COMFYUI_DIR = str(tmp_path / "work/comfyui") - constants.BACKEND_TYPE = constants.TYPE_COMFYUI - constants.USE_API_MODE = False - - os.makedirs(constants.WORK_DIR, exist_ok=True) - - return tmp_path - - -def test_load_latest_dev_snapshot(setup_snapshot_files): - manager = SnapshotManager() - result = manager.load(SnapshotManager.USE_LATEST_DEV) - - assert result["snapshot"] == "dev-20231202-120000" - assert manager.snapshot_name == "dev-20231202-120000" - - # 验证模型目录软链接 - models_link = os.path.join(constants.COMFYUI_DIR, "models") - assert os.path.islink(models_link) - assert os.readlink(models_link) == os.path.join(constants.MNT_DIR, "models") - - -def test_load_latest_prod_snapshot(setup_snapshot_files): - manager = SnapshotManager() - result = manager.load(SnapshotManager.USE_LATEST_PROD) - - assert result["snapshot"] == "prod-20231202-120000" - assert manager.snapshot_name == "prod-20231202-120000" - - -def test_load_specific_snapshot(setup_snapshot_files): - manager = SnapshotManager() - result = manager.load("dev-20231202-115959") - - assert result["snapshot"] == "dev-20231202-115959" - assert manager.snapshot_name == "dev-20231202-115959" - - -def test_load_nonexistent_snapshot(setup_snapshot_files): - manager = SnapshotManager() - with pytest.raises(RuntimeError): - manager.load("nonexistent") - - -def test_load_same_snapshot_twice(setup_snapshot_files): - manager = SnapshotManager() - first_load = manager.load("dev-20231202-120000") - second_load = manager.load("dev-20231202-120000") - - assert first_load["snapshot"] == second_load["snapshot"] - assert manager.snapshot_name == "dev-20231202-120000" - - -def test_select_latest_snapshot_with_invalid_format(setup_snapshot_files): - invalid_snapshot = os.path.join(constants.SNAPSHOT_DIR, "invalid_format") - os.makedirs(invalid_snapshot) - - manager = SnapshotManager() - latest = manager._select_latest_snapshot(SnapshotManager.TYPE_DEV) - - assert latest == "dev-20231202-120000" - - -def test_select_latest_snapshot_empty_dir(tmp_path): - empty_dir = tmp_path / "empty" - empty_dir.mkdir() - constants.SNAPSHOT_DIR = str(empty_dir) - - manager = SnapshotManager() - latest = manager._select_latest_snapshot(SnapshotManager.TYPE_DEV) - - assert latest is None - - -def test_save_dev_snapshot(setup_snapshot_files): - manager = SnapshotManager() - result = manager.save(SnapshotManager.TYPE_DEV) - - assert "snapshot" in result - snapshot_name = result["snapshot"] - assert snapshot_name.startswith("dev-") - assert manager.snapshot_name == snapshot_name - - -def test_save_prod_snapshot(setup_snapshot_files): - manager = SnapshotManager() - result = manager.save(SnapshotManager.TYPE_PROD) - - assert "snapshot" in result - snapshot_name = result["snapshot"] - assert snapshot_name.startswith("prod-") - assert manager.snapshot_name == snapshot_name - - -def test_save_invalid_type(setup_snapshot_files): - manager = SnapshotManager() - with pytest.raises(RuntimeError): - manager.save("invalid") - - -@pytest.fixture(autouse=True) -def cleanup(setup_snapshot_files): - yield - shutil.rmtree(setup_snapshot_files) diff --git a/src/code/agent/test/unit/services/workspace/__init__.py b/src/code/agent/test/unit/services/workspace/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/code/agent/test/unit/services/workspace/snapshot_loader_test.py b/src/code/agent/test/unit/services/workspace/snapshot_loader_test.py deleted file mode 100644 index 2ee24c1..0000000 --- a/src/code/agent/test/unit/services/workspace/snapshot_loader_test.py +++ /dev/null @@ -1,344 +0,0 @@ -import zipfile -import pytest -import os -import json -import shutil -import tarfile -import io - -import constants -from services.workspace.snapshot_loader import ComfyUIDevSnapshotLoader, SDSnapshotLoader, ComfyUIProdSnapshotLoader - - -class MockTimer: - def __init__(self, name): - self.name = name - self.elapsed = 1.5 - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - -@pytest.fixture -def setup_dirs(tmp_path): - # 设置基础目录 - constants.WORK_DIR = str(tmp_path / "work") - constants.MNT_DIR = str(tmp_path / "mnt") - constants.MODEL_DIR = str(tmp_path / "mnt/models") - constants.COMFYUI_DIR = str(tmp_path / "work/comfyui") - constants.SD_DIR = str(tmp_path / "work/stable-diffusion-webui") - constants.VENV_DIR = str(tmp_path / "work/venv") - - # 创建必要的目录 - os.makedirs(constants.WORK_DIR, exist_ok=True) - os.makedirs(constants.MNT_DIR, exist_ok=True) - os.makedirs(constants.MODEL_DIR, exist_ok=True) - os.makedirs(os.path.join(constants.MNT_DIR, "custom_nodes"), exist_ok=True) - - return tmp_path - - -def create_mock_zip(path, content_dir=None, extra_files=None, flat=False): - """ - 创建测试用zip文件 - :param path: zip文件路径 - :param content_dir: 内容目录 - :param extra_files: 额外要添加的文件字典 {"相对路径": "内容"} - :param flat: 是否将文件直接放在根目录下,而不是在子目录中 - """ - with zipfile.ZipFile(path, 'w') as zf: - if content_dir: - # 确保目录存在 - os.makedirs(content_dir, exist_ok=True) - - # 添加测试文件 - test_file = os.path.join(content_dir, "test.txt") - with open(test_file, 'w') as f: - f.write("test content") - - arcname = "test.txt" if flat else os.path.join(os.path.basename(content_dir), "test.txt") - zf.write(test_file, arcname) - - # 添加额外文件 - if extra_files: - for file_path, content in extra_files.items(): - full_path = os.path.join(content_dir, file_path) - os.makedirs(os.path.dirname(full_path), exist_ok=True) - with open(full_path, 'w') as f: - if isinstance(content, dict): - json.dump(content, f, indent=4) - else: - f.write(content) - arcname = file_path if flat else os.path.join(os.path.basename(content_dir), file_path) - zf.write(full_path, arcname) - else: - zf.writestr("test.txt", "test content") - - -def create_mock_tar(path): - """创建一个有效的测试用 tar 文件""" - with tarfile.open(path, "w:gz") as tar: - test_content = b"Test venv content" - test_file = io.BytesIO(test_content) - tarinfo = tarfile.TarInfo(name="venv/test.txt") - tarinfo.size = len(test_content) - tar.addfile(tarinfo, test_file) - - -@pytest.fixture -def mock_timer(): - return lambda x: MockTimer(x) - - -class TestComfyUIDevSnapshotLoader: - @pytest.fixture - def loader(self, mock_timer): - return ComfyUIDevSnapshotLoader(mock_timer) - - def test_load_snapshot(self, setup_dirs, loader): - snapshot_path = os.path.join(setup_dirs, "snapshot") - os.makedirs(snapshot_path) - - # 创建测试文件 - create_mock_tar(os.path.join(snapshot_path, "venv.tar")) - create_mock_zip(os.path.join(snapshot_path, "comfyui.zip"), content_dir=os.path.join(setup_dirs, "comfyui")) - create_mock_zip(os.path.join(snapshot_path, ".cache.zip"), content_dir=os.path.join(setup_dirs, ".cache")) - - stage_cost = loader.load(snapshot_path) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_clear", "time_download", "time_extract"]) - - # 验证目录结构 - assert os.path.exists(constants.COMFYUI_DIR) - assert os.path.exists(os.path.join(constants.COMFYUI_DIR, "test.txt")) - - # 验证软链接 - models_link = os.path.join(constants.COMFYUI_DIR, "models") - custom_nodes_link = os.path.join(constants.COMFYUI_DIR, "custom_nodes") - assert os.path.islink(models_link) - assert os.path.islink(custom_nodes_link) - assert os.readlink(models_link) == os.path.join(constants.MNT_DIR, "models") - assert os.readlink(custom_nodes_link) == os.path.join(constants.MNT_DIR, "custom_nodes") - - -class TestComfyUIProdSnapshotLoader: - @pytest.fixture - def loader(self, mock_timer): - return ComfyUIProdSnapshotLoader(mock_timer) - - def test_load_snapshot(self, setup_dirs, loader): - snapshot_path = os.path.join(setup_dirs, "snapshot") - os.makedirs(snapshot_path) - - # 创建测试文件 - create_mock_tar(os.path.join(snapshot_path, "venv.tar")) - create_mock_zip(os.path.join(snapshot_path, "comfyui.zip"), - content_dir=os.path.join(setup_dirs, "comfyui")) - create_mock_zip(os.path.join(snapshot_path, "custom_nodes.zip"), - content_dir=os.path.join(setup_dirs, "custom_nodes"), - flat=True) - create_mock_zip(os.path.join(snapshot_path, ".cache.zip"), - content_dir=os.path.join(setup_dirs, ".cache")) - - stage_cost = loader.load(snapshot_path) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_clear", "time_download", "time_extract"]) - - # 验证目录结构 - assert os.path.exists(constants.COMFYUI_DIR) - assert os.path.exists(os.path.join(constants.COMFYUI_DIR, "test.txt")) - - # 验证软链接 - models_link = os.path.join(constants.COMFYUI_DIR, "models") - assert os.path.islink(models_link) - assert os.readlink(models_link) == os.path.join(constants.MNT_DIR, "models") - - # 验证custom_nodes目录 - custom_nodes_dir = os.path.join(constants.COMFYUI_DIR, "custom_nodes") - assert os.path.exists(custom_nodes_dir) - assert not os.path.islink(custom_nodes_dir) # 确保不是软链接 - assert os.path.exists(os.path.join(custom_nodes_dir, "test.txt")) # 验证custom_nodes内容 - - def test_load_snapshot_without_cache(self, setup_dirs, loader): - """测试没有cache文件的情况""" - snapshot_path = os.path.join(setup_dirs, "snapshot") - os.makedirs(snapshot_path) - - create_mock_tar(os.path.join(snapshot_path, "venv.tar")) - create_mock_zip(os.path.join(snapshot_path, "comfyui.zip"), - content_dir=os.path.join(setup_dirs, "comfyui")) - create_mock_zip(os.path.join(snapshot_path, "custom_nodes.zip"), - content_dir=os.path.join(setup_dirs, "custom_nodes"), - flat=True) - - stage_cost = loader.load(snapshot_path) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_clear", "time_download", "time_extract"]) - - # 验证目录结构 - assert os.path.exists(constants.COMFYUI_DIR) - assert os.path.exists(os.path.join(constants.COMFYUI_DIR, "test.txt")) - - # 验证软链接 - models_link = os.path.join(constants.COMFYUI_DIR, "models") - assert os.path.islink(models_link) - assert os.readlink(models_link) == os.path.join(constants.MNT_DIR, "models") - - # 验证custom_nodes目录 - custom_nodes_dir = os.path.join(constants.COMFYUI_DIR, "custom_nodes") - assert os.path.exists(custom_nodes_dir) - assert not os.path.islink(custom_nodes_dir) - assert os.path.exists(os.path.join(custom_nodes_dir, "test.txt")) - - -class TestSDSnapshotLoader: - @pytest.fixture - def loader(self, mock_timer): - return SDSnapshotLoader(mock_timer) - - def test_load_snapshot(self, setup_dirs, loader): - snapshot_path = os.path.join(setup_dirs, "snapshot") - os.makedirs(snapshot_path) - - # 创建测试文件 - create_mock_tar(os.path.join(snapshot_path, "venv.tar")) - - # 创建stable-diffusion-webui.zip,包含config.json - create_mock_zip( - os.path.join(snapshot_path, "stable-diffusion-webui.zip"), - content_dir=os.path.join(setup_dirs, "stable-diffusion-webui"), - extra_files={ - "config.json": { - "test_key": "test_value", - "outdir_samples": "old_path" - } - } - ) - - create_mock_zip( - os.path.join(snapshot_path, ".cache.zip"), - content_dir=os.path.join(setup_dirs, ".cache") - ) - - stage_cost = loader.load(snapshot_path) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_clear", "time_download", "time_extract"]) - - # 验证目录结构 - assert os.path.exists(constants.SD_DIR) - assert os.path.exists(os.path.join(constants.SD_DIR, "test.txt")) - - # 验证软链接 - models_link = os.path.join(constants.SD_DIR, "models") - assert os.path.islink(models_link) - assert os.readlink(models_link) == os.path.join(constants.MNT_DIR, "models") - - # 验证配置文件 - config_link = os.path.join(constants.SD_DIR, "config.json") - assert os.path.islink(config_link) - assert os.readlink(config_link) == os.path.join(constants.MNT_DIR, "config.json") - - # 验证配置内容 - with open(os.path.join(constants.MNT_DIR, "config.json"), "r") as f: - new_config = json.load(f) - assert new_config["outdir_samples"] == f"{constants.MNT_DIR}/output" - assert new_config["outdir_grids"] == f"{constants.MNT_DIR}/output" - assert new_config["outdir_save"] == f"{constants.MNT_DIR}/output/saves" - assert new_config["outdir_init_images"] == f"{constants.MNT_DIR}/input" - assert new_config["save_init_img"] is True - assert new_config["test_key"] == "test_value" - - def test_load_snapshot_without_cache(self, setup_dirs, loader): - """测试没有cache文件的情况""" - snapshot_path = os.path.join(setup_dirs, "snapshot") - os.makedirs(snapshot_path) - - # 创建测试文件 - create_mock_tar(os.path.join(snapshot_path, "venv.tar")) - create_mock_zip( - os.path.join(snapshot_path, "stable-diffusion-webui.zip"), - content_dir=os.path.join(setup_dirs, "stable-diffusion-webui"), - extra_files={ - "config.json": { - "test_key": "test_value", - "outdir_samples": "old_path" - } - } - ) - - stage_cost = loader.load(snapshot_path) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_clear", "time_download", "time_extract"]) - - # 验证目录结构和软链接 - assert os.path.exists(constants.SD_DIR) - assert os.path.exists(os.path.join(constants.SD_DIR, "test.txt")) - assert os.path.islink(os.path.join(constants.SD_DIR, "models")) - assert os.path.islink(os.path.join(constants.SD_DIR, "config.json")) - - # 验证配置内容 - with open(os.path.join(constants.MNT_DIR, "config.json"), "r") as f: - new_config = json.load(f) - assert new_config["outdir_samples"] == f"{constants.MNT_DIR}/output" - assert new_config["test_key"] == "test_value" - - def test_load_snapshot_existing_config(self, setup_dirs, loader): - """测试已存在配置文件的情况""" - # 创建已存在的配置文件 - existing_config = { - "existing": "config", - "outdir_samples": "/custom/path", - "outdir_grids": "/custom/grid/path", - "outdir_save": "/custom/save/path", - "outdir_init_images": "/custom/init/path", - "save_init_img": False - } - with open(os.path.join(constants.MNT_DIR, "config.json"), "w") as f: - json.dump(existing_config, f) - - snapshot_path = os.path.join(setup_dirs, "snapshot") - os.makedirs(snapshot_path) - - # 创建测试文件 - create_mock_tar(os.path.join(snapshot_path, "venv.tar")) - create_mock_zip( - os.path.join(snapshot_path, "stable-diffusion-webui.zip"), - content_dir=os.path.join(setup_dirs, "stable-diffusion-webui"), - extra_files={ - "config.json": { - "test_key": "test_value", - "outdir_samples": "old_path" - } - } - ) - - loader.load(snapshot_path) - - # 验证配置文件未被修改 - with open(os.path.join(constants.MNT_DIR, "config.json"), "r") as f: - config = json.load(f) - assert config == existing_config - - # 验证软链接 - config_link = os.path.join(constants.SD_DIR, "config.json") - assert os.path.islink(config_link) - assert os.readlink(config_link) == os.path.join(constants.MNT_DIR, "config.json") - - # 验证目录结构 - assert os.path.exists(constants.SD_DIR) - assert os.path.exists(os.path.join(constants.SD_DIR, "test.txt")) - assert os.path.islink(os.path.join(constants.SD_DIR, "models")) - - -@pytest.fixture(autouse=True) -def cleanup(setup_dirs): - yield - shutil.rmtree(setup_dirs) diff --git a/src/code/agent/test/unit/services/workspace/snapshot_saver_test.py b/src/code/agent/test/unit/services/workspace/snapshot_saver_test.py deleted file mode 100644 index 9e27805..0000000 --- a/src/code/agent/test/unit/services/workspace/snapshot_saver_test.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import pytest -import shutil - -import constants -from services.workspace.snapshot_saver import SDSnapshotSaver, ComfyUISnapshotSaver - - -@pytest.fixture -def setup_dirs(tmp_path): - """设置测试环境目录""" - constants.WORK_DIR = str(tmp_path / "work") - constants.SNAPSHOT_DIR = str(tmp_path / "snapshots") - - # 创建必要的目录 - os.makedirs(constants.WORK_DIR, exist_ok=True) - os.makedirs(constants.SNAPSHOT_DIR, exist_ok=True) - os.makedirs(os.path.join(constants.WORK_DIR, "venv")) - os.makedirs(os.path.join(constants.WORK_DIR, "stable-diffusion-webui")) - os.makedirs(os.path.join(constants.WORK_DIR, "comfyui")) - os.makedirs(os.path.join(constants.WORK_DIR, ".cache")) - - # 创建一些测试文件 - with open(os.path.join(constants.WORK_DIR, "venv/test.txt"), "w") as f: - f.write("test venv content") - with open(os.path.join(constants.WORK_DIR, "stable-diffusion-webui/test.txt"), "w") as f: - f.write("test sd content") - with open(os.path.join(constants.WORK_DIR, "comfyui/test.txt"), "w") as f: - f.write("test comfy content") - - return tmp_path - - -class MockTimer: - def __init__(self, name): - self.name = name - self.elapsed = 1.5 - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - -@pytest.fixture -def mock_timer(): - return lambda x: MockTimer(x) - - -@pytest.fixture(autouse=True) -def cleanup(setup_dirs): - yield - shutil.rmtree(setup_dirs) - - -class TestSDSnapshotSaver: - @pytest.fixture - def saver(self, mock_timer): - return SDSnapshotSaver(mock_timer) - - def test_save_snapshot(self, setup_dirs, saver): - """测试保存SD快照基本功能""" - snapshot_name = "test_snapshot" - stage_cost = saver.save(snapshot_name) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_compress", "time_upload"]) - assert "time_clear" not in stage_cost - - # 验证生成的文件 - snapshot_path = os.path.join(constants.SNAPSHOT_DIR, snapshot_name) - assert os.path.exists(os.path.join(snapshot_path, "venv.tar")) - assert os.path.exists(os.path.join(snapshot_path, "stable-diffusion-webui.zip")) - - def test_save_snapshot_with_cache(self, setup_dirs, saver): - """测试保存带缓存的快照""" - snapshot_name = "test_snapshot" - with open(os.path.join(constants.WORK_DIR, ".cache/test.txt"), "w") as f: - f.write("test cache") - - stage_cost = saver.save(snapshot_name) - - snapshot_path = os.path.join(constants.SNAPSHOT_DIR, snapshot_name) - assert os.path.exists(os.path.join(snapshot_path, ".cache.zip")) - - def test_save_with_remove_old(self, setup_dirs, saver): - """测试保存快照并删除旧快照""" - # 创建旧快照 - old_name = "old_snapshot" - old_path = os.path.join(constants.SNAPSHOT_DIR, old_name) - os.makedirs(old_path) - with open(os.path.join(old_path, "test.txt"), "w") as f: - f.write("old snapshot") - - # 保存新快照并删除旧的 - new_name = "new_snapshot" - stage_cost = saver.save(new_name, remove_old=True, old_snapshot_name=old_name) - - # 验证旧快照被删除 - assert not os.path.exists(old_path) - # 验证清理时间统计 - assert "time_clear" in stage_cost - - -class TestComfyUISnapshotSaver: - @pytest.fixture - def saver(self, mock_timer): - return ComfyUISnapshotSaver(mock_timer) - - def test_save_snapshot(self, setup_dirs, saver): - """测试保存ComfyUI开发版快照""" - snapshot_name = "test_snapshot" - stage_cost = saver.save(snapshot_name) - - # 验证时间统计 - assert all(key in stage_cost for key in ["time_compress", "time_upload"]) - - # 验证生成的文件 - snapshot_path = os.path.join(constants.SNAPSHOT_DIR, snapshot_name) - assert os.path.exists(os.path.join(snapshot_path, "venv.tar")) - assert os.path.exists(os.path.join(snapshot_path, "comfyui.zip")) - - def test_save_snapshot_with_cache(self, setup_dirs, saver): - """测试保存带缓存的ComfyUI开发版快照""" - snapshot_name = "test_snapshot" - # 创建缓存文件 - with open(os.path.join(constants.WORK_DIR, ".cache/test.txt"), "w") as f: - f.write("test cache") - - stage_cost = saver.save(snapshot_name) - - snapshot_path = os.path.join(constants.SNAPSHOT_DIR, snapshot_name) - assert os.path.exists(os.path.join(snapshot_path, ".cache.zip")) diff --git a/src/code/agent/test/unit/utils/__init__.py b/src/code/agent/test/unit/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/code/agent/test/unit/utils/file_ops_test.py b/src/code/agent/test/unit/utils/file_ops_test.py deleted file mode 100644 index 28767dd..0000000 --- a/src/code/agent/test/unit/utils/file_ops_test.py +++ /dev/null @@ -1,555 +0,0 @@ -import os -import shutil -import tarfile -import pytest -import zipfile -from utils.file_ops import copy, move, remove, compress, extract, create_symlink - - -# 创建测试用的临时目录和文件 -@pytest.fixture -def setup_test_files(tmp_path): - # 创建源文件和目录 - source_dir = tmp_path / "source" - source_dir.mkdir() - - # 创建测试文件 - test_file = source_dir / "test.txt" - test_file.write_text("test content") - - # 创建测试子目录 - test_subdir = source_dir / "subdir" - test_subdir.mkdir() - (test_subdir / "subfile.txt").write_text("subfile content") - - # 创建软链接 - symlink_path = source_dir / "symlink.txt" - os.symlink(str(test_file), str(symlink_path)) - - # 在子目录中创建同样的软链接 - subdir_symlink = test_subdir / "symlink.txt" - os.symlink(str(test_file), str(subdir_symlink)) - - return tmp_path - - -def test_copy_file(setup_test_files): - source_file = setup_test_files / "source" / "test.txt" - target_file = setup_test_files / "target" / "test.txt" - - copy(str(source_file), str(target_file)) - - assert target_file.exists() - assert target_file.read_text() == "test content" - - -def test_copy_directory(setup_test_files): - source_dir = setup_test_files / "source" - target_dir = setup_test_files / "target" - - copy(str(source_dir), str(target_dir)) - - assert target_dir.exists() - assert (target_dir / "test.txt").exists() - assert (target_dir / "subdir" / "subfile.txt").exists() - - -def test_copy_symlink(setup_test_files): - source_dir = setup_test_files / "source" - test_file = source_dir / "test.txt" - symlink_file = source_dir / "link" - os.symlink(str(test_file), str(symlink_file)) - - target_dir = setup_test_files / "target" - - copy(str(source_dir), str(target_dir)) - - target_symlink = target_dir / "link" - assert os.path.islink(str(target_symlink)) - assert os.path.realpath(str(target_symlink)) == os.path.realpath(str(test_file)) - - -def test_copy_symlink_file(setup_test_files): - source_dir = setup_test_files / "source" - test_file = source_dir / "test.txt" - symlink_file = source_dir / "link" - os.symlink(str(test_file), str(symlink_file)) - - target_symlink = setup_test_files / "target" / "link" - - copy(str(symlink_file), str(target_symlink)) - - assert os.path.islink(str(target_symlink)) - assert os.readlink(str(target_symlink)) == str(test_file) - - -def test_copy_symlink_to_directory(setup_test_files): - source_dir = setup_test_files / "source" - subdir = source_dir / "subdir" - symlink_dir = source_dir / "link_to_subdir" - os.symlink(str(subdir), str(symlink_dir)) - - target_dir = setup_test_files / "target" - - copy(str(source_dir), str(target_dir)) - - target_symlink = target_dir / "link_to_subdir" - assert os.path.islink(str(target_symlink)) - assert os.path.realpath(str(target_symlink)) == os.path.realpath(str(subdir)) - - -def test_move_file(setup_test_files): - source_file = setup_test_files / "source" / "test.txt" - target_file = setup_test_files / "target" / "test.txt" - - move(str(source_file), str(target_file)) - - assert target_file.exists() - assert not source_file.exists() - assert target_file.read_text() == "test content" - - -def test_move_directory(setup_test_files): - source_dir = setup_test_files / "source" - target_dir = setup_test_files / "target" - - move(str(source_dir), str(target_dir)) - - assert target_dir.exists() - assert not source_dir.exists() - assert (target_dir / "test.txt").exists() - assert (target_dir / "subdir" / "subfile.txt").exists() - - -def test_remove_file(setup_test_files): - test_file = setup_test_files / "source" / "test.txt" - - assert test_file.exists() - remove(str(test_file)) - assert not test_file.exists() - - -def test_remove_directory(setup_test_files): - source_dir = setup_test_files / "source" - - assert source_dir.exists() - remove(str(source_dir)) - assert not source_dir.exists() - - -def test_remove_nonexistent_path(setup_test_files): - nonexistent_path = setup_test_files / "nonexistent" - - # Should print message but not raise exception - remove(str(nonexistent_path)) - - -def test_remove_symlink(setup_test_files): - source_dir = setup_test_files / "source" - test_file = source_dir / "test.txt" - symlink_path = source_dir / "symlink.txt" - - # 确保软链接和原文件都存在 - assert symlink_path.exists() - assert test_file.exists() - assert os.path.islink(str(symlink_path)) - - # 删除软链接 - remove(str(symlink_path)) - - # 验证软链接被删除,但原文件仍然存在 - assert not os.path.lexists(symlink_path) - assert test_file.exists() - - -def test_remove_dangling_symlink(setup_test_files): - source_dir = setup_test_files / "source" - test_file = source_dir / "test.txt" - dangling_symlink = source_dir / "dangling_symlink.txt" - - # 先创建软链接 - os.symlink(str(test_file), str(dangling_symlink)) - - # 删除源文件,使软链接成为悬空链接 - os.remove(str(test_file)) - - # 确认初始状态:软链接存在但是指向的文件不存在 - assert os.path.islink(str(dangling_symlink)) - assert not os.path.exists(str(dangling_symlink)) # 因为目标文件不存在,所以exists返回False - - # 删除悬空的软链接 - remove(str(dangling_symlink)) - - # 验证软链接被成功删除 - assert not os.path.islink(str(dangling_symlink)) - assert not os.path.lexists(str(dangling_symlink)) - - -def test_compress_all_files(setup_test_files): - source_dir = setup_test_files / "source" - tar_file = setup_test_files / "test.tar" - - compress(str(tar_file), str(source_dir)) - - assert tar_file.exists() - - -def test_compress_selected_files(setup_test_files): - source_dir = setup_test_files / "source" - tar_file = setup_test_files / "test.tar" - - compress(str(tar_file), str(source_dir), ["test.txt"]) - - assert tar_file.exists() - - -def test_extract(setup_test_files): - # 首先创建一个tar文件 - source_dir = setup_test_files / "source" - tar_file = setup_test_files / "test.tar" - output_dir = setup_test_files / "extracted" - - compress(str(tar_file), str(source_dir)) - extract(str(tar_file), str(output_dir)) - - assert output_dir.exists() - assert (output_dir / "test.txt").exists() - assert (output_dir / "subdir" / "subfile.txt").exists() - - -def test_extract_to_default_location(setup_test_files): - source_dir = setup_test_files / "source" - tar_file = setup_test_files / "test.tar" - - compress(str(tar_file), str(source_dir)) - extract(str(tar_file)) - - assert tar_file.parent.exists() - assert (tar_file.parent / "test.txt").exists() - assert (tar_file.parent / "subdir" / "subfile.txt").exists() - - -def test_extract_zip(setup_test_files): - # 首先创建一个zip文件 - source_dir = setup_test_files / "source" - zip_file = setup_test_files / "test.zip" - output_dir = setup_test_files / "extracted" - - compress(str(zip_file), str(source_dir)) - extract(str(zip_file), str(output_dir)) - - assert output_dir.exists() - assert (output_dir / "test.txt").exists() - assert (output_dir / "subdir" / "subfile.txt").exists() - - # 验证文件内容 - assert (output_dir / "test.txt").read_text() == "test content" - assert (output_dir / "subdir" / "subfile.txt").read_text() == "subfile content" - - -def test_extract_zip_to_default_location(setup_test_files): - source_dir = setup_test_files / "source" - zip_file = setup_test_files / "test.zip" - - compress(str(zip_file), str(source_dir)) - extract(str(zip_file)) - - assert zip_file.parent.exists() - assert (zip_file.parent / "test.txt").exists() - assert (zip_file.parent / "subdir" / "subfile.txt").exists() - - -def test_extract_with_existing_files(setup_test_files): - # 准备源目录和目标目录 - source_dir = setup_test_files / "source" - zip_file = setup_test_files / "test.zip" - output_dir = setup_test_files / "output" - - # 创建zip文件 - compress(str(zip_file), str(source_dir)) - - # 在解压目标目录下创建同名文件 - output_dir.mkdir() - existing_file = output_dir / "subdir" - existing_file.write_text("existing content") - - # 执行解压操作,会因为解压时subdir文件出现冲突而抛错,extract方法并不会在解压时强制覆盖 - with pytest.raises(Exception): - extract(str(zip_file), output_dir=str(output_dir)) - - -def test_compress_tar_symlink_behavior(setup_test_files): - """测试tar格式压缩时对软链接的处理行为""" - source_dir = setup_test_files / "source" - tar_file = setup_test_files / "test.tar" - output_dir = setup_test_files / "extracted" - - compress(str(tar_file), str(source_dir)) - - # 检查tar文件中的软链接 - with tarfile.open(str(tar_file), 'r') as tar: - # 验证根目录下的软链接 - symlink_info = tar.getmember("symlink.txt") - assert tarfile.SYMTYPE in symlink_info.type - - # 验证子目录中的软链接 - subdir_symlink_info = tar.getmember("subdir/symlink.txt") - assert tarfile.SYMTYPE in subdir_symlink_info.type - - # 解压并验证 - extract(str(tar_file), str(output_dir)) - - # 验证解压后的软链接 - # 根目录软链接 - extracted_symlink = output_dir / "symlink.txt" - assert os.path.islink(str(extracted_symlink)) - assert os.readlink(str(extracted_symlink)) == str(setup_test_files / "source" / "test.txt") - - # 子目录中的软链接 - extracted_subdir_symlink = output_dir / "subdir" / "symlink.txt" - assert os.path.islink(str(extracted_subdir_symlink)) - assert os.readlink(str(extracted_subdir_symlink)) == str(setup_test_files / "source" / "test.txt") - - -def test_compress_all_files_zip(setup_test_files): - source_dir = setup_test_files / "source" - zip_file = setup_test_files / "test.zip" - - compress(str(zip_file), str(source_dir)) - - assert zip_file.exists() - # 验证是否为有效的zip文件 - assert zipfile.is_zipfile(str(zip_file)) - - -def test_compress_selected_files_zip(setup_test_files): - source_dir = setup_test_files / "source" - zip_file = setup_test_files / "test.zip" - - compress(str(zip_file), str(source_dir), ["test.txt"]) - - assert zip_file.exists() - # 验证zip文件中只包含选定的文件 - with zipfile.ZipFile(str(zip_file), 'r') as zf: - assert len(zf.namelist()) == 1 - assert "test.txt" in zf.namelist() - - -def test_compress_zip_symlink_behavior(setup_test_files): - """测试zip格式压缩时对软链接的处理行为""" - source_dir = setup_test_files / "source" - zip_file = setup_test_files / "test.zip" - output_dir = setup_test_files / "extracted" - - compress(str(zip_file), str(source_dir)) - - # 检查zip文件中的内容 - with zipfile.ZipFile(str(zip_file), 'r') as zf: - file_list = zf.namelist() - # 验证仅包含原始文件,不包含软链接及其指向的内容副本 - expected_files = {'test.txt', 'symlink.txt', 'subdir/subfile.txt', 'subdir/symlink.txt'} - assert set(file_list) == expected_files - - # 解压并验证 - extract(str(zip_file), str(output_dir)) - - # 验证解压后的目录结构完整性 - assert (output_dir / "test.txt").exists() - assert (output_dir / "symlink.txt").exists() - assert not os.path.islink(str(output_dir / "symlink.txt")) - assert (output_dir / "subdir" / "subfile.txt").exists() - assert (output_dir / "subdir" / "symlink.txt").exists() - assert not os.path.islink(str(output_dir / "subdir" / "symlink.txt")) - - -def test_compress_invalid_format(setup_test_files): - """测试不支持的压缩格式""" - source_dir = setup_test_files / "source" - invalid_file = setup_test_files / "test.rar" - - with pytest.raises(ValueError) as exc_info: - compress(str(invalid_file), str(source_dir)) - assert "Unsupported archive format" in str(exc_info.value) - - -def test_extract_invalid_format(setup_test_files): - """测试不支持的解压格式""" - invalid_file = setup_test_files / "test.rar" - - # 创建一个假的rar文件 - invalid_file.write_text("fake content") - - with pytest.raises(ValueError) as exc_info: - extract(str(invalid_file)) - assert "Unsupported archive format" in str(exc_info.value) - - -# 错误处理测试 -def test_copy_nonexistent_source(): - with pytest.raises(Exception): - copy("nonexistent_file", "target") - - -def test_move_nonexistent_source(): - with pytest.raises(Exception): - move("nonexistent_file", "target") - - -def test_extract_nonexistent_tar(): - with pytest.raises(Exception): - extract("nonexistent.tar") - - -def test_create_symlink(setup_test_files): - source_file = setup_test_files / "source" / "test.txt" - link_path = setup_test_files / "link.txt" - - create_symlink(str(source_file), str(link_path)) - - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(source_file) - assert link_path.read_text() == "test content" - - -def test_create_symlink_force(setup_test_files): - # 创建一个原始文件 - original_file = setup_test_files / "original.txt" - original_file.write_text("original content") - - # 创建目标文件 - source_file = setup_test_files / "source" / "test.txt" - link_path = setup_test_files / "original.txt" - - # 使用force选项创建符号链接 - create_symlink(str(source_file), str(link_path), force=True) - - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(source_file) - assert link_path.read_text() == "test content" - - -def test_create_symlink_force_existing_symlink(setup_test_files): - # 创建源文件 - source_file1 = setup_test_files / "source1.txt" - source_file1.write_text("source1 content") - - source_file2 = setup_test_files / "source2.txt" - source_file2.write_text("source2 content") - - # 先创建指向source1的软链接 - link_path = setup_test_files / "link.txt" - os.symlink(str(source_file1), str(link_path)) - - # 验证初始软链接 - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(source_file1) - assert link_path.read_text() == "source1 content" - - # 使用force选项创建指向source2的新软链接 - create_symlink(str(source_file2), str(link_path), force=True) - - # 验证软链接已更新 - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(source_file2) - assert link_path.read_text() == "source2 content" - - -def test_create_symlink_force_existing_broken_symlink(setup_test_files): - # 创建新的源文件 - source_file = setup_test_files / "source.txt" - source_file.write_text("source content") - - # 创建一个指向不存在文件的损坏软链接 - link_path = setup_test_files / "broken_link.txt" - nonexistent_path = setup_test_files / "nonexistent.txt" - os.symlink(str(nonexistent_path), str(link_path)) - - # 验证初始软链接状态 - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(nonexistent_path) - assert not os.path.exists(str(link_path)) # 验证链接是损坏的 - assert os.path.lexists(str(link_path)) # 但链接本身存在 - - # 使用force选项创建新的软链接 - create_symlink(str(source_file), str(link_path), force=True) - - # 验证软链接已更新 - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(source_file) - assert os.path.exists(str(link_path)) # 现在链接是有效的 - assert link_path.read_text() == "source content" - - -def test_create_symlink_force_existing_directory(setup_test_files): - # 创建源目录 - source_dir = setup_test_files / "source_dir" - source_dir.mkdir() - (source_dir / "test.txt").write_text("source content") - - # 创建目标目录并添加一些内容 - target_dir = setup_test_files / "target_dir" - target_dir.mkdir() - (target_dir / "existing.txt").write_text("existing content") - - # 使用force选项创建软链接 - create_symlink(str(source_dir), str(target_dir), force=True) - - # 验证目录已被软链接替换 - assert os.path.islink(str(target_dir)) - assert os.readlink(str(target_dir)) == str(source_dir) - assert (target_dir / "test.txt").exists() - assert (target_dir / "test.txt").read_text() == "source content" - # 验证原目录内容已被删除 - assert not os.path.exists(str(target_dir / "existing.txt")) - - -def test_create_symlink_directory(setup_test_files): - source_dir = setup_test_files / "source" - link_path = setup_test_files / "link_dir" - - create_symlink(str(source_dir), str(link_path)) - - assert os.path.islink(str(link_path)) - assert os.readlink(str(link_path)) == str(source_dir) - assert (link_path / "test.txt").exists() - - -def test_create_symlink_nonexistent_source(): - with pytest.raises(Exception): - create_symlink("nonexistent_file", "link_path") - - -def test_create_symlink_existing_without_force(setup_test_files): - source_file = setup_test_files / "source" / "test.txt" - link_path = setup_test_files / "link.txt" - - # 首先创建一个文件 - link_path.write_text("existing content") - - # 尝试创建符号链接但不使用force选项 - with pytest.raises(Exception): - create_symlink(str(source_file), str(link_path), force=False) - - -def test_create_symlink_existing_directory_force(setup_test_files): - source_dir = setup_test_files / "source" - existing_dir = setup_test_files / "existing_dir" - - # 创建一个已存在的目录 - existing_dir.mkdir() - (existing_dir / "existing_file.txt").write_text("existing content") - - # 使用force选项创建符号链接 - create_symlink(str(source_dir), str(existing_dir), force=True) - - assert os.path.islink(str(existing_dir)) - assert os.readlink(str(existing_dir)) == str(source_dir) - assert (existing_dir / "test.txt").exists() - - -# 清理函数 -@pytest.fixture(autouse=True) -def cleanup(setup_test_files): - yield - # 测试后清理临时文件 - shutil.rmtree(setup_test_files)