From 2cb512120b80cbe0a82153207720cb669731813a Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Feb 2026 16:11:52 +0800 Subject: [PATCH] feat(plugin-system): harden P0 safety with workflow timeout, service ACL, and contract validation - enforce step timeout/cancellation in workflow engine - add caller authorization boundary for cross-plugin service calls - validate params_schema and return_schema at runtime --- src/plugin_system/apis/plugin_service_api.py | 16 ++- src/plugin_system/base/service_types.py | 4 +- src/plugin_system/core/component_registry.py | 17 +++ src/plugin_system/core/plugin_manager.py | 30 +++++ .../core/plugin_service_registry.py | 106 +++++++++++++++++- src/plugin_system/core/workflow_engine.py | 35 +++++- 6 files changed, 199 insertions(+), 9 deletions(-) diff --git a/src/plugin_system/apis/plugin_service_api.py b/src/plugin_system/apis/plugin_service_api.py index 3e64722b..4c783f04 100644 --- a/src/plugin_system/apis/plugin_service_api.py +++ b/src/plugin_system/apis/plugin_service_api.py @@ -39,6 +39,18 @@ def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> return plugin_service_registry.unregister_service(service_name, plugin_name) -async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: +async def call_service( + service_name: str, + *args: Any, + plugin_name: Optional[str] = None, + caller_plugin: Optional[str] = None, + **kwargs: Any, +) -> Any: """调用插件服务。""" - return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs) + return await plugin_service_registry.call_service( + service_name, + *args, + plugin_name=plugin_name, + caller_plugin=caller_plugin, + **kwargs, + ) diff --git a/src/plugin_system/base/service_types.py b/src/plugin_system/base/service_types.py index 7d23f7ca..0a789311 100644 --- a/src/plugin_system/base/service_types.py +++ b/src/plugin_system/base/service_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any, Dict, List @dataclass @@ -11,6 +11,8 @@ class PluginServiceInfo: version: str = "1.0.0" description: str = "" enabled: bool = True + public: bool = False + allowed_callers: List[str] = field(default_factory=list) params_schema: Dict[str, Any] = field(default_factory=dict) return_schema: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index af2f852c..aaa8d70a 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -274,6 +274,23 @@ class ComponentRegistry: logger.error(f"移除组件 {component_name} 时发生错误: {e}") return False + async def remove_components_by_plugin(self, plugin_name: str) -> int: + """移除某插件注册的所有组件。""" + targets = [ + (component_info.name, component_info.component_type) + for component_info in self._components.values() + if component_info.plugin_name == plugin_name + ] + + removed_count = 0 + for component_name, component_type in targets: + if await self.remove_component(component_name, component_type, plugin_name): + removed_count += 1 + + if removed_count: + logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}") + return removed_count + def remove_plugin_registry(self, plugin_name: str) -> bool: """移除插件注册信息 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 9c8664ba..c2665a66 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -200,13 +200,43 @@ class PluginManager: """ 重载插件模块 """ + old_instance = self.loaded_plugins.get(plugin_name) + if not old_instance: + logger.warning(f"插件 {plugin_name} 未加载,无法重载") + return False + if not await self.remove_registered_plugin(plugin_name): return False + if not self.load_registered_plugin_classes(plugin_name)[0]: + logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例") + rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance) + if rollback_ok: + logger.info(f"插件 {plugin_name} 已回滚到旧版本实例") + else: + logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用") return False + logger.debug(f"插件 {plugin_name} 重载成功") return True + async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool: + """重载失败后回滚旧实例。""" + try: + await component_registry.remove_components_by_plugin(plugin_name) + component_registry.remove_plugin_registry(plugin_name) + plugin_service_registry.remove_services_by_plugin(plugin_name) + + if not old_instance.register_plugin(): + logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败") + return False + + self.loaded_plugins[plugin_name] = old_instance + return True + except Exception as e: + logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True) + return False + def rescan_plugin_directory(self) -> Tuple[int, int]: """ 重新扫描插件根目录 diff --git a/src/plugin_system/core/plugin_service_registry.py b/src/plugin_system/core/plugin_service_registry.py index 1dbf1233..c8bd2bfb 100644 --- a/src/plugin_system/core/plugin_service_registry.py +++ b/src/plugin_system/core/plugin_service_registry.py @@ -26,6 +26,9 @@ class PluginServiceRegistry: if "." in service_info.plugin_name: logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") return False + if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]: + logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}") + return False full_name = service_info.full_name if full_name in self._services: @@ -103,12 +106,28 @@ class PluginServiceRegistry: logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}") return removed_count - async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: + async def call_service( + self, + service_name: str, + *args: Any, + plugin_name: Optional[str] = None, + caller_plugin: Optional[str] = None, + **kwargs: Any, + ) -> Any: """调用插件服务(支持同步/异步handler)。""" service_info = self.get_service(service_name, plugin_name) if not service_info: target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name raise ValueError(f"插件服务未注册: {target_name}") + + if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin: + raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name") + + if not self._is_call_authorized(service_info, caller_plugin): + raise PermissionError( + f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}" + ) + if not service_info.enabled: raise RuntimeError(f"插件服务已禁用: {service_info.full_name}") @@ -116,8 +135,91 @@ class PluginServiceRegistry: if not handler: raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}") + self._validate_input_contract(service_info, args, kwargs) + result = handler(*args, **kwargs) - return await result if inspect.isawaitable(result) else result + resolved_result = await result if inspect.isawaitable(result) else result + self._validate_output_contract(service_info, resolved_result) + return resolved_result + + def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool: + """检查服务调用权限。""" + if caller_plugin is None: + return service_info.public + if caller_plugin == service_info.plugin_name: + return True + if service_info.public: + return True + allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()} + return "*" in allowed_callers or caller_plugin in allowed_callers + + def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + """校验服务入参契约。""" + schema = service_info.params_schema + if not schema: + return + + properties = schema.get("properties", {}) if isinstance(schema, dict) else {} + is_invocation_schema = "args" in properties or "kwargs" in properties + + if is_invocation_schema: + payload = {"args": list(args), "kwargs": kwargs} + self._validate_by_schema(payload, schema, path="params") + return + + if args: + raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数") + self._validate_by_schema(kwargs, schema, path="params") + + def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None: + """校验服务返回值契约。""" + if not service_info.return_schema: + return + self._validate_by_schema(value, service_info.return_schema, path="return") + + def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None: + """基于简化JSON-Schema校验数据。""" + expected_type = schema.get("type") + if expected_type: + self._validate_type(value, expected_type, path) + + enum_values = schema.get("enum") + if enum_values is not None and value not in enum_values: + raise ValueError(f"{path} 不在枚举范围内: {value}") + + if expected_type == "object": + properties = schema.get("properties", {}) + required = schema.get("required", []) + + for field in required: + if field not in value: + raise ValueError(f"{path}.{field} 为必填字段") + + for field, field_value in value.items(): + if field in properties: + self._validate_by_schema(field_value, properties[field], f"{path}.{field}") + elif schema.get("additionalProperties", True) is False: + raise ValueError(f"{path}.{field} 不允许额外字段") + + if expected_type == "array": + if item_schema := schema.get("items"): + for index, item in enumerate(value): + self._validate_by_schema(item, item_schema, f"{path}[{index}]") + + def _validate_type(self, value: Any, expected_type: str, path: str) -> None: + """校验基础类型。""" + type_checkers: Dict[str, Callable[[Any], bool]] = { + "string": lambda item: isinstance(item, str), + "number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool), + "integer": lambda item: isinstance(item, int) and not isinstance(item, bool), + "boolean": lambda item: isinstance(item, bool), + "object": lambda item: isinstance(item, dict), + "array": lambda item: isinstance(item, list), + "null": lambda item: item is None, + } + checker = type_checkers.get(expected_type) + if checker and not checker(value): + raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}") def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]: """解析服务全名。""" diff --git a/src/plugin_system/core/workflow_engine.py b/src/plugin_system/core/workflow_engine.py index 089d0498..937a7a91 100644 --- a/src/plugin_system/core/workflow_engine.py +++ b/src/plugin_system/core/workflow_engine.py @@ -1,7 +1,8 @@ from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union +import asyncio +import inspect import time import uuid -import inspect from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType, MaiMessages @@ -144,11 +145,19 @@ class WorkflowEngine: step_timing_key = f"{stage.value}:{step_info.full_name}" step_start = time.perf_counter() + timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None try: - result = handler(context, message) - if inspect.isawaitable(result): - result = await result + if inspect.iscoroutinefunction(handler): + coroutine = handler(context, message) + result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine + else: + if timeout_seconds: + result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds) + else: + result = handler(context, message) + if inspect.isawaitable(result): + result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result context.timings[step_timing_key] = time.perf_counter() - step_start normalized_result = self._normalize_step_result(result) @@ -165,6 +174,24 @@ class WorkflowEngine: normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value) return normalized_result + except asyncio.TimeoutError: + context.timings[step_timing_key] = time.perf_counter() - step_start + timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms" + context.errors.append(f"{step_info.full_name}: {timeout_message}") + logger.error( + f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}" + ) + return WorkflowStepResult( + status="failed", + return_message=timeout_message, + diagnostics={ + "stage": stage.value, + "step": step_info.full_name, + "trace_id": context.trace_id, + "error_code": WorkflowErrorCode.STEP_TIMEOUT.value, + }, + ) + except Exception as e: context.timings[step_timing_key] = time.perf_counter() - step_start context.errors.append(f"{step_info.full_name}: {e}")