feat(plugin-system): harden P0 safety with workflow timeout, service ACL, and contract validation

- enforce step timeout/cancellation in workflow engine
- add caller authorization boundary for cross-plugin service calls
- validate params_schema and return_schema at runtime
pull/1496/head
DrSmoothl 2026-02-21 16:11:52 +08:00
parent 6d196454ee
commit 2cb512120b
6 changed files with 199 additions and 9 deletions

View File

@ -39,6 +39,18 @@ def unregister_service(service_name: str, plugin_name: Optional[str] = None) ->
return plugin_service_registry.unregister_service(service_name, plugin_name) return plugin_service_registry.unregister_service(service_name, plugin_name)
async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: async def call_service(
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务。""" """调用插件服务。"""
return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs) return await plugin_service_registry.call_service(
service_name,
*args,
plugin_name=plugin_name,
caller_plugin=caller_plugin,
**kwargs,
)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict from typing import Any, Dict, List
@dataclass @dataclass
@ -11,6 +11,8 @@ class PluginServiceInfo:
version: str = "1.0.0" version: str = "1.0.0"
description: str = "" description: str = ""
enabled: bool = True enabled: bool = True
public: bool = False
allowed_callers: List[str] = field(default_factory=list)
params_schema: Dict[str, Any] = field(default_factory=dict) params_schema: Dict[str, Any] = field(default_factory=dict)
return_schema: Dict[str, Any] = field(default_factory=dict) return_schema: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)

View File

@ -274,6 +274,23 @@ class ComponentRegistry:
logger.error(f"移除组件 {component_name} 时发生错误: {e}") logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False return False
async def remove_components_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有组件。"""
targets = [
(component_info.name, component_info.component_type)
for component_info in self._components.values()
if component_info.plugin_name == plugin_name
]
removed_count = 0
for component_name, component_type in targets:
if await self.remove_component(component_name, component_type, plugin_name):
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
return removed_count
def remove_plugin_registry(self, plugin_name: str) -> bool: def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息 """移除插件注册信息

View File

@ -200,13 +200,43 @@ class PluginManager:
""" """
重载插件模块 重载插件模块
""" """
old_instance = self.loaded_plugins.get(plugin_name)
if not old_instance:
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
return False
if not await self.remove_registered_plugin(plugin_name): if not await self.remove_registered_plugin(plugin_name):
return False return False
if not self.load_registered_plugin_classes(plugin_name)[0]: if not self.load_registered_plugin_classes(plugin_name)[0]:
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
if rollback_ok:
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
else:
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
return False return False
logger.debug(f"插件 {plugin_name} 重载成功") logger.debug(f"插件 {plugin_name} 重载成功")
return True return True
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
"""重载失败后回滚旧实例。"""
try:
await component_registry.remove_components_by_plugin(plugin_name)
component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
if not old_instance.register_plugin():
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
return False
self.loaded_plugins[plugin_name] = old_instance
return True
except Exception as e:
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
return False
def rescan_plugin_directory(self) -> Tuple[int, int]: def rescan_plugin_directory(self) -> Tuple[int, int]:
""" """
重新扫描插件根目录 重新扫描插件根目录

View File

@ -26,6 +26,9 @@ class PluginServiceRegistry:
if "." in service_info.plugin_name: if "." in service_info.plugin_name:
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False return False
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
return False
full_name = service_info.full_name full_name = service_info.full_name
if full_name in self._services: if full_name in self._services:
@ -103,12 +106,28 @@ class PluginServiceRegistry:
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}") logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
return removed_count return removed_count
async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: async def call_service(
self,
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务(支持同步/异步handler""" """调用插件服务(支持同步/异步handler"""
service_info = self.get_service(service_name, plugin_name) service_info = self.get_service(service_name, plugin_name)
if not service_info: if not service_info:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}") raise ValueError(f"插件服务未注册: {target_name}")
if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin:
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin):
raise PermissionError(
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
)
if not service_info.enabled: if not service_info.enabled:
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}") raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
@ -116,8 +135,91 @@ class PluginServiceRegistry:
if not handler: if not handler:
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}") raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
self._validate_input_contract(service_info, args, kwargs)
result = handler(*args, **kwargs) result = handler(*args, **kwargs)
return await result if inspect.isawaitable(result) else result resolved_result = await result if inspect.isawaitable(result) else result
self._validate_output_contract(service_info, resolved_result)
return resolved_result
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
"""检查服务调用权限。"""
if caller_plugin is None:
return service_info.public
if caller_plugin == service_info.plugin_name:
return True
if service_info.public:
return True
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
"""校验服务入参契约。"""
schema = service_info.params_schema
if not schema:
return
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
is_invocation_schema = "args" in properties or "kwargs" in properties
if is_invocation_schema:
payload = {"args": list(args), "kwargs": kwargs}
self._validate_by_schema(payload, schema, path="params")
return
if args:
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
self._validate_by_schema(kwargs, schema, path="params")
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
"""校验服务返回值契约。"""
if not service_info.return_schema:
return
self._validate_by_schema(value, service_info.return_schema, path="return")
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
"""基于简化JSON-Schema校验数据。"""
expected_type = schema.get("type")
if expected_type:
self._validate_type(value, expected_type, path)
enum_values = schema.get("enum")
if enum_values is not None and value not in enum_values:
raise ValueError(f"{path} 不在枚举范围内: {value}")
if expected_type == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
for field in required:
if field not in value:
raise ValueError(f"{path}.{field} 为必填字段")
for field, field_value in value.items():
if field in properties:
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
elif schema.get("additionalProperties", True) is False:
raise ValueError(f"{path}.{field} 不允许额外字段")
if expected_type == "array":
if item_schema := schema.get("items"):
for index, item in enumerate(value):
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
"""校验基础类型。"""
type_checkers: Dict[str, Callable[[Any], bool]] = {
"string": lambda item: isinstance(item, str),
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
"boolean": lambda item: isinstance(item, bool),
"object": lambda item: isinstance(item, dict),
"array": lambda item: isinstance(item, list),
"null": lambda item: item is None,
}
checker = type_checkers.get(expected_type)
if checker and not checker(value):
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]: def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
"""解析服务全名。""" """解析服务全名。"""

View File

@ -1,7 +1,8 @@
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import asyncio
import inspect
import time import time
import uuid import uuid
import inspect
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, MaiMessages from src.plugin_system.base.component_types import EventType, MaiMessages
@ -144,11 +145,19 @@ class WorkflowEngine:
step_timing_key = f"{stage.value}:{step_info.full_name}" step_timing_key = f"{stage.value}:{step_info.full_name}"
step_start = time.perf_counter() step_start = time.perf_counter()
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
try: try:
result = handler(context, message) if inspect.iscoroutinefunction(handler):
if inspect.isawaitable(result): coroutine = handler(context, message)
result = await result result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
else:
if timeout_seconds:
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
else:
result = handler(context, message)
if inspect.isawaitable(result):
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
context.timings[step_timing_key] = time.perf_counter() - step_start context.timings[step_timing_key] = time.perf_counter() - step_start
normalized_result = self._normalize_step_result(result) normalized_result = self._normalize_step_result(result)
@ -165,6 +174,24 @@ class WorkflowEngine:
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value) normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
return normalized_result return normalized_result
except asyncio.TimeoutError:
context.timings[step_timing_key] = time.perf_counter() - step_start
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
context.errors.append(f"{step_info.full_name}: {timeout_message}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
)
return WorkflowStepResult(
status="failed",
return_message=timeout_message,
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
},
)
except Exception as e: except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}") context.errors.append(f"{step_info.full_name}: {e}")