Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev

pull/1496/head
UnCLAS-Prommer 2026-02-21 15:55:57 +08:00
commit 6d196454ee
No known key found for this signature in database
16 changed files with 1036 additions and 19 deletions

View File

@ -29,10 +29,17 @@ from .base import (
MaiMessages, MaiMessages,
ToolParamType, ToolParamType,
CustomEventHandlerResult, CustomEventHandlerResult,
PluginServiceInfo,
ReplyContentType, ReplyContentType,
ReplyContent, ReplyContent,
ForwardNode, ForwardNode,
ReplySetModel, ReplySetModel,
WorkflowContext,
WorkflowMessage,
WorkflowStage,
WorkflowStepInfo,
WorkflowStepResult,
WorkflowErrorCode,
) )
# 导入工具模块 # 导入工具模块
@ -55,6 +62,8 @@ from .apis import (
message_api, message_api,
person_api, person_api,
plugin_manage_api, plugin_manage_api,
plugin_service_api,
workflow_api,
send_api, send_api,
register_plugin, register_plugin,
get_logger, get_logger,
@ -85,6 +94,8 @@ __all__ = [
"message_api", "message_api",
"person_api", "person_api",
"plugin_manage_api", "plugin_manage_api",
"plugin_service_api",
"workflow_api",
"send_api", "send_api",
"auto_talk_api", "auto_talk_api",
"register_plugin", "register_plugin",
@ -115,6 +126,13 @@ __all__ = [
"ReplySetModel", "ReplySetModel",
"MaiMessages", "MaiMessages",
"CustomEventHandlerResult", "CustomEventHandlerResult",
"PluginServiceInfo",
"WorkflowContext",
"WorkflowMessage",
"WorkflowStage",
"WorkflowStepInfo",
"WorkflowStepResult",
"WorkflowErrorCode",
# 装饰器 # 装饰器
"register_plugin", "register_plugin",
"ConfigField", "ConfigField",

View File

@ -16,9 +16,11 @@ from src.plugin_system.apis import (
message_api, message_api,
person_api, person_api,
plugin_manage_api, plugin_manage_api,
plugin_service_api,
send_api, send_api,
tool_api, tool_api,
frequency_api, frequency_api,
workflow_api,
) )
from .logging_api import get_logger from .logging_api import get_logger
from .plugin_register_api import register_plugin from .plugin_register_api import register_plugin
@ -35,9 +37,11 @@ __all__ = [
"message_api", "message_api",
"person_api", "person_api",
"plugin_manage_api", "plugin_manage_api",
"plugin_service_api",
"send_api", "send_api",
"get_logger", "get_logger",
"register_plugin", "register_plugin",
"tool_api", "tool_api",
"frequency_api", "frequency_api",
"workflow_api",
] ]

View File

@ -0,0 +1,44 @@
from typing import Any, Callable, Dict, Optional
from src.plugin_system.base.service_types import PluginServiceInfo
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
"""注册插件服务。"""
return plugin_service_registry.register_service(service_info, service_handler)
def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
"""获取插件服务元信息。"""
return plugin_service_registry.get_service(service_name, plugin_name)
def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
"""获取插件服务处理函数。"""
return plugin_service_registry.get_service_handler(service_name, plugin_name)
def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only)
def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""启用插件服务。"""
return plugin_service_registry.enable_service(service_name, plugin_name)
def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""禁用插件服务。"""
return plugin_service_registry.disable_service(service_name, plugin_name)
def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""注销插件服务。"""
return plugin_service_registry.unregister_service(service_name, plugin_name)
async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
"""调用插件服务。"""
return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs)

View File

@ -0,0 +1,77 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from src.plugin_system.base.component_types import EventType, MaiMessages
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.workflow_engine import workflow_engine
def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
"""注册workflow step。"""
return component_registry.register_workflow_step(step_info, step_handler)
def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
"""获取指定阶段的workflow steps。"""
return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only)
def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
"""获取workflow step元信息。"""
return component_registry.get_workflow_step(step_name, stage)
def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]:
"""获取workflow step处理函数。"""
return component_registry.get_workflow_step_handler(step_name, stage)
def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""启用workflow step。"""
return component_registry.enable_workflow_step(step_name, stage)
def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""禁用workflow step。"""
return component_registry.disable_workflow_step(step_name, stage)
def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]:
"""按trace_id获取workflow执行路径。"""
return workflow_engine.get_execution_trace(trace_id)
def clear_execution_trace(trace_id: str) -> bool:
"""清理trace执行路径记录。"""
return workflow_engine.clear_execution_trace(trace_id)
async def execute_workflow_message(
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行workflow消息流转。"""
return await events_manager.handle_workflow_message(
message=message,
stream_id=stream_id,
action_usage=action_usage,
context=context,
)
async def publish_event(
event_type: Union[EventType, str],
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""发布事件(支持系统事件和自定义字符串事件)。"""
return await events_manager.handle_mai_events(
event_type=event_type,
message=message,
stream_id=stream_id,
action_usage=action_usage,
)

View File

@ -9,6 +9,9 @@ from .base_action import BaseAction
from .base_tool import BaseTool from .base_tool import BaseTool
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .service_types import PluginServiceInfo
from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
from .workflow_errors import WorkflowErrorCode
from .component_types import ( from .component_types import (
ComponentType, ComponentType,
ActionActivationType, ActionActivationType,
@ -59,4 +62,11 @@ __all__ = [
"ReplyContent", "ReplyContent",
"ForwardNode", "ForwardNode",
"ReplySetModel", "ReplySetModel",
"PluginServiceInfo",
"WorkflowContext",
"WorkflowMessage",
"WorkflowStage",
"WorkflowStepInfo",
"WorkflowStepResult",
"WorkflowErrorCode",
] ]

View File

@ -1,9 +1,10 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Type, Tuple, Union from typing import Any, Callable, List, Type, Tuple, Union
from .plugin_base import PluginBase from .plugin_base import PluginBase
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
from src.plugin_system.base.workflow_types import WorkflowStepInfo
from .base_action import BaseAction from .base_action import BaseAction
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
@ -44,6 +45,13 @@ class BasePlugin(PluginBase):
""" """
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]:
"""获取插件包含的workflow steps。
默认返回空列表子类可按需覆盖
"""
return []
def register_plugin(self) -> bool: def register_plugin(self) -> bool:
"""注册插件及其所有组件""" """注册插件及其所有组件"""
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
@ -69,7 +77,25 @@ class BasePlugin(PluginBase):
# 注册插件 # 注册插件
if component_registry.register_plugin(self.plugin_info): if component_registry.register_plugin(self.plugin_info):
# 注册workflow steps可选
registered_step_count = 0
for step_info, step_handler in self.get_workflow_steps():
if not step_info.plugin_name:
step_info.plugin_name = self.plugin_name
elif step_info.plugin_name != self.plugin_name:
logger.warning(
f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}"
)
step_info.plugin_name = self.plugin_name
if component_registry.register_workflow_step(step_info, step_handler):
registered_step_count += 1
else:
logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败")
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件") logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
if registered_step_count > 0:
logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}")
return True return True
else: else:
logger.error(f"{self.log_prefix} 插件注册失败") logger.error(f"{self.log_prefix} 插件注册失败")

View File

@ -6,6 +6,7 @@ import toml
import json import json
import shutil import shutil
import datetime import datetime
import re
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
@ -17,7 +18,7 @@ from src.plugin_system.base.config_types import (
ConfigSection, ConfigSection,
ConfigLayout, ConfigLayout,
) )
from src.plugin_system.utils.manifest_utils import ManifestValidator from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator
logger = get_logger("plugin_base") logger = get_logger("plugin_base")
@ -41,7 +42,7 @@ class PluginBase(ABC):
@property @property
@abstractmethod @abstractmethod
def dependencies(self) -> List[str]: def dependencies(self) -> List[Union[str, Dict[str, Any]]]:
return [] # 依赖的其他插件 return [] # 依赖的其他插件
@property @property
@ -104,7 +105,7 @@ class PluginBase(ABC):
enabled=self.enable_plugin, enabled=self.enable_plugin,
is_built_in=False, is_built_in=False,
config_file=self.config_file_name or "", config_file=self.config_file_name or "",
dependencies=self.dependencies.copy(), dependencies=self._get_dependency_names(),
python_dependencies=self.python_dependencies.copy(), python_dependencies=self.python_dependencies.copy(),
# manifest相关信息 # manifest相关信息
manifest_data=self.manifest_data.copy(), manifest_data=self.manifest_data.copy(),
@ -541,13 +542,101 @@ class PluginBase(ABC):
if not self.dependencies: if not self.dependencies:
return True return True
for dep in self.dependencies: for dependency in self.dependencies:
if not component_registry.get_plugin_info(dep): dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency)
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}") if not dep_name:
logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}")
continue
dep_plugin_info = component_registry.get_plugin_info(dep_name)
if not dep_plugin_info:
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}")
return False
dep_version = dep_plugin_info.version or "0.0.0"
if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok:
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
return False
if min_version or max_version:
is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version)
if not is_ok:
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})"
)
return False return False
return True return True
def _get_dependency_names(self) -> List[str]:
"""获取依赖插件名称列表(用于插件信息展示和统计)。"""
dependency_names: List[str] = []
for dependency in self.dependencies:
dep_name, _, _, _ = self._parse_dependency(dependency)
if dep_name:
dependency_names.append(dep_name)
return dependency_names
def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]:
"""解析依赖声明。
支持格式
- "plugin_a"
- {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"}
- {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"}
"""
if isinstance(dependency, str):
return dependency.strip(), "", "", ""
if isinstance(dependency, dict):
dep_name = str(dependency.get("name", "")).strip()
version_spec = str(dependency.get("version", "")).strip()
min_version = str(dependency.get("min_version", "")).strip()
max_version = str(dependency.get("max_version", "")).strip()
return dep_name, version_spec, min_version, max_version
return "", "", "", ""
def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]:
"""检查版本是否满足表达式。
支持==, >=, <=, >, <可用逗号分隔多个条件
示例">=1.2.0,<2.0.0"
"""
normalized_version = VersionComparator.normalize_version(version)
clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()]
if not clauses:
return True, ""
operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$"
for clause in clauses:
if not (match := re.match(operators_pattern, clause)):
return False, f"无效版本约束表达式: {clause}"
operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2))
compare_result = VersionComparator.compare_versions(normalized_version, target_version)
is_satisfied = False
if operator == "==":
is_satisfied = compare_result == 0
elif operator == ">=":
is_satisfied = compare_result >= 0
elif operator == "<=":
is_satisfied = compare_result <= 0
elif operator == ">":
is_satisfied = compare_result > 0
elif operator == "<":
is_satisfied = compare_result < 0
if not is_satisfied:
return False, f"{normalized_version} 不满足约束 {operator}{target_version}"
return True, ""
def get_config(self, key: str, default: Any = None) -> Any: def get_config(self, key: str, default: Any = None) -> Any:
"""获取插件配置值,支持嵌套键访问 """获取插件配置值,支持嵌套键访问

View File

@ -0,0 +1,20 @@
from dataclasses import dataclass, field
from typing import Any, Dict
@dataclass
class PluginServiceInfo:
"""插件服务注册信息"""
name: str
plugin_name: str
version: str = "1.0.0"
description: str = ""
enabled: bool = True
params_schema: Dict[str, Any] = field(default_factory=dict)
return_schema: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def full_name(self) -> str:
return f"{self.plugin_name}.{self.name}"

View File

@ -0,0 +1,14 @@
from enum import Enum
class WorkflowErrorCode(Enum):
"""Workflow统一错误码"""
PLUGIN_NOT_READY = "PLUGIN_NOT_READY"
STEP_TIMEOUT = "STEP_TIMEOUT"
BAD_PAYLOAD = "BAD_PAYLOAD"
DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED"
POLICY_BLOCKED = "POLICY_BLOCKED"
def __str__(self) -> str:
return self.value

View File

@ -0,0 +1,68 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
import time
class WorkflowStage(Enum):
"""Workflow阶段定义MVP固定阶段"""
INGRESS = "ingress"
PRE_PROCESS = "pre_process"
PLAN = "plan"
TOOL_EXECUTE = "tool_execute"
POST_PROCESS = "post_process"
EGRESS = "egress"
def __str__(self) -> str:
return self.value
@dataclass
class WorkflowContext:
"""Workflow上下文"""
trace_id: str
stream_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
timings: Dict[str, float] = field(default_factory=dict)
errors: List[str] = field(default_factory=list)
@dataclass
class WorkflowMessage:
"""Workflow消息包装"""
msg_type: str
payload: Dict[str, Any] = field(default_factory=dict)
headers: Dict[str, Any] = field(default_factory=dict)
mutable_flags: Dict[str, bool] = field(default_factory=dict)
@dataclass
class WorkflowStepResult:
"""Workflow步骤结果"""
status: Literal["continue", "stop", "failed"] = "continue"
return_message: Optional[str] = None
diagnostics: Dict[str, Any] = field(default_factory=dict)
events: List[Dict[str, Any]] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
@dataclass
class WorkflowStepInfo:
"""Workflow步骤元数据"""
name: str
stage: WorkflowStage
plugin_name: str
description: str = ""
enabled: bool = True
priority: int = 0
timeout_ms: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def full_name(self) -> str:
return f"{self.plugin_name}.{self.name}"

View File

@ -8,10 +8,14 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
from src.plugin_system.core.workflow_engine import workflow_engine
__all__ = [ __all__ = [
"plugin_manager", "plugin_manager",
"component_registry", "component_registry",
"events_manager", "events_manager",
"global_announcement_manager", "global_announcement_manager",
"plugin_service_registry",
"workflow_engine",
] ]

View File

@ -1,6 +1,6 @@
import re import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
@ -16,6 +16,7 @@ from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo
logger = get_logger("component_registry") logger = get_logger("component_registry")
@ -61,6 +62,10 @@ class ComponentRegistry:
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类""" """启用的事件处理器 event_handler名 -> event_handler类"""
# Workflow step注册表
self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage}
self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {}
logger.info("组件注册中心初始化完成") logger.info("组件注册中心初始化完成")
# == 注册方法 == # == 注册方法 ==
@ -282,9 +287,116 @@ class ComponentRegistry:
logger.warning(f"插件 {plugin_name} 未注册,无法移除") logger.warning(f"插件 {plugin_name} 未注册,无法移除")
return False return False
del self._plugins[plugin_name] del self._plugins[plugin_name]
self.remove_workflow_steps_by_plugin(plugin_name)
logger.info(f"插件 {plugin_name} 已移除") logger.info(f"插件 {plugin_name} 已移除")
return True return True
# === Workflow step 注册与查询 ===
def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
"""注册workflow步骤。"""
if not step_info.name or not step_info.plugin_name:
logger.error("workflow step 注册失败: step名称或插件名称为空")
return False
if "." in step_info.name:
logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in step_info.plugin_name:
logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
full_name = step_info.full_name
stage_registry = self._workflow_steps.get(step_info.stage)
if stage_registry is None:
logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}")
return False
if full_name in stage_registry:
logger.warning(f"workflow step 已存在,跳过注册: {full_name}")
return False
stage_registry[full_name] = step_info
self._workflow_step_handlers[full_name] = step_handler
logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}")
return True
def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
"""获取某阶段的workflow步骤。"""
steps = self._workflow_steps.get(stage, {})
if enabled_only:
return {name: info for name, info in steps.items() if info.enabled}
return steps.copy()
def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
"""获取workflow step信息。
step_name支持两种
- full_name: plugin_name.step_name
- short_name: step_name若有冲突返回第一个并告警
"""
candidates: List[WorkflowStepInfo] = []
target_stages = [stage] if stage else list(WorkflowStage)
for current_stage in target_stages:
current_steps = self._workflow_steps.get(current_stage, {})
if "." in step_name:
if step_info := current_steps.get(step_name):
return step_info
continue
candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name])
if len(candidates) == 1:
return candidates[0]
if len(candidates) > 1:
logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}")
return candidates[0]
return None
def get_workflow_step_handler(
self, step_name: str, stage: Optional[WorkflowStage] = None
) -> Optional[Callable[..., Any]]:
"""获取workflow step处理函数。"""
if "." in step_name:
return self._workflow_step_handlers.get(step_name)
if step_info := self.get_workflow_step(step_name, stage):
return self._workflow_step_handlers.get(step_info.full_name)
return None
def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""启用workflow step。"""
step_info = self.get_workflow_step(step_name, stage)
if not step_info:
logger.warning(f"workflow step 未注册,无法启用: {step_name}")
return False
step_info.enabled = True
logger.info(f"workflow step 已启用: {step_info.full_name}")
return True
def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""禁用workflow step。"""
step_info = self.get_workflow_step(step_name, stage)
if not step_info:
logger.warning(f"workflow step 未注册,无法禁用: {step_name}")
return False
step_info.enabled = False
logger.info(f"workflow step 已禁用: {step_info.full_name}")
return True
def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有workflow step。"""
removed_count = 0
for stage in WorkflowStage:
stage_registry = self._workflow_steps.get(stage, {})
target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name]
for full_name in target_names:
stage_registry.pop(full_name, None)
self._workflow_step_handlers.pop(full_name, None)
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}")
return removed_count
# === 组件全局启用/禁用方法 === # === 组件全局启用/禁用方法 ===
def enable_component(self, component_name: str, component_type: ComponentType) -> bool: def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
@ -602,6 +714,12 @@ class ComponentRegistry:
tool_components += 1 tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER: elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1 events_handlers += 1
workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values())
enabled_workflow_step_count = sum(
len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values()
)
return { return {
"action_components": action_components, "action_components": action_components,
"command_components": command_components, "command_components": command_components,
@ -614,6 +732,11 @@ class ComponentRegistry:
}, },
"enabled_components": len([c for c in self._components.values() if c.enabled]), "enabled_components": len([c for c in self._components.values() if c.enabled]),
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": {
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
},
} }

View File

@ -7,7 +7,9 @@ from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult
from .global_announcement_manager import global_announcement_manager from .global_announcement_manager import global_announcement_manager
from .workflow_engine import workflow_engine
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
@ -57,6 +59,10 @@ class EventsManager:
return False return False
if handler_info.event_type not in self._history_enable_map: if handler_info.event_type not in self._history_enable_map:
if isinstance(handler_info.event_type, str):
self.register_event(handler_info.event_type, enable_history_result=False)
logger.info(f"自动注册自定义事件类型: {handler_info.event_type}")
else:
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}") logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
return False return False
@ -65,8 +71,8 @@ class EventsManager:
async def handle_mai_events( async def handle_mai_events(
self, self,
event_type: EventType, event_type: EventType | str,
message: Optional[MessageRecv | MessageSending] = None, message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None, llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None, llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
@ -121,6 +127,37 @@ class EventsManager:
return continue_flag, modified_message return continue_flag, modified_message
async def handle_workflow_message(
self,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行线性workflow消息流转MVP兼容入口"""
initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id)
async def _dispatch(
event_type: EventType | str,
workflow_message: Optional[MaiMessages],
workflow_stream_id: Optional[str],
workflow_action_usage: Optional[List[str]],
) -> Tuple[bool, Optional[MaiMessages]]:
return await self.handle_mai_events(
event_type=event_type,
message=workflow_message,
stream_id=workflow_stream_id,
action_usage=workflow_action_usage,
)
return await workflow_engine.execute_linear(
dispatch_event=_dispatch,
message=initial_message,
stream_id=stream_id,
action_usage=action_usage,
context=context,
)
async def cancel_handler_tasks(self, handler_name: str) -> None: async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, []) tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]: if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
@ -294,14 +331,17 @@ class EventsManager:
def _prepare_message( def _prepare_message(
self, self,
event_type: EventType, event_type: EventType | str,
message: Optional[MessageRecv | MessageSending] = None, message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None, llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None, llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None, action_usage: Optional[List[str]] = None,
) -> Optional[MaiMessages]: ) -> Optional[MaiMessages]:
"""根据事件类型和输入,准备和转换消息对象。""" """根据事件类型和输入,准备和转换消息对象。"""
if isinstance(message, MaiMessages):
return message.deepcopy()
if message: if message:
return self._transform_event_message(message, llm_prompt, llm_response) return self._transform_event_message(message, llm_prompt, llm_response)

View File

@ -1,9 +1,9 @@
import os
import traceback
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Type
from collections import deque
import os
import traceback
from src.common.logger import get_logger from src.common.logger import get_logger
@ -11,6 +11,7 @@ from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry from .component_registry import component_registry
from .plugin_service_registry import plugin_service_registry
logger = get_logger("plugin_manager") logger = get_logger("plugin_manager")
@ -70,10 +71,36 @@ class PluginManager:
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}") logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
# 第二阶段前:根据依赖关系决定插件加载顺序
dependency_graph, missing_dependencies = self._build_plugin_dependency_graph()
sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph)
total_registered = 0 total_registered = 0
total_failed_registration = 0 total_failed_registration = 0
for plugin_name in self.plugin_classes.keys(): # 先处理缺失依赖
for plugin_name, missing in missing_dependencies.items():
if not missing:
continue
missing_dep_names = ", ".join(sorted(missing))
self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}"
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}")
total_failed_registration += 1
# 再处理循环依赖
for plugin_name in sorted(cycle_plugins):
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
continue
self.failed_plugins[plugin_name] = "检测到循环依赖"
logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖")
total_failed_registration += 1
# 最后按拓扑序加载可加载插件
for plugin_name in sorted_plugins:
if plugin_name in cycle_plugins:
continue
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
continue
load_status, count = self.load_registered_plugin_classes(plugin_name) load_status, count = self.load_registered_plugin_classes(plugin_name)
if load_status: if load_status:
total_registered += 1 total_registered += 1
@ -165,6 +192,7 @@ class PluginManager:
for component in plugin_info.components: for component in plugin_info.components:
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name) success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
success &= component_registry.remove_plugin_registry(plugin_name) success &= component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
del self.loaded_plugins[plugin_name] del self.loaded_plugins[plugin_name]
return success return success
@ -313,6 +341,89 @@ class PluginManager:
self.failed_plugins[module_name] = error_msg self.failed_plugins[module_name] = error_msg
return False return False
# == 依赖解析与加载顺序 ==
def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]:
"""提取插件声明的依赖。
兼容声明格式
- list[str]
- list[dict]其中dict至少包含name键
"""
dependencies: Set[str] = set()
raw_dependencies = getattr(plugin_class, "dependencies", [])
# 兼容错误声明
if isinstance(raw_dependencies, property):
logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理")
return dependencies
if not isinstance(raw_dependencies, list):
logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理")
return dependencies
for dependency in raw_dependencies:
dependency_name = ""
if isinstance(dependency, str):
dependency_name = dependency.strip()
elif isinstance(dependency, dict):
dependency_name = str(dependency.get("name", "")).strip()
else:
logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}")
continue
if not dependency_name:
continue
if dependency_name == plugin_name:
logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略")
continue
dependencies.add(dependency_name)
return dependencies
def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""构建依赖图并返回缺失依赖映射。"""
plugin_names = set(self.plugin_classes.keys())
dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names}
missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names}
for plugin_name, plugin_class in self.plugin_classes.items():
declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class)
for dependency_name in declared_dependencies:
if dependency_name in plugin_names:
dependency_graph[plugin_name].add(dependency_name)
else:
missing_dependencies[plugin_name].add(dependency_name)
return dependency_graph, missing_dependencies
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items():
for dependency_name in dependencies:
reverse_graph[dependency_name].add(plugin_name)
zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0]))
load_order: List[str] = []
while zero_indegree_queue:
current_plugin = zero_indegree_queue.popleft()
load_order.append(current_plugin)
for dependent_plugin in sorted(reverse_graph[current_plugin]):
indegree[dependent_plugin] -= 1
if indegree[dependent_plugin] == 0:
zero_indegree_queue.append(dependent_plugin)
cycle_plugins = {name for name, degree in indegree.items() if degree > 0}
if cycle_plugins:
logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}")
return load_order, cycle_plugins
# == 兼容性检查 == # == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:

View File

@ -0,0 +1,140 @@
from typing import Any, Callable, Dict, Optional
import inspect
from src.common.logger import get_logger
from src.plugin_system.base.service_types import PluginServiceInfo
logger = get_logger("plugin_service_registry")
class PluginServiceRegistry:
"""插件服务注册中心"""
def __init__(self):
self._services: Dict[str, PluginServiceInfo] = {}
self._service_handlers: Dict[str, Callable[..., Any]] = {}
logger.info("插件服务注册中心初始化完成")
def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
"""注册插件服务。"""
if not service_info.name or not service_info.plugin_name:
logger.error("插件服务注册失败: service名称或插件名称为空")
return False
if "." in service_info.name:
logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in service_info.plugin_name:
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
full_name = service_info.full_name
if full_name in self._services:
logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}")
return False
self._services[full_name] = service_info
self._service_handlers[full_name] = service_handler
logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})")
return True
def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
"""获取插件服务元信息。
service_name支持
- full_name: plugin_name.service_name
- short_name: service_name当唯一时可解析
"""
full_name = self._resolve_full_name(service_name, plugin_name)
return self._services.get(full_name) if full_name else None
def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
"""获取插件服务处理函数。"""
full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
services = self._services.copy()
if plugin_name:
services = {name: info for name, info in services.items() if info.plugin_name == plugin_name}
if enabled_only:
services = {name: info for name, info in services.items() if info.enabled}
return services
def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""启用插件服务。"""
if not (service_info := self.get_service(service_name, plugin_name)):
logger.warning(f"插件服务未注册,无法启用: {service_name}")
return False
service_info.enabled = True
logger.info(f"插件服务已启用: {service_info.full_name}")
return True
def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""禁用插件服务。"""
if not (service_info := self.get_service(service_name, plugin_name)):
logger.warning(f"插件服务未注册,无法禁用: {service_name}")
return False
service_info.enabled = False
logger.info(f"插件服务已禁用: {service_info.full_name}")
return True
def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""注销单个插件服务。"""
full_name = self._resolve_full_name(service_name, plugin_name)
if not full_name:
logger.warning(f"插件服务未注册,无法注销: {service_name}")
return False
self._services.pop(full_name, None)
self._service_handlers.pop(full_name, None)
logger.info(f"插件服务已注销: {full_name}")
return True
def remove_services_by_plugin(self, plugin_name: str) -> int:
"""移除某插件的所有注册服务。"""
target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name]
for full_name in target_names:
self._services.pop(full_name, None)
self._service_handlers.pop(full_name, None)
removed_count = len(target_names)
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
return removed_count
async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
"""调用插件服务(支持同步/异步handler"""
service_info = self.get_service(service_name, plugin_name)
if not service_info:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}")
if not service_info.enabled:
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
handler = self.get_service_handler(service_name, plugin_name)
if not handler:
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
result = handler(*args, **kwargs)
return await result if inspect.isawaitable(result) else result
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
"""解析服务全名。"""
if "." in service_name:
return service_name if service_name in self._services else None
if plugin_name:
full_name = f"{plugin_name}.{service_name}"
return full_name if full_name in self._services else None
candidates = [full_name for full_name, info in self._services.items() if info.name == service_name]
if len(candidates) == 1:
return candidates[0]
if len(candidates) > 1:
logger.warning(f"插件服务名称 '{service_name}' 存在多义请传入plugin_name或使用完整服务名")
return None
return None
plugin_service_registry = PluginServiceRegistry()

View File

@ -0,0 +1,229 @@
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import time
import uuid
import inspect
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, MaiMessages
from src.plugin_system.base.workflow_errors import WorkflowErrorCode
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult
logger = get_logger("workflow_engine")
class WorkflowEngine:
"""线性Workflow执行器MVP"""
STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [
(WorkflowStage.INGRESS, "workflow.ingress"),
(WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS),
(WorkflowStage.PLAN, EventType.ON_PLAN),
(WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"),
(WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS),
(WorkflowStage.EGRESS, "workflow.egress"),
]
def __init__(self):
self._execution_history: dict[str, dict[str, Any]] = {}
async def execute_linear(
self,
dispatch_event: Callable[
[Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]],
Awaitable[Tuple[bool, Optional[MaiMessages]]],
],
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行线性workflow。"""
workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id)
current_message = message.deepcopy() if message else None
self._execution_history[workflow_context.trace_id] = {
"trace_id": workflow_context.trace_id,
"stream_id": workflow_context.stream_id,
"stages": [],
"errors": [],
"status": "running",
}
for stage, event_type in self.STAGE_EVENT_SEQUENCE:
stage_key = str(stage)
stage_start = time.perf_counter()
try:
should_continue, modified_message = await dispatch_event(
event_type,
current_message,
workflow_context.stream_id,
action_usage,
)
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
self._execution_history[workflow_context.trace_id]["stages"].append(
{
"stage": stage_key,
"event_type": str(event_type),
"event_continue": should_continue,
"event_cost": workflow_context.timings[stage_key],
}
)
if modified_message:
current_message = modified_message
if not should_continue:
logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断")
return (
WorkflowStepResult(
status="stop",
return_message=f"workflow stopped at stage {stage_key}",
diagnostics={
"stage": stage_key,
"trace_id": workflow_context.trace_id,
"error_code": WorkflowErrorCode.POLICY_BLOCKED.value,
},
),
current_message,
workflow_context,
)
step_result = await self._execute_registered_steps(stage, workflow_context, current_message)
if step_result.status in ["stop", "failed"]:
self._execution_history[workflow_context.trace_id]["status"] = step_result.status
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return step_result, current_message, workflow_context
except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
WorkflowStepResult(
status="failed",
return_message=str(e),
diagnostics={
"stage": stage_key,
"trace_id": workflow_context.trace_id,
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
},
),
current_message,
workflow_context,
)
self._execution_history[workflow_context.trace_id]["status"] = "continue"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
WorkflowStepResult(
status="continue",
return_message="workflow completed",
diagnostics={"trace_id": workflow_context.trace_id},
),
current_message,
workflow_context,
)
async def _execute_registered_steps(
self,
stage: WorkflowStage,
context: WorkflowContext,
message: Optional[MaiMessages],
) -> WorkflowStepResult:
"""执行指定阶段已注册的workflow步骤。"""
from src.plugin_system.core.component_registry import component_registry
stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True)
sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True)
for step_info in sorted_steps:
handler = component_registry.get_workflow_step_handler(step_info.full_name, stage)
if not handler:
context.errors.append(f"{step_info.full_name}: handler not found")
continue
step_timing_key = f"{stage.value}:{step_info.full_name}"
step_start = time.perf_counter()
try:
result = handler(context, message)
if inspect.isawaitable(result):
result = await result
context.timings[step_timing_key] = time.perf_counter() - step_start
normalized_result = self._normalize_step_result(result)
if normalized_result.status == "continue":
continue
normalized_result.diagnostics.setdefault("stage", stage.value)
normalized_result.diagnostics.setdefault("step", step_info.full_name)
normalized_result.diagnostics.setdefault("trace_id", context.trace_id)
if normalized_result.status == "failed":
context.errors.append(
f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}"
)
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
return normalized_result
except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}")
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
return WorkflowStepResult(
status="failed",
return_message=str(e),
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
},
)
return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id})
def _normalize_step_result(self, result: Any) -> WorkflowStepResult:
"""归一化workflow step返回值。"""
if isinstance(result, WorkflowStepResult):
return result
if isinstance(result, bool):
if result:
return WorkflowStepResult(status="continue")
return WorkflowStepResult(
status="failed",
diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value},
)
if result is None:
return WorkflowStepResult(status="continue")
if isinstance(result, str):
return WorkflowStepResult(status="continue", return_message=result)
if isinstance(result, dict):
status = result.get("status", "continue")
if status not in ["continue", "stop", "failed"]:
status = "failed"
return WorkflowStepResult(
status=status,
return_message=result.get("return_message"),
diagnostics=result.get("diagnostics", {}),
events=result.get("events", []),
)
return WorkflowStepResult(
status="failed",
return_message=f"unsupported step result type: {type(result)}",
diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value},
)
def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]:
"""按trace_id获取workflow执行路径。"""
trace = self._execution_history.get(trace_id)
return trace.copy() if trace else None
def clear_execution_trace(self, trace_id: str) -> bool:
"""清理trace执行记录。"""
if trace_id in self._execution_history:
del self._execution_history[trace_id]
return True
return False
workflow_engine = WorkflowEngine()