diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 36f4ff5b..2c33726e 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -29,10 +29,17 @@ from .base import ( MaiMessages, ToolParamType, CustomEventHandlerResult, + PluginServiceInfo, ReplyContentType, ReplyContent, ForwardNode, ReplySetModel, + WorkflowContext, + WorkflowMessage, + WorkflowStage, + WorkflowStepInfo, + WorkflowStepResult, + WorkflowErrorCode, ) # 导入工具模块 @@ -55,6 +62,8 @@ from .apis import ( message_api, person_api, plugin_manage_api, + plugin_service_api, + workflow_api, send_api, register_plugin, get_logger, @@ -85,6 +94,8 @@ __all__ = [ "message_api", "person_api", "plugin_manage_api", + "plugin_service_api", + "workflow_api", "send_api", "auto_talk_api", "register_plugin", @@ -115,6 +126,13 @@ __all__ = [ "ReplySetModel", "MaiMessages", "CustomEventHandlerResult", + "PluginServiceInfo", + "WorkflowContext", + "WorkflowMessage", + "WorkflowStage", + "WorkflowStepInfo", + "WorkflowStepResult", + "WorkflowErrorCode", # 装饰器 "register_plugin", "ConfigField", diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 036c077e..75e64af1 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -16,9 +16,11 @@ from src.plugin_system.apis import ( message_api, person_api, plugin_manage_api, + plugin_service_api, send_api, tool_api, frequency_api, + workflow_api, ) from .logging_api import get_logger from .plugin_register_api import register_plugin @@ -35,9 +37,11 @@ __all__ = [ "message_api", "person_api", "plugin_manage_api", + "plugin_service_api", "send_api", "get_logger", "register_plugin", "tool_api", "frequency_api", + "workflow_api", ] diff --git a/src/plugin_system/apis/plugin_service_api.py b/src/plugin_system/apis/plugin_service_api.py new file mode 100644 index 00000000..3e64722b --- /dev/null +++ b/src/plugin_system/apis/plugin_service_api.py @@ -0,0 +1,44 @@ +from typing import Any, Callable, Dict, Optional + +from src.plugin_system.base.service_types import PluginServiceInfo +from src.plugin_system.core.plugin_service_registry import plugin_service_registry + + +def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool: + """注册插件服务。""" + return plugin_service_registry.register_service(service_info, service_handler) + + +def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]: + """获取插件服务元信息。""" + return plugin_service_registry.get_service(service_name, plugin_name) + + +def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]: + """获取插件服务处理函数。""" + return plugin_service_registry.get_service_handler(service_name, plugin_name) + + +def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]: + """列出插件服务。""" + return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only) + + +def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool: + """启用插件服务。""" + return plugin_service_registry.enable_service(service_name, plugin_name) + + +def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool: + """禁用插件服务。""" + return plugin_service_registry.disable_service(service_name, plugin_name) + + +def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool: + """注销插件服务。""" + return plugin_service_registry.unregister_service(service_name, plugin_name) + + +async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: + """调用插件服务。""" + return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs) diff --git a/src/plugin_system/apis/workflow_api.py b/src/plugin_system/apis/workflow_api.py new file mode 100644 index 00000000..c7d45e88 --- /dev/null +++ b/src/plugin_system/apis/workflow_api.py @@ -0,0 +1,77 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from src.plugin_system.base.component_types import EventType, MaiMessages +from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult +from src.plugin_system.core.component_registry import component_registry +from src.plugin_system.core.events_manager import events_manager +from src.plugin_system.core.workflow_engine import workflow_engine + + +def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool: + """注册workflow step。""" + return component_registry.register_workflow_step(step_info, step_handler) + + +def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]: + """获取指定阶段的workflow steps。""" + return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only) + + +def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]: + """获取workflow step元信息。""" + return component_registry.get_workflow_step(step_name, stage) + + +def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]: + """获取workflow step处理函数。""" + return component_registry.get_workflow_step_handler(step_name, stage) + + +def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool: + """启用workflow step。""" + return component_registry.enable_workflow_step(step_name, stage) + + +def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool: + """禁用workflow step。""" + return component_registry.disable_workflow_step(step_name, stage) + + +def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]: + """按trace_id获取workflow执行路径。""" + return workflow_engine.get_execution_trace(trace_id) + + +def clear_execution_trace(trace_id: str) -> bool: + """清理trace执行路径记录。""" + return workflow_engine.clear_execution_trace(trace_id) + + +async def execute_workflow_message( + message: Optional[MaiMessages] = None, + stream_id: Optional[str] = None, + action_usage: Optional[List[str]] = None, + context: Optional[WorkflowContext] = None, +) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]: + """执行workflow消息流转。""" + return await events_manager.handle_workflow_message( + message=message, + stream_id=stream_id, + action_usage=action_usage, + context=context, + ) + + +async def publish_event( + event_type: Union[EventType, str], + message: Optional[MaiMessages] = None, + stream_id: Optional[str] = None, + action_usage: Optional[List[str]] = None, +) -> Tuple[bool, Optional[MaiMessages]]: + """发布事件(支持系统事件和自定义字符串事件)。""" + return await events_manager.handle_mai_events( + event_type=event_type, + message=message, + stream_id=stream_id, + action_usage=action_usage, + ) diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index 9f930b5a..7326fcec 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -9,6 +9,9 @@ from .base_action import BaseAction from .base_tool import BaseTool from .base_command import BaseCommand from .base_events_handler import BaseEventHandler +from .service_types import PluginServiceInfo +from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult +from .workflow_errors import WorkflowErrorCode from .component_types import ( ComponentType, ActionActivationType, @@ -59,4 +62,11 @@ __all__ = [ "ReplyContent", "ForwardNode", "ReplySetModel", + "PluginServiceInfo", + "WorkflowContext", + "WorkflowMessage", + "WorkflowStage", + "WorkflowStepInfo", + "WorkflowStepResult", + "WorkflowErrorCode", ] diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index ea28c514..6b278ef4 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -1,9 +1,10 @@ from abc import abstractmethod -from typing import List, Type, Tuple, Union +from typing import Any, Callable, List, Type, Tuple, Union from .plugin_base import PluginBase from src.common.logger import get_logger from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo +from src.plugin_system.base.workflow_types import WorkflowStepInfo from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler @@ -44,6 +45,13 @@ class BasePlugin(PluginBase): """ raise NotImplementedError("Subclasses must implement this method") + def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]: + """获取插件包含的workflow steps。 + + 默认返回空列表,子类可按需覆盖。 + """ + return [] + def register_plugin(self) -> bool: """注册插件及其所有组件""" from src.plugin_system.core.component_registry import component_registry @@ -69,7 +77,25 @@ class BasePlugin(PluginBase): # 注册插件 if component_registry.register_plugin(self.plugin_info): + # 注册workflow steps(可选) + registered_step_count = 0 + for step_info, step_handler in self.get_workflow_steps(): + if not step_info.plugin_name: + step_info.plugin_name = self.plugin_name + elif step_info.plugin_name != self.plugin_name: + logger.warning( + f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}" + ) + step_info.plugin_name = self.plugin_name + + if component_registry.register_workflow_step(step_info, step_handler): + registered_step_count += 1 + else: + logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败") + logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件") + if registered_step_count > 0: + logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}") return True else: logger.error(f"{self.log_prefix} 插件注册失败") diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index f21c2a40..521ca137 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -6,6 +6,7 @@ import toml import json import shutil import datetime +import re from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -17,7 +18,7 @@ from src.plugin_system.base.config_types import ( ConfigSection, ConfigLayout, ) -from src.plugin_system.utils.manifest_utils import ManifestValidator +from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator logger = get_logger("plugin_base") @@ -41,7 +42,7 @@ class PluginBase(ABC): @property @abstractmethod - def dependencies(self) -> List[str]: + def dependencies(self) -> List[Union[str, Dict[str, Any]]]: return [] # 依赖的其他插件 @property @@ -104,7 +105,7 @@ class PluginBase(ABC): enabled=self.enable_plugin, is_built_in=False, config_file=self.config_file_name or "", - dependencies=self.dependencies.copy(), + dependencies=self._get_dependency_names(), python_dependencies=self.python_dependencies.copy(), # manifest相关信息 manifest_data=self.manifest_data.copy(), @@ -541,13 +542,101 @@ class PluginBase(ABC): if not self.dependencies: return True - for dep in self.dependencies: - if not component_registry.get_plugin_info(dep): - logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}") + for dependency in self.dependencies: + dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency) + if not dep_name: + logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}") + continue + + dep_plugin_info = component_registry.get_plugin_info(dep_name) + if not dep_plugin_info: + logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}") return False + dep_version = dep_plugin_info.version or "0.0.0" + + if version_spec: + is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec) + if not is_ok: + logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})") + return False + + if min_version or max_version: + is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version) + if not is_ok: + logger.error( + f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})" + ) + return False + return True + def _get_dependency_names(self) -> List[str]: + """获取依赖插件名称列表(用于插件信息展示和统计)。""" + dependency_names: List[str] = [] + for dependency in self.dependencies: + dep_name, _, _, _ = self._parse_dependency(dependency) + if dep_name: + dependency_names.append(dep_name) + return dependency_names + + def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]: + """解析依赖声明。 + + 支持格式: + - "plugin_a" + - {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"} + - {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"} + """ + if isinstance(dependency, str): + return dependency.strip(), "", "", "" + + if isinstance(dependency, dict): + dep_name = str(dependency.get("name", "")).strip() + version_spec = str(dependency.get("version", "")).strip() + min_version = str(dependency.get("min_version", "")).strip() + max_version = str(dependency.get("max_version", "")).strip() + return dep_name, version_spec, min_version, max_version + + return "", "", "", "" + + def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]: + """检查版本是否满足表达式。 + + 支持:==, >=, <=, >, <,可用逗号分隔多个条件。 + 示例:">=1.2.0,<2.0.0" + """ + normalized_version = VersionComparator.normalize_version(version) + clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()] + if not clauses: + return True, "" + + operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$" + + for clause in clauses: + if not (match := re.match(operators_pattern, clause)): + return False, f"无效版本约束表达式: {clause}" + + operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2)) + compare_result = VersionComparator.compare_versions(normalized_version, target_version) + + is_satisfied = False + if operator == "==": + is_satisfied = compare_result == 0 + elif operator == ">=": + is_satisfied = compare_result >= 0 + elif operator == "<=": + is_satisfied = compare_result <= 0 + elif operator == ">": + is_satisfied = compare_result > 0 + elif operator == "<": + is_satisfied = compare_result < 0 + + if not is_satisfied: + return False, f"{normalized_version} 不满足约束 {operator}{target_version}" + + return True, "" + def get_config(self, key: str, default: Any = None) -> Any: """获取插件配置值,支持嵌套键访问 diff --git a/src/plugin_system/base/service_types.py b/src/plugin_system/base/service_types.py new file mode 100644 index 00000000..7d23f7ca --- /dev/null +++ b/src/plugin_system/base/service_types.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from typing import Any, Dict + + +@dataclass +class PluginServiceInfo: + """插件服务注册信息""" + + name: str + plugin_name: str + version: str = "1.0.0" + description: str = "" + enabled: bool = True + params_schema: Dict[str, Any] = field(default_factory=dict) + return_schema: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def full_name(self) -> str: + return f"{self.plugin_name}.{self.name}" diff --git a/src/plugin_system/base/workflow_errors.py b/src/plugin_system/base/workflow_errors.py new file mode 100644 index 00000000..0f4f74d0 --- /dev/null +++ b/src/plugin_system/base/workflow_errors.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class WorkflowErrorCode(Enum): + """Workflow统一错误码""" + + PLUGIN_NOT_READY = "PLUGIN_NOT_READY" + STEP_TIMEOUT = "STEP_TIMEOUT" + BAD_PAYLOAD = "BAD_PAYLOAD" + DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED" + POLICY_BLOCKED = "POLICY_BLOCKED" + + def __str__(self) -> str: + return self.value diff --git a/src/plugin_system/base/workflow_types.py b/src/plugin_system/base/workflow_types.py new file mode 100644 index 00000000..bedcec8d --- /dev/null +++ b/src/plugin_system/base/workflow_types.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Literal, Optional +import time + + +class WorkflowStage(Enum): + """Workflow阶段定义(MVP固定阶段)""" + + INGRESS = "ingress" + PRE_PROCESS = "pre_process" + PLAN = "plan" + TOOL_EXECUTE = "tool_execute" + POST_PROCESS = "post_process" + EGRESS = "egress" + + def __str__(self) -> str: + return self.value + + +@dataclass +class WorkflowContext: + """Workflow上下文""" + + trace_id: str + stream_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + timings: Dict[str, float] = field(default_factory=dict) + errors: List[str] = field(default_factory=list) + + +@dataclass +class WorkflowMessage: + """Workflow消息包装""" + + msg_type: str + payload: Dict[str, Any] = field(default_factory=dict) + headers: Dict[str, Any] = field(default_factory=dict) + mutable_flags: Dict[str, bool] = field(default_factory=dict) + + +@dataclass +class WorkflowStepResult: + """Workflow步骤结果""" + + status: Literal["continue", "stop", "failed"] = "continue" + return_message: Optional[str] = None + diagnostics: Dict[str, Any] = field(default_factory=dict) + events: List[Dict[str, Any]] = field(default_factory=list) + created_at: float = field(default_factory=time.time) + + +@dataclass +class WorkflowStepInfo: + """Workflow步骤元数据""" + + name: str + stage: WorkflowStage + plugin_name: str + description: str = "" + enabled: bool = True + priority: int = 0 + timeout_ms: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def full_name(self) -> str: + return f"{self.plugin_name}.{self.name}" diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index eb794a30..01d97646 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -8,10 +8,14 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.plugin_system.core.plugin_service_registry import plugin_service_registry +from src.plugin_system.core.workflow_engine import workflow_engine __all__ = [ "plugin_manager", "component_registry", "events_manager", "global_announcement_manager", + "plugin_service_registry", + "workflow_engine", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 19fda27e..af2f852c 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,6 +1,6 @@ import re -from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type +from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -16,6 +16,7 @@ from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_events_handler import BaseEventHandler +from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo logger = get_logger("component_registry") @@ -61,6 +62,10 @@ class ComponentRegistry: self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} """启用的事件处理器 event_handler名 -> event_handler类""" + # Workflow step注册表 + self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage} + self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {} + logger.info("组件注册中心初始化完成") # == 注册方法 == @@ -282,9 +287,116 @@ class ComponentRegistry: logger.warning(f"插件 {plugin_name} 未注册,无法移除") return False del self._plugins[plugin_name] + self.remove_workflow_steps_by_plugin(plugin_name) logger.info(f"插件 {plugin_name} 已移除") return True + # === Workflow step 注册与查询 === + + def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool: + """注册workflow步骤。""" + if not step_info.name or not step_info.plugin_name: + logger.error("workflow step 注册失败: step名称或插件名称为空") + return False + if "." in step_info.name: + logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代") + return False + if "." in step_info.plugin_name: + logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") + return False + + full_name = step_info.full_name + stage_registry = self._workflow_steps.get(step_info.stage) + if stage_registry is None: + logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}") + return False + if full_name in stage_registry: + logger.warning(f"workflow step 已存在,跳过注册: {full_name}") + return False + + stage_registry[full_name] = step_info + self._workflow_step_handlers[full_name] = step_handler + logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}") + return True + + def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]: + """获取某阶段的workflow步骤。""" + steps = self._workflow_steps.get(stage, {}) + if enabled_only: + return {name: info for name, info in steps.items() if info.enabled} + return steps.copy() + + def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]: + """获取workflow step信息。 + + step_name支持两种: + - full_name: plugin_name.step_name + - short_name: step_name(若有冲突返回第一个并告警) + """ + candidates: List[WorkflowStepInfo] = [] + + target_stages = [stage] if stage else list(WorkflowStage) + for current_stage in target_stages: + current_steps = self._workflow_steps.get(current_stage, {}) + if "." in step_name: + if step_info := current_steps.get(step_name): + return step_info + continue + candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name]) + + if len(candidates) == 1: + return candidates[0] + if len(candidates) > 1: + logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}") + return candidates[0] + return None + + def get_workflow_step_handler( + self, step_name: str, stage: Optional[WorkflowStage] = None + ) -> Optional[Callable[..., Any]]: + """获取workflow step处理函数。""" + if "." in step_name: + return self._workflow_step_handlers.get(step_name) + + if step_info := self.get_workflow_step(step_name, stage): + return self._workflow_step_handlers.get(step_info.full_name) + return None + + def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool: + """启用workflow step。""" + step_info = self.get_workflow_step(step_name, stage) + if not step_info: + logger.warning(f"workflow step 未注册,无法启用: {step_name}") + return False + step_info.enabled = True + logger.info(f"workflow step 已启用: {step_info.full_name}") + return True + + def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool: + """禁用workflow step。""" + step_info = self.get_workflow_step(step_name, stage) + if not step_info: + logger.warning(f"workflow step 未注册,无法禁用: {step_name}") + return False + step_info.enabled = False + logger.info(f"workflow step 已禁用: {step_info.full_name}") + return True + + def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int: + """移除某插件注册的所有workflow step。""" + removed_count = 0 + for stage in WorkflowStage: + stage_registry = self._workflow_steps.get(stage, {}) + target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name] + for full_name in target_names: + stage_registry.pop(full_name, None) + self._workflow_step_handlers.pop(full_name, None) + removed_count += 1 + + if removed_count: + logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}") + return removed_count + # === 组件全局启用/禁用方法 === def enable_component(self, component_name: str, component_type: ComponentType) -> bool: @@ -602,6 +714,12 @@ class ComponentRegistry: tool_components += 1 elif component.component_type == ComponentType.EVENT_HANDLER: events_handlers += 1 + + workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values()) + enabled_workflow_step_count = sum( + len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values() + ) + return { "action_components": action_components, "command_components": command_components, @@ -614,6 +732,11 @@ class ComponentRegistry: }, "enabled_components": len([c for c in self._components.values() if c.enabled]), "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), + "workflow_steps": workflow_step_count, + "enabled_workflow_steps": enabled_workflow_step_count, + "workflow_steps_by_stage": { + stage.value: len(steps) for stage, steps in self._workflow_steps.items() + }, } diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 3fe62937..a81cc462 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -7,7 +7,9 @@ from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult from src.plugin_system.base.base_events_handler import BaseEventHandler +from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult from .global_announcement_manager import global_announcement_manager +from .workflow_engine import workflow_engine if TYPE_CHECKING: from src.common.data_models.llm_data_model import LLMGenerationDataModel @@ -57,16 +59,20 @@ class EventsManager: return False if handler_info.event_type not in self._history_enable_map: - logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}") - return False + if isinstance(handler_info.event_type, str): + self.register_event(handler_info.event_type, enable_history_result=False) + logger.info(f"自动注册自定义事件类型: {handler_info.event_type}") + else: + logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}") + return False self._handler_mapping[handler_name] = handler_class return self._insert_event_handler(handler_class, handler_info) async def handle_mai_events( self, - event_type: EventType, - message: Optional[MessageRecv | MessageSending] = None, + event_type: EventType | str, + message: Optional[MessageRecv | MessageSending | MaiMessages] = None, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None, stream_id: Optional[str] = None, @@ -121,6 +127,37 @@ class EventsManager: return continue_flag, modified_message + async def handle_workflow_message( + self, + message: Optional[MessageRecv | MessageSending | MaiMessages] = None, + stream_id: Optional[str] = None, + action_usage: Optional[List[str]] = None, + context: Optional[WorkflowContext] = None, + ) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]: + """执行线性workflow消息流转(MVP兼容入口)。""" + initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id) + + async def _dispatch( + event_type: EventType | str, + workflow_message: Optional[MaiMessages], + workflow_stream_id: Optional[str], + workflow_action_usage: Optional[List[str]], + ) -> Tuple[bool, Optional[MaiMessages]]: + return await self.handle_mai_events( + event_type=event_type, + message=workflow_message, + stream_id=workflow_stream_id, + action_usage=workflow_action_usage, + ) + + return await workflow_engine.execute_linear( + dispatch_event=_dispatch, + message=initial_message, + stream_id=stream_id, + action_usage=action_usage, + context=context, + ) + async def cancel_handler_tasks(self, handler_name: str) -> None: tasks_to_be_cancelled = self._handler_tasks.get(handler_name, []) if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]: @@ -294,14 +331,17 @@ class EventsManager: def _prepare_message( self, - event_type: EventType, - message: Optional[MessageRecv | MessageSending] = None, + event_type: EventType | str, + message: Optional[MessageRecv | MessageSending | MaiMessages] = None, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None, stream_id: Optional[str] = None, action_usage: Optional[List[str]] = None, ) -> Optional[MaiMessages]: """根据事件类型和输入,准备和转换消息对象。""" + if isinstance(message, MaiMessages): + return message.deepcopy() + if message: return self._transform_event_message(message, llm_prompt, llm_response) diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 122a9ea2..9c8664ba 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,9 +1,9 @@ -import os -import traceback - -from typing import Dict, List, Optional, Tuple, Type, Any from importlib.util import spec_from_file_location, module_from_spec from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Type +from collections import deque +import os +import traceback from src.common.logger import get_logger @@ -11,6 +11,7 @@ from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.base.component_types import ComponentType from src.plugin_system.utils.manifest_utils import VersionComparator from .component_registry import component_registry +from .plugin_service_registry import plugin_service_registry logger = get_logger("plugin_manager") @@ -70,10 +71,36 @@ class PluginManager: logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}") + # 第二阶段前:根据依赖关系决定插件加载顺序 + dependency_graph, missing_dependencies = self._build_plugin_dependency_graph() + sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph) + total_registered = 0 total_failed_registration = 0 - for plugin_name in self.plugin_classes.keys(): + # 先处理缺失依赖 + for plugin_name, missing in missing_dependencies.items(): + if not missing: + continue + missing_dep_names = ", ".join(sorted(missing)) + self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}" + logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}") + total_failed_registration += 1 + + # 再处理循环依赖 + for plugin_name in sorted(cycle_plugins): + if plugin_name in missing_dependencies and missing_dependencies[plugin_name]: + continue + self.failed_plugins[plugin_name] = "检测到循环依赖" + logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖") + total_failed_registration += 1 + + # 最后按拓扑序加载可加载插件 + for plugin_name in sorted_plugins: + if plugin_name in cycle_plugins: + continue + if plugin_name in missing_dependencies and missing_dependencies[plugin_name]: + continue load_status, count = self.load_registered_plugin_classes(plugin_name) if load_status: total_registered += 1 @@ -165,6 +192,7 @@ class PluginManager: for component in plugin_info.components: success &= await component_registry.remove_component(component.name, component.component_type, plugin_name) success &= component_registry.remove_plugin_registry(plugin_name) + plugin_service_registry.remove_services_by_plugin(plugin_name) del self.loaded_plugins[plugin_name] return success @@ -313,6 +341,89 @@ class PluginManager: self.failed_plugins[module_name] = error_msg return False + # == 依赖解析与加载顺序 == + + def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]: + """提取插件声明的依赖。 + + 兼容声明格式: + - list[str] + - list[dict],其中dict至少包含name键 + """ + dependencies: Set[str] = set() + raw_dependencies = getattr(plugin_class, "dependencies", []) + + # 兼容错误声明 + if isinstance(raw_dependencies, property): + logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理") + return dependencies + if not isinstance(raw_dependencies, list): + logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理") + return dependencies + + for dependency in raw_dependencies: + dependency_name = "" + if isinstance(dependency, str): + dependency_name = dependency.strip() + elif isinstance(dependency, dict): + dependency_name = str(dependency.get("name", "")).strip() + else: + logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}") + continue + + if not dependency_name: + continue + if dependency_name == plugin_name: + logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略") + continue + + dependencies.add(dependency_name) + + return dependencies + + def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: + """构建依赖图并返回缺失依赖映射。""" + plugin_names = set(self.plugin_classes.keys()) + dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names} + missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names} + + for plugin_name, plugin_class in self.plugin_classes.items(): + declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class) + for dependency_name in declared_dependencies: + if dependency_name in plugin_names: + dependency_graph[plugin_name].add(dependency_name) + else: + missing_dependencies[plugin_name].add(dependency_name) + + return dependency_graph, missing_dependencies + + def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]: + """根据依赖图计算加载顺序,并检测循环依赖。""" + indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()} + reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph} + + for plugin_name, dependencies in dependency_graph.items(): + for dependency_name in dependencies: + reverse_graph[dependency_name].add(plugin_name) + + zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0])) + load_order: List[str] = [] + + while zero_indegree_queue: + current_plugin = zero_indegree_queue.popleft() + load_order.append(current_plugin) + + for dependent_plugin in sorted(reverse_graph[current_plugin]): + indegree[dependent_plugin] -= 1 + if indegree[dependent_plugin] == 0: + zero_indegree_queue.append(dependent_plugin) + + cycle_plugins = {name for name, degree in indegree.items() if degree > 0} + if cycle_plugins: + logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}") + + return load_order, cycle_plugins + # == 兼容性检查 == def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: diff --git a/src/plugin_system/core/plugin_service_registry.py b/src/plugin_system/core/plugin_service_registry.py new file mode 100644 index 00000000..1dbf1233 --- /dev/null +++ b/src/plugin_system/core/plugin_service_registry.py @@ -0,0 +1,140 @@ +from typing import Any, Callable, Dict, Optional +import inspect + +from src.common.logger import get_logger +from src.plugin_system.base.service_types import PluginServiceInfo + +logger = get_logger("plugin_service_registry") + + +class PluginServiceRegistry: + """插件服务注册中心""" + + def __init__(self): + self._services: Dict[str, PluginServiceInfo] = {} + self._service_handlers: Dict[str, Callable[..., Any]] = {} + logger.info("插件服务注册中心初始化完成") + + def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool: + """注册插件服务。""" + if not service_info.name or not service_info.plugin_name: + logger.error("插件服务注册失败: service名称或插件名称为空") + return False + if "." in service_info.name: + logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代") + return False + if "." in service_info.plugin_name: + logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") + return False + + full_name = service_info.full_name + if full_name in self._services: + logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}") + return False + + self._services[full_name] = service_info + self._service_handlers[full_name] = service_handler + logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})") + return True + + def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]: + """获取插件服务元信息。 + + service_name支持: + - full_name: plugin_name.service_name + - short_name: service_name(当唯一时可解析) + """ + full_name = self._resolve_full_name(service_name, plugin_name) + return self._services.get(full_name) if full_name else None + + def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]: + """获取插件服务处理函数。""" + full_name = self._resolve_full_name(service_name, plugin_name) + return self._service_handlers.get(full_name) if full_name else None + + def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]: + """列出插件服务。""" + services = self._services.copy() + if plugin_name: + services = {name: info for name, info in services.items() if info.plugin_name == plugin_name} + if enabled_only: + services = {name: info for name, info in services.items() if info.enabled} + return services + + def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool: + """启用插件服务。""" + if not (service_info := self.get_service(service_name, plugin_name)): + logger.warning(f"插件服务未注册,无法启用: {service_name}") + return False + service_info.enabled = True + logger.info(f"插件服务已启用: {service_info.full_name}") + return True + + def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool: + """禁用插件服务。""" + if not (service_info := self.get_service(service_name, plugin_name)): + logger.warning(f"插件服务未注册,无法禁用: {service_name}") + return False + service_info.enabled = False + logger.info(f"插件服务已禁用: {service_info.full_name}") + return True + + def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool: + """注销单个插件服务。""" + full_name = self._resolve_full_name(service_name, plugin_name) + if not full_name: + logger.warning(f"插件服务未注册,无法注销: {service_name}") + return False + + self._services.pop(full_name, None) + self._service_handlers.pop(full_name, None) + logger.info(f"插件服务已注销: {full_name}") + return True + + def remove_services_by_plugin(self, plugin_name: str) -> int: + """移除某插件的所有注册服务。""" + target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name] + for full_name in target_names: + self._services.pop(full_name, None) + self._service_handlers.pop(full_name, None) + + removed_count = len(target_names) + if removed_count: + logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}") + return removed_count + + async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: + """调用插件服务(支持同步/异步handler)。""" + service_info = self.get_service(service_name, plugin_name) + if not service_info: + target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name + raise ValueError(f"插件服务未注册: {target_name}") + if not service_info.enabled: + raise RuntimeError(f"插件服务已禁用: {service_info.full_name}") + + handler = self.get_service_handler(service_name, plugin_name) + if not handler: + raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}") + + result = handler(*args, **kwargs) + return await result if inspect.isawaitable(result) else result + + def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]: + """解析服务全名。""" + if "." in service_name: + return service_name if service_name in self._services else None + + if plugin_name: + full_name = f"{plugin_name}.{service_name}" + return full_name if full_name in self._services else None + + candidates = [full_name for full_name, info in self._services.items() if info.name == service_name] + if len(candidates) == 1: + return candidates[0] + if len(candidates) > 1: + logger.warning(f"插件服务名称 '{service_name}' 存在多义,请传入plugin_name或使用完整服务名") + return None + return None + + +plugin_service_registry = PluginServiceRegistry() diff --git a/src/plugin_system/core/workflow_engine.py b/src/plugin_system/core/workflow_engine.py new file mode 100644 index 00000000..089d0498 --- /dev/null +++ b/src/plugin_system/core/workflow_engine.py @@ -0,0 +1,229 @@ +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union +import time +import uuid +import inspect + +from src.common.logger import get_logger +from src.plugin_system.base.component_types import EventType, MaiMessages +from src.plugin_system.base.workflow_errors import WorkflowErrorCode +from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult + +logger = get_logger("workflow_engine") + + +class WorkflowEngine: + """线性Workflow执行器(MVP)""" + + STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [ + (WorkflowStage.INGRESS, "workflow.ingress"), + (WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS), + (WorkflowStage.PLAN, EventType.ON_PLAN), + (WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"), + (WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS), + (WorkflowStage.EGRESS, "workflow.egress"), + ] + + def __init__(self): + self._execution_history: dict[str, dict[str, Any]] = {} + + async def execute_linear( + self, + dispatch_event: Callable[ + [Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]], + Awaitable[Tuple[bool, Optional[MaiMessages]]], + ], + message: Optional[MaiMessages] = None, + stream_id: Optional[str] = None, + action_usage: Optional[List[str]] = None, + context: Optional[WorkflowContext] = None, + ) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]: + """执行线性workflow。""" + workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id) + current_message = message.deepcopy() if message else None + self._execution_history[workflow_context.trace_id] = { + "trace_id": workflow_context.trace_id, + "stream_id": workflow_context.stream_id, + "stages": [], + "errors": [], + "status": "running", + } + + for stage, event_type in self.STAGE_EVENT_SEQUENCE: + stage_key = str(stage) + stage_start = time.perf_counter() + try: + should_continue, modified_message = await dispatch_event( + event_type, + current_message, + workflow_context.stream_id, + action_usage, + ) + workflow_context.timings[stage_key] = time.perf_counter() - stage_start + self._execution_history[workflow_context.trace_id]["stages"].append( + { + "stage": stage_key, + "event_type": str(event_type), + "event_continue": should_continue, + "event_cost": workflow_context.timings[stage_key], + } + ) + + if modified_message: + current_message = modified_message + + if not should_continue: + logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断") + return ( + WorkflowStepResult( + status="stop", + return_message=f"workflow stopped at stage {stage_key}", + diagnostics={ + "stage": stage_key, + "trace_id": workflow_context.trace_id, + "error_code": WorkflowErrorCode.POLICY_BLOCKED.value, + }, + ), + current_message, + workflow_context, + ) + + step_result = await self._execute_registered_steps(stage, workflow_context, current_message) + if step_result.status in ["stop", "failed"]: + self._execution_history[workflow_context.trace_id]["status"] = step_result.status + self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() + return step_result, current_message, workflow_context + except Exception as e: + workflow_context.timings[stage_key] = time.perf_counter() - stage_start + workflow_context.errors.append(f"{stage_key}: {e}") + logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True) + self._execution_history[workflow_context.trace_id]["status"] = "failed" + self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() + return ( + WorkflowStepResult( + status="failed", + return_message=str(e), + diagnostics={ + "stage": stage_key, + "trace_id": workflow_context.trace_id, + "error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value, + }, + ), + current_message, + workflow_context, + ) + + self._execution_history[workflow_context.trace_id]["status"] = "continue" + self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() + return ( + WorkflowStepResult( + status="continue", + return_message="workflow completed", + diagnostics={"trace_id": workflow_context.trace_id}, + ), + current_message, + workflow_context, + ) + + async def _execute_registered_steps( + self, + stage: WorkflowStage, + context: WorkflowContext, + message: Optional[MaiMessages], + ) -> WorkflowStepResult: + """执行指定阶段已注册的workflow步骤。""" + from src.plugin_system.core.component_registry import component_registry + + stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True) + sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True) + + for step_info in sorted_steps: + handler = component_registry.get_workflow_step_handler(step_info.full_name, stage) + if not handler: + context.errors.append(f"{step_info.full_name}: handler not found") + continue + + step_timing_key = f"{stage.value}:{step_info.full_name}" + step_start = time.perf_counter() + + try: + result = handler(context, message) + if inspect.isawaitable(result): + result = await result + context.timings[step_timing_key] = time.perf_counter() - step_start + + normalized_result = self._normalize_step_result(result) + if normalized_result.status == "continue": + continue + + normalized_result.diagnostics.setdefault("stage", stage.value) + normalized_result.diagnostics.setdefault("step", step_info.full_name) + normalized_result.diagnostics.setdefault("trace_id", context.trace_id) + if normalized_result.status == "failed": + context.errors.append( + f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}" + ) + normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value) + return normalized_result + + except Exception as e: + context.timings[step_timing_key] = time.perf_counter() - step_start + context.errors.append(f"{step_info.full_name}: {e}") + logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True) + return WorkflowStepResult( + status="failed", + return_message=str(e), + diagnostics={ + "stage": stage.value, + "step": step_info.full_name, + "trace_id": context.trace_id, + "error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value, + }, + ) + + return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id}) + + def _normalize_step_result(self, result: Any) -> WorkflowStepResult: + """归一化workflow step返回值。""" + if isinstance(result, WorkflowStepResult): + return result + if isinstance(result, bool): + if result: + return WorkflowStepResult(status="continue") + return WorkflowStepResult( + status="failed", + diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value}, + ) + if result is None: + return WorkflowStepResult(status="continue") + if isinstance(result, str): + return WorkflowStepResult(status="continue", return_message=result) + if isinstance(result, dict): + status = result.get("status", "continue") + if status not in ["continue", "stop", "failed"]: + status = "failed" + return WorkflowStepResult( + status=status, + return_message=result.get("return_message"), + diagnostics=result.get("diagnostics", {}), + events=result.get("events", []), + ) + return WorkflowStepResult( + status="failed", + return_message=f"unsupported step result type: {type(result)}", + diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value}, + ) + + def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]: + """按trace_id获取workflow执行路径。""" + trace = self._execution_history.get(trace_id) + return trace.copy() if trace else None + + def clear_execution_trace(self, trace_id: str) -> bool: + """清理trace执行记录。""" + if trace_id in self._execution_history: + del self._execution_history[trace_id] + return True + return False + + +workflow_engine = WorkflowEngine()