mirror of https://github.com/Mai-with-u/MaiBot.git
feat(plugin-system): add workflow pipeline and cross-plugin service registry
parent
6bcd7cbebb
commit
6fcc53a22b
|
|
@ -29,10 +29,17 @@ from .base import (
|
|||
MaiMessages,
|
||||
ToolParamType,
|
||||
CustomEventHandlerResult,
|
||||
PluginServiceInfo,
|
||||
ReplyContentType,
|
||||
ReplyContent,
|
||||
ForwardNode,
|
||||
ReplySetModel,
|
||||
WorkflowContext,
|
||||
WorkflowMessage,
|
||||
WorkflowStage,
|
||||
WorkflowStepInfo,
|
||||
WorkflowStepResult,
|
||||
WorkflowErrorCode,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
|
|
@ -55,6 +62,8 @@ from .apis import (
|
|||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
plugin_service_api,
|
||||
workflow_api,
|
||||
send_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
|
|
@ -85,6 +94,8 @@ __all__ = [
|
|||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"plugin_service_api",
|
||||
"workflow_api",
|
||||
"send_api",
|
||||
"auto_talk_api",
|
||||
"register_plugin",
|
||||
|
|
@ -115,6 +126,13 @@ __all__ = [
|
|||
"ReplySetModel",
|
||||
"MaiMessages",
|
||||
"CustomEventHandlerResult",
|
||||
"PluginServiceInfo",
|
||||
"WorkflowContext",
|
||||
"WorkflowMessage",
|
||||
"WorkflowStage",
|
||||
"WorkflowStepInfo",
|
||||
"WorkflowStepResult",
|
||||
"WorkflowErrorCode",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
|
|
|
|||
|
|
@ -16,9 +16,11 @@ from src.plugin_system.apis import (
|
|||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
plugin_service_api,
|
||||
send_api,
|
||||
tool_api,
|
||||
frequency_api,
|
||||
workflow_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
|
@ -35,9 +37,11 @@ __all__ = [
|
|||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"plugin_service_api",
|
||||
"send_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
"frequency_api",
|
||||
"workflow_api",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from src.plugin_system.base.service_types import PluginServiceInfo
|
||||
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
|
||||
|
||||
|
||||
def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
|
||||
"""注册插件服务。"""
|
||||
return plugin_service_registry.register_service(service_info, service_handler)
|
||||
|
||||
|
||||
def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
|
||||
"""获取插件服务元信息。"""
|
||||
return plugin_service_registry.get_service(service_name, plugin_name)
|
||||
|
||||
|
||||
def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
|
||||
"""获取插件服务处理函数。"""
|
||||
return plugin_service_registry.get_service_handler(service_name, plugin_name)
|
||||
|
||||
|
||||
def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
||||
"""列出插件服务。"""
|
||||
return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only)
|
||||
|
||||
|
||||
def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
|
||||
"""启用插件服务。"""
|
||||
return plugin_service_registry.enable_service(service_name, plugin_name)
|
||||
|
||||
|
||||
def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
|
||||
"""禁用插件服务。"""
|
||||
return plugin_service_registry.disable_service(service_name, plugin_name)
|
||||
|
||||
|
||||
def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
|
||||
"""注销插件服务。"""
|
||||
return plugin_service_registry.unregister_service(service_name, plugin_name)
|
||||
|
||||
|
||||
async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
|
||||
"""调用插件服务。"""
|
||||
return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs)
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from src.plugin_system.base.component_types import EventType, MaiMessages
|
||||
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.workflow_engine import workflow_engine
|
||||
|
||||
|
||||
def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
|
||||
"""注册workflow step。"""
|
||||
return component_registry.register_workflow_step(step_info, step_handler)
|
||||
|
||||
|
||||
def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
|
||||
"""获取指定阶段的workflow steps。"""
|
||||
return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only)
|
||||
|
||||
|
||||
def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
|
||||
"""获取workflow step元信息。"""
|
||||
return component_registry.get_workflow_step(step_name, stage)
|
||||
|
||||
|
||||
def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]:
|
||||
"""获取workflow step处理函数。"""
|
||||
return component_registry.get_workflow_step_handler(step_name, stage)
|
||||
|
||||
|
||||
def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
||||
"""启用workflow step。"""
|
||||
return component_registry.enable_workflow_step(step_name, stage)
|
||||
|
||||
|
||||
def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
||||
"""禁用workflow step。"""
|
||||
return component_registry.disable_workflow_step(step_name, stage)
|
||||
|
||||
|
||||
def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""按trace_id获取workflow执行路径。"""
|
||||
return workflow_engine.get_execution_trace(trace_id)
|
||||
|
||||
|
||||
def clear_execution_trace(trace_id: str) -> bool:
|
||||
"""清理trace执行路径记录。"""
|
||||
return workflow_engine.clear_execution_trace(trace_id)
|
||||
|
||||
|
||||
async def execute_workflow_message(
|
||||
message: Optional[MaiMessages] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
|
||||
"""执行workflow消息流转。"""
|
||||
return await events_manager.handle_workflow_message(
|
||||
message=message,
|
||||
stream_id=stream_id,
|
||||
action_usage=action_usage,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
async def publish_event(
|
||||
event_type: Union[EventType, str],
|
||||
message: Optional[MaiMessages] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||
"""发布事件(支持系统事件和自定义字符串事件)。"""
|
||||
return await events_manager.handle_mai_events(
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
stream_id=stream_id,
|
||||
action_usage=action_usage,
|
||||
)
|
||||
|
|
@ -9,6 +9,9 @@ from .base_action import BaseAction
|
|||
from .base_tool import BaseTool
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .service_types import PluginServiceInfo
|
||||
from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
|
||||
from .workflow_errors import WorkflowErrorCode
|
||||
from .component_types import (
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
|
|
@ -59,4 +62,11 @@ __all__ = [
|
|||
"ReplyContent",
|
||||
"ForwardNode",
|
||||
"ReplySetModel",
|
||||
"PluginServiceInfo",
|
||||
"WorkflowContext",
|
||||
"WorkflowMessage",
|
||||
"WorkflowStage",
|
||||
"WorkflowStepInfo",
|
||||
"WorkflowStepResult",
|
||||
"WorkflowErrorCode",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Type, Tuple, Union
|
||||
from typing import Any, Callable, List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
|
||||
from src.plugin_system.base.workflow_types import WorkflowStepInfo
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
|
|
@ -44,6 +45,13 @@ class BasePlugin(PluginBase):
|
|||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]:
|
||||
"""获取插件包含的workflow steps。
|
||||
|
||||
默认返回空列表,子类可按需覆盖。
|
||||
"""
|
||||
return []
|
||||
|
||||
def register_plugin(self) -> bool:
|
||||
"""注册插件及其所有组件"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
|
@ -69,7 +77,25 @@ class BasePlugin(PluginBase):
|
|||
|
||||
# 注册插件
|
||||
if component_registry.register_plugin(self.plugin_info):
|
||||
# 注册workflow steps(可选)
|
||||
registered_step_count = 0
|
||||
for step_info, step_handler in self.get_workflow_steps():
|
||||
if not step_info.plugin_name:
|
||||
step_info.plugin_name = self.plugin_name
|
||||
elif step_info.plugin_name != self.plugin_name:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}"
|
||||
)
|
||||
step_info.plugin_name = self.plugin_name
|
||||
|
||||
if component_registry.register_workflow_step(step_info, step_handler):
|
||||
registered_step_count += 1
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败")
|
||||
|
||||
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
|
||||
if registered_step_count > 0:
|
||||
logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 插件注册失败")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import toml
|
|||
import json
|
||||
import shutil
|
||||
import datetime
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
|
|
@ -17,7 +18,7 @@ from src.plugin_system.base.config_types import (
|
|||
ConfigSection,
|
||||
ConfigLayout,
|
||||
)
|
||||
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
||||
from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator
|
||||
|
||||
logger = get_logger("plugin_base")
|
||||
|
||||
|
|
@ -41,7 +42,7 @@ class PluginBase(ABC):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dependencies(self) -> List[str]:
|
||||
def dependencies(self) -> List[Union[str, Dict[str, Any]]]:
|
||||
return [] # 依赖的其他插件
|
||||
|
||||
@property
|
||||
|
|
@ -104,7 +105,7 @@ class PluginBase(ABC):
|
|||
enabled=self.enable_plugin,
|
||||
is_built_in=False,
|
||||
config_file=self.config_file_name or "",
|
||||
dependencies=self.dependencies.copy(),
|
||||
dependencies=self._get_dependency_names(),
|
||||
python_dependencies=self.python_dependencies.copy(),
|
||||
# manifest相关信息
|
||||
manifest_data=self.manifest_data.copy(),
|
||||
|
|
@ -541,13 +542,101 @@ class PluginBase(ABC):
|
|||
if not self.dependencies:
|
||||
return True
|
||||
|
||||
for dep in self.dependencies:
|
||||
if not component_registry.get_plugin_info(dep):
|
||||
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}")
|
||||
for dependency in self.dependencies:
|
||||
dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency)
|
||||
if not dep_name:
|
||||
logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}")
|
||||
continue
|
||||
|
||||
dep_plugin_info = component_registry.get_plugin_info(dep_name)
|
||||
if not dep_plugin_info:
|
||||
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}")
|
||||
return False
|
||||
|
||||
dep_version = dep_plugin_info.version or "0.0.0"
|
||||
|
||||
if version_spec:
|
||||
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
||||
if not is_ok:
|
||||
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
|
||||
return False
|
||||
|
||||
if min_version or max_version:
|
||||
is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version)
|
||||
if not is_ok:
|
||||
logger.error(
|
||||
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_dependency_names(self) -> List[str]:
|
||||
"""获取依赖插件名称列表(用于插件信息展示和统计)。"""
|
||||
dependency_names: List[str] = []
|
||||
for dependency in self.dependencies:
|
||||
dep_name, _, _, _ = self._parse_dependency(dependency)
|
||||
if dep_name:
|
||||
dependency_names.append(dep_name)
|
||||
return dependency_names
|
||||
|
||||
def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]:
|
||||
"""解析依赖声明。
|
||||
|
||||
支持格式:
|
||||
- "plugin_a"
|
||||
- {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"}
|
||||
- {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"}
|
||||
"""
|
||||
if isinstance(dependency, str):
|
||||
return dependency.strip(), "", "", ""
|
||||
|
||||
if isinstance(dependency, dict):
|
||||
dep_name = str(dependency.get("name", "")).strip()
|
||||
version_spec = str(dependency.get("version", "")).strip()
|
||||
min_version = str(dependency.get("min_version", "")).strip()
|
||||
max_version = str(dependency.get("max_version", "")).strip()
|
||||
return dep_name, version_spec, min_version, max_version
|
||||
|
||||
return "", "", "", ""
|
||||
|
||||
def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]:
|
||||
"""检查版本是否满足表达式。
|
||||
|
||||
支持:==, >=, <=, >, <,可用逗号分隔多个条件。
|
||||
示例:">=1.2.0,<2.0.0"
|
||||
"""
|
||||
normalized_version = VersionComparator.normalize_version(version)
|
||||
clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()]
|
||||
if not clauses:
|
||||
return True, ""
|
||||
|
||||
operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$"
|
||||
|
||||
for clause in clauses:
|
||||
if not (match := re.match(operators_pattern, clause)):
|
||||
return False, f"无效版本约束表达式: {clause}"
|
||||
|
||||
operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2))
|
||||
compare_result = VersionComparator.compare_versions(normalized_version, target_version)
|
||||
|
||||
is_satisfied = False
|
||||
if operator == "==":
|
||||
is_satisfied = compare_result == 0
|
||||
elif operator == ">=":
|
||||
is_satisfied = compare_result >= 0
|
||||
elif operator == "<=":
|
||||
is_satisfied = compare_result <= 0
|
||||
elif operator == ">":
|
||||
is_satisfied = compare_result > 0
|
||||
elif operator == "<":
|
||||
is_satisfied = compare_result < 0
|
||||
|
||||
if not is_satisfied:
|
||||
return False, f"{normalized_version} 不满足约束 {operator}{target_version}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginServiceInfo:
|
||||
"""插件服务注册信息"""
|
||||
|
||||
name: str
|
||||
plugin_name: str
|
||||
version: str = "1.0.0"
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
params_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
return_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.plugin_name}.{self.name}"
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class WorkflowErrorCode(Enum):
|
||||
"""Workflow统一错误码"""
|
||||
|
||||
PLUGIN_NOT_READY = "PLUGIN_NOT_READY"
|
||||
STEP_TIMEOUT = "STEP_TIMEOUT"
|
||||
BAD_PAYLOAD = "BAD_PAYLOAD"
|
||||
DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED"
|
||||
POLICY_BLOCKED = "POLICY_BLOCKED"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
import time
|
||||
|
||||
|
||||
class WorkflowStage(Enum):
|
||||
"""Workflow阶段定义(MVP固定阶段)"""
|
||||
|
||||
INGRESS = "ingress"
|
||||
PRE_PROCESS = "pre_process"
|
||||
PLAN = "plan"
|
||||
TOOL_EXECUTE = "tool_execute"
|
||||
POST_PROCESS = "post_process"
|
||||
EGRESS = "egress"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowContext:
|
||||
"""Workflow上下文"""
|
||||
|
||||
trace_id: str
|
||||
stream_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
timings: Dict[str, float] = field(default_factory=dict)
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowMessage:
|
||||
"""Workflow消息包装"""
|
||||
|
||||
msg_type: str
|
||||
payload: Dict[str, Any] = field(default_factory=dict)
|
||||
headers: Dict[str, Any] = field(default_factory=dict)
|
||||
mutable_flags: Dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowStepResult:
|
||||
"""Workflow步骤结果"""
|
||||
|
||||
status: Literal["continue", "stop", "failed"] = "continue"
|
||||
return_message: Optional[str] = None
|
||||
diagnostics: Dict[str, Any] = field(default_factory=dict)
|
||||
events: List[Dict[str, Any]] = field(default_factory=list)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowStepInfo:
|
||||
"""Workflow步骤元数据"""
|
||||
|
||||
name: str
|
||||
stage: WorkflowStage
|
||||
plugin_name: str
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
timeout_ms: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.plugin_name}.{self.name}"
|
||||
|
|
@ -8,10 +8,14 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
|||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
|
||||
from src.plugin_system.core.workflow_engine import workflow_engine
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
"plugin_service_registry",
|
||||
"workflow_engine",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
|
|
@ -16,6 +16,7 @@ from src.plugin_system.base.base_command import BaseCommand
|
|||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
|
|
@ -61,6 +62,10 @@ class ComponentRegistry:
|
|||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
# Workflow step注册表
|
||||
self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage}
|
||||
self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {}
|
||||
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
# == 注册方法 ==
|
||||
|
|
@ -282,9 +287,116 @@ class ComponentRegistry:
|
|||
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
|
||||
return False
|
||||
del self._plugins[plugin_name]
|
||||
self.remove_workflow_steps_by_plugin(plugin_name)
|
||||
logger.info(f"插件 {plugin_name} 已移除")
|
||||
return True
|
||||
|
||||
# === Workflow step 注册与查询 ===
|
||||
|
||||
def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
|
||||
"""注册workflow步骤。"""
|
||||
if not step_info.name or not step_info.plugin_name:
|
||||
logger.error("workflow step 注册失败: step名称或插件名称为空")
|
||||
return False
|
||||
if "." in step_info.name:
|
||||
logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
if "." in step_info.plugin_name:
|
||||
logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
|
||||
full_name = step_info.full_name
|
||||
stage_registry = self._workflow_steps.get(step_info.stage)
|
||||
if stage_registry is None:
|
||||
logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}")
|
||||
return False
|
||||
if full_name in stage_registry:
|
||||
logger.warning(f"workflow step 已存在,跳过注册: {full_name}")
|
||||
return False
|
||||
|
||||
stage_registry[full_name] = step_info
|
||||
self._workflow_step_handlers[full_name] = step_handler
|
||||
logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}")
|
||||
return True
|
||||
|
||||
def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
|
||||
"""获取某阶段的workflow步骤。"""
|
||||
steps = self._workflow_steps.get(stage, {})
|
||||
if enabled_only:
|
||||
return {name: info for name, info in steps.items() if info.enabled}
|
||||
return steps.copy()
|
||||
|
||||
def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
|
||||
"""获取workflow step信息。
|
||||
|
||||
step_name支持两种:
|
||||
- full_name: plugin_name.step_name
|
||||
- short_name: step_name(若有冲突返回第一个并告警)
|
||||
"""
|
||||
candidates: List[WorkflowStepInfo] = []
|
||||
|
||||
target_stages = [stage] if stage else list(WorkflowStage)
|
||||
for current_stage in target_stages:
|
||||
current_steps = self._workflow_steps.get(current_stage, {})
|
||||
if "." in step_name:
|
||||
if step_info := current_steps.get(step_name):
|
||||
return step_info
|
||||
continue
|
||||
candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name])
|
||||
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}")
|
||||
return candidates[0]
|
||||
return None
|
||||
|
||||
def get_workflow_step_handler(
|
||||
self, step_name: str, stage: Optional[WorkflowStage] = None
|
||||
) -> Optional[Callable[..., Any]]:
|
||||
"""获取workflow step处理函数。"""
|
||||
if "." in step_name:
|
||||
return self._workflow_step_handlers.get(step_name)
|
||||
|
||||
if step_info := self.get_workflow_step(step_name, stage):
|
||||
return self._workflow_step_handlers.get(step_info.full_name)
|
||||
return None
|
||||
|
||||
def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
||||
"""启用workflow step。"""
|
||||
step_info = self.get_workflow_step(step_name, stage)
|
||||
if not step_info:
|
||||
logger.warning(f"workflow step 未注册,无法启用: {step_name}")
|
||||
return False
|
||||
step_info.enabled = True
|
||||
logger.info(f"workflow step 已启用: {step_info.full_name}")
|
||||
return True
|
||||
|
||||
def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
||||
"""禁用workflow step。"""
|
||||
step_info = self.get_workflow_step(step_name, stage)
|
||||
if not step_info:
|
||||
logger.warning(f"workflow step 未注册,无法禁用: {step_name}")
|
||||
return False
|
||||
step_info.enabled = False
|
||||
logger.info(f"workflow step 已禁用: {step_info.full_name}")
|
||||
return True
|
||||
|
||||
def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int:
|
||||
"""移除某插件注册的所有workflow step。"""
|
||||
removed_count = 0
|
||||
for stage in WorkflowStage:
|
||||
stage_registry = self._workflow_steps.get(stage, {})
|
||||
target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name]
|
||||
for full_name in target_names:
|
||||
stage_registry.pop(full_name, None)
|
||||
self._workflow_step_handlers.pop(full_name, None)
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}")
|
||||
return removed_count
|
||||
|
||||
# === 组件全局启用/禁用方法 ===
|
||||
|
||||
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
|
|
@ -602,6 +714,12 @@ class ComponentRegistry:
|
|||
tool_components += 1
|
||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
||||
events_handlers += 1
|
||||
|
||||
workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values())
|
||||
enabled_workflow_step_count = sum(
|
||||
len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values()
|
||||
)
|
||||
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
|
|
@ -614,6 +732,11 @@ class ComponentRegistry:
|
|||
},
|
||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
"workflow_steps": workflow_step_count,
|
||||
"enabled_workflow_steps": enabled_workflow_step_count,
|
||||
"workflow_steps_by_stage": {
|
||||
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
|||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
from .workflow_engine import workflow_engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
|
@ -57,16 +59,20 @@ class EventsManager:
|
|||
return False
|
||||
|
||||
if handler_info.event_type not in self._history_enable_map:
|
||||
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
|
||||
return False
|
||||
if isinstance(handler_info.event_type, str):
|
||||
self.register_event(handler_info.event_type, enable_history_result=False)
|
||||
logger.info(f"自动注册自定义事件类型: {handler_info.event_type}")
|
||||
else:
|
||||
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
|
||||
return False
|
||||
|
||||
self._handler_mapping[handler_name] = handler_class
|
||||
return self._insert_event_handler(handler_class, handler_info)
|
||||
|
||||
async def handle_mai_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
message: Optional[MessageRecv | MessageSending] = None,
|
||||
event_type: EventType | str,
|
||||
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
|
|
@ -121,6 +127,37 @@ class EventsManager:
|
|||
|
||||
return continue_flag, modified_message
|
||||
|
||||
async def handle_workflow_message(
|
||||
self,
|
||||
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
|
||||
"""执行线性workflow消息流转(MVP兼容入口)。"""
|
||||
initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id)
|
||||
|
||||
async def _dispatch(
|
||||
event_type: EventType | str,
|
||||
workflow_message: Optional[MaiMessages],
|
||||
workflow_stream_id: Optional[str],
|
||||
workflow_action_usage: Optional[List[str]],
|
||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||
return await self.handle_mai_events(
|
||||
event_type=event_type,
|
||||
message=workflow_message,
|
||||
stream_id=workflow_stream_id,
|
||||
action_usage=workflow_action_usage,
|
||||
)
|
||||
|
||||
return await workflow_engine.execute_linear(
|
||||
dispatch_event=_dispatch,
|
||||
message=initial_message,
|
||||
stream_id=stream_id,
|
||||
action_usage=action_usage,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
|
||||
|
|
@ -294,14 +331,17 @@ class EventsManager:
|
|||
|
||||
def _prepare_message(
|
||||
self,
|
||||
event_type: EventType,
|
||||
message: Optional[MessageRecv | MessageSending] = None,
|
||||
event_type: EventType | str,
|
||||
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> Optional[MaiMessages]:
|
||||
"""根据事件类型和输入,准备和转换消息对象。"""
|
||||
if isinstance(message, MaiMessages):
|
||||
return message.deepcopy()
|
||||
|
||||
if message:
|
||||
return self._transform_event_message(message, llm_prompt, llm_response)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
import traceback
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
from collections import deque
|
||||
import os
|
||||
import traceback
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -11,6 +11,7 @@ from src.plugin_system.base.plugin_base import PluginBase
|
|||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
from .plugin_service_registry import plugin_service_registry
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
|
|
@ -70,10 +71,36 @@ class PluginManager:
|
|||
|
||||
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
||||
|
||||
# 第二阶段前:根据依赖关系决定插件加载顺序
|
||||
dependency_graph, missing_dependencies = self._build_plugin_dependency_graph()
|
||||
sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph)
|
||||
|
||||
total_registered = 0
|
||||
total_failed_registration = 0
|
||||
|
||||
for plugin_name in self.plugin_classes.keys():
|
||||
# 先处理缺失依赖
|
||||
for plugin_name, missing in missing_dependencies.items():
|
||||
if not missing:
|
||||
continue
|
||||
missing_dep_names = ", ".join(sorted(missing))
|
||||
self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}"
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}")
|
||||
total_failed_registration += 1
|
||||
|
||||
# 再处理循环依赖
|
||||
for plugin_name in sorted(cycle_plugins):
|
||||
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
|
||||
continue
|
||||
self.failed_plugins[plugin_name] = "检测到循环依赖"
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖")
|
||||
total_failed_registration += 1
|
||||
|
||||
# 最后按拓扑序加载可加载插件
|
||||
for plugin_name in sorted_plugins:
|
||||
if plugin_name in cycle_plugins:
|
||||
continue
|
||||
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
|
||||
continue
|
||||
load_status, count = self.load_registered_plugin_classes(plugin_name)
|
||||
if load_status:
|
||||
total_registered += 1
|
||||
|
|
@ -165,6 +192,7 @@ class PluginManager:
|
|||
for component in plugin_info.components:
|
||||
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
|
||||
success &= component_registry.remove_plugin_registry(plugin_name)
|
||||
plugin_service_registry.remove_services_by_plugin(plugin_name)
|
||||
del self.loaded_plugins[plugin_name]
|
||||
return success
|
||||
|
||||
|
|
@ -313,6 +341,89 @@ class PluginManager:
|
|||
self.failed_plugins[module_name] = error_msg
|
||||
return False
|
||||
|
||||
# == 依赖解析与加载顺序 ==
|
||||
|
||||
def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]:
|
||||
"""提取插件声明的依赖。
|
||||
|
||||
兼容声明格式:
|
||||
- list[str]
|
||||
- list[dict],其中dict至少包含name键
|
||||
"""
|
||||
dependencies: Set[str] = set()
|
||||
raw_dependencies = getattr(plugin_class, "dependencies", [])
|
||||
|
||||
# 兼容错误声明
|
||||
if isinstance(raw_dependencies, property):
|
||||
logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理")
|
||||
return dependencies
|
||||
if not isinstance(raw_dependencies, list):
|
||||
logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理")
|
||||
return dependencies
|
||||
|
||||
for dependency in raw_dependencies:
|
||||
dependency_name = ""
|
||||
if isinstance(dependency, str):
|
||||
dependency_name = dependency.strip()
|
||||
elif isinstance(dependency, dict):
|
||||
dependency_name = str(dependency.get("name", "")).strip()
|
||||
else:
|
||||
logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}")
|
||||
continue
|
||||
|
||||
if not dependency_name:
|
||||
continue
|
||||
if dependency_name == plugin_name:
|
||||
logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略")
|
||||
continue
|
||||
|
||||
dependencies.add(dependency_name)
|
||||
|
||||
return dependencies
|
||||
|
||||
def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
|
||||
"""构建依赖图并返回缺失依赖映射。"""
|
||||
plugin_names = set(self.plugin_classes.keys())
|
||||
dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names}
|
||||
missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names}
|
||||
|
||||
for plugin_name, plugin_class in self.plugin_classes.items():
|
||||
declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class)
|
||||
for dependency_name in declared_dependencies:
|
||||
if dependency_name in plugin_names:
|
||||
dependency_graph[plugin_name].add(dependency_name)
|
||||
else:
|
||||
missing_dependencies[plugin_name].add(dependency_name)
|
||||
|
||||
return dependency_graph, missing_dependencies
|
||||
|
||||
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
||||
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
||||
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
|
||||
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
||||
|
||||
for plugin_name, dependencies in dependency_graph.items():
|
||||
for dependency_name in dependencies:
|
||||
reverse_graph[dependency_name].add(plugin_name)
|
||||
|
||||
zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0]))
|
||||
load_order: List[str] = []
|
||||
|
||||
while zero_indegree_queue:
|
||||
current_plugin = zero_indegree_queue.popleft()
|
||||
load_order.append(current_plugin)
|
||||
|
||||
for dependent_plugin in sorted(reverse_graph[current_plugin]):
|
||||
indegree[dependent_plugin] -= 1
|
||||
if indegree[dependent_plugin] == 0:
|
||||
zero_indegree_queue.append(dependent_plugin)
|
||||
|
||||
cycle_plugins = {name for name, degree in indegree.items() if degree > 0}
|
||||
if cycle_plugins:
|
||||
logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}")
|
||||
|
||||
return load_order, cycle_plugins
|
||||
|
||||
# == 兼容性检查 ==
|
||||
|
||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
from typing import Any, Callable, Dict, Optional
|
||||
import inspect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.service_types import PluginServiceInfo
|
||||
|
||||
logger = get_logger("plugin_service_registry")
|
||||
|
||||
|
||||
class PluginServiceRegistry:
|
||||
"""插件服务注册中心"""
|
||||
|
||||
def __init__(self):
|
||||
self._services: Dict[str, PluginServiceInfo] = {}
|
||||
self._service_handlers: Dict[str, Callable[..., Any]] = {}
|
||||
logger.info("插件服务注册中心初始化完成")
|
||||
|
||||
def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
|
||||
"""注册插件服务。"""
|
||||
if not service_info.name or not service_info.plugin_name:
|
||||
logger.error("插件服务注册失败: service名称或插件名称为空")
|
||||
return False
|
||||
if "." in service_info.name:
|
||||
logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
if "." in service_info.plugin_name:
|
||||
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
|
||||
full_name = service_info.full_name
|
||||
if full_name in self._services:
|
||||
logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}")
|
||||
return False
|
||||
|
||||
self._services[full_name] = service_info
|
||||
self._service_handlers[full_name] = service_handler
|
||||
logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})")
|
||||
return True
|
||||
|
||||
def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
|
||||
"""获取插件服务元信息。
|
||||
|
||||
service_name支持:
|
||||
- full_name: plugin_name.service_name
|
||||
- short_name: service_name(当唯一时可解析)
|
||||
"""
|
||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||
return self._services.get(full_name) if full_name else None
|
||||
|
||||
def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
|
||||
"""获取插件服务处理函数。"""
|
||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||
return self._service_handlers.get(full_name) if full_name else None
|
||||
|
||||
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
||||
"""列出插件服务。"""
|
||||
services = self._services.copy()
|
||||
if plugin_name:
|
||||
services = {name: info for name, info in services.items() if info.plugin_name == plugin_name}
|
||||
if enabled_only:
|
||||
services = {name: info for name, info in services.items() if info.enabled}
|
||||
return services
|
||||
|
||||
def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
|
||||
"""启用插件服务。"""
|
||||
if not (service_info := self.get_service(service_name, plugin_name)):
|
||||
logger.warning(f"插件服务未注册,无法启用: {service_name}")
|
||||
return False
|
||||
service_info.enabled = True
|
||||
logger.info(f"插件服务已启用: {service_info.full_name}")
|
||||
return True
|
||||
|
||||
def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
|
||||
"""禁用插件服务。"""
|
||||
if not (service_info := self.get_service(service_name, plugin_name)):
|
||||
logger.warning(f"插件服务未注册,无法禁用: {service_name}")
|
||||
return False
|
||||
service_info.enabled = False
|
||||
logger.info(f"插件服务已禁用: {service_info.full_name}")
|
||||
return True
|
||||
|
||||
def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
|
||||
"""注销单个插件服务。"""
|
||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||
if not full_name:
|
||||
logger.warning(f"插件服务未注册,无法注销: {service_name}")
|
||||
return False
|
||||
|
||||
self._services.pop(full_name, None)
|
||||
self._service_handlers.pop(full_name, None)
|
||||
logger.info(f"插件服务已注销: {full_name}")
|
||||
return True
|
||||
|
||||
def remove_services_by_plugin(self, plugin_name: str) -> int:
|
||||
"""移除某插件的所有注册服务。"""
|
||||
target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name]
|
||||
for full_name in target_names:
|
||||
self._services.pop(full_name, None)
|
||||
self._service_handlers.pop(full_name, None)
|
||||
|
||||
removed_count = len(target_names)
|
||||
if removed_count:
|
||||
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
|
||||
return removed_count
|
||||
|
||||
async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
|
||||
"""调用插件服务(支持同步/异步handler)。"""
|
||||
service_info = self.get_service(service_name, plugin_name)
|
||||
if not service_info:
|
||||
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
||||
raise ValueError(f"插件服务未注册: {target_name}")
|
||||
if not service_info.enabled:
|
||||
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
|
||||
|
||||
handler = self.get_service_handler(service_name, plugin_name)
|
||||
if not handler:
|
||||
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
|
||||
|
||||
result = handler(*args, **kwargs)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
|
||||
"""解析服务全名。"""
|
||||
if "." in service_name:
|
||||
return service_name if service_name in self._services else None
|
||||
|
||||
if plugin_name:
|
||||
full_name = f"{plugin_name}.{service_name}"
|
||||
return full_name if full_name in self._services else None
|
||||
|
||||
candidates = [full_name for full_name, info in self._services.items() if info.name == service_name]
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"插件服务名称 '{service_name}' 存在多义,请传入plugin_name或使用完整服务名")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
plugin_service_registry = PluginServiceRegistry()
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
|
||||
import time
|
||||
import uuid
|
||||
import inspect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, MaiMessages
|
||||
from src.plugin_system.base.workflow_errors import WorkflowErrorCode
|
||||
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult
|
||||
|
||||
logger = get_logger("workflow_engine")
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
"""线性Workflow执行器(MVP)"""
|
||||
|
||||
STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [
|
||||
(WorkflowStage.INGRESS, "workflow.ingress"),
|
||||
(WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS),
|
||||
(WorkflowStage.PLAN, EventType.ON_PLAN),
|
||||
(WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"),
|
||||
(WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS),
|
||||
(WorkflowStage.EGRESS, "workflow.egress"),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._execution_history: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def execute_linear(
|
||||
self,
|
||||
dispatch_event: Callable[
|
||||
[Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]],
|
||||
Awaitable[Tuple[bool, Optional[MaiMessages]]],
|
||||
],
|
||||
message: Optional[MaiMessages] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
|
||||
"""执行线性workflow。"""
|
||||
workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id)
|
||||
current_message = message.deepcopy() if message else None
|
||||
self._execution_history[workflow_context.trace_id] = {
|
||||
"trace_id": workflow_context.trace_id,
|
||||
"stream_id": workflow_context.stream_id,
|
||||
"stages": [],
|
||||
"errors": [],
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
for stage, event_type in self.STAGE_EVENT_SEQUENCE:
|
||||
stage_key = str(stage)
|
||||
stage_start = time.perf_counter()
|
||||
try:
|
||||
should_continue, modified_message = await dispatch_event(
|
||||
event_type,
|
||||
current_message,
|
||||
workflow_context.stream_id,
|
||||
action_usage,
|
||||
)
|
||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
||||
self._execution_history[workflow_context.trace_id]["stages"].append(
|
||||
{
|
||||
"stage": stage_key,
|
||||
"event_type": str(event_type),
|
||||
"event_continue": should_continue,
|
||||
"event_cost": workflow_context.timings[stage_key],
|
||||
}
|
||||
)
|
||||
|
||||
if modified_message:
|
||||
current_message = modified_message
|
||||
|
||||
if not should_continue:
|
||||
logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断")
|
||||
return (
|
||||
WorkflowStepResult(
|
||||
status="stop",
|
||||
return_message=f"workflow stopped at stage {stage_key}",
|
||||
diagnostics={
|
||||
"stage": stage_key,
|
||||
"trace_id": workflow_context.trace_id,
|
||||
"error_code": WorkflowErrorCode.POLICY_BLOCKED.value,
|
||||
},
|
||||
),
|
||||
current_message,
|
||||
workflow_context,
|
||||
)
|
||||
|
||||
step_result = await self._execute_registered_steps(stage, workflow_context, current_message)
|
||||
if step_result.status in ["stop", "failed"]:
|
||||
self._execution_history[workflow_context.trace_id]["status"] = step_result.status
|
||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||
return step_result, current_message, workflow_context
|
||||
except Exception as e:
|
||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
||||
workflow_context.errors.append(f"{stage_key}: {e}")
|
||||
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
|
||||
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||
return (
|
||||
WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
diagnostics={
|
||||
"stage": stage_key,
|
||||
"trace_id": workflow_context.trace_id,
|
||||
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
|
||||
},
|
||||
),
|
||||
current_message,
|
||||
workflow_context,
|
||||
)
|
||||
|
||||
self._execution_history[workflow_context.trace_id]["status"] = "continue"
|
||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||
return (
|
||||
WorkflowStepResult(
|
||||
status="continue",
|
||||
return_message="workflow completed",
|
||||
diagnostics={"trace_id": workflow_context.trace_id},
|
||||
),
|
||||
current_message,
|
||||
workflow_context,
|
||||
)
|
||||
|
||||
async def _execute_registered_steps(
|
||||
self,
|
||||
stage: WorkflowStage,
|
||||
context: WorkflowContext,
|
||||
message: Optional[MaiMessages],
|
||||
) -> WorkflowStepResult:
|
||||
"""执行指定阶段已注册的workflow步骤。"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True)
|
||||
sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True)
|
||||
|
||||
for step_info in sorted_steps:
|
||||
handler = component_registry.get_workflow_step_handler(step_info.full_name, stage)
|
||||
if not handler:
|
||||
context.errors.append(f"{step_info.full_name}: handler not found")
|
||||
continue
|
||||
|
||||
step_timing_key = f"{stage.value}:{step_info.full_name}"
|
||||
step_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = handler(context, message)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
|
||||
normalized_result = self._normalize_step_result(result)
|
||||
if normalized_result.status == "continue":
|
||||
continue
|
||||
|
||||
normalized_result.diagnostics.setdefault("stage", stage.value)
|
||||
normalized_result.diagnostics.setdefault("step", step_info.full_name)
|
||||
normalized_result.diagnostics.setdefault("trace_id", context.trace_id)
|
||||
if normalized_result.status == "failed":
|
||||
context.errors.append(
|
||||
f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}"
|
||||
)
|
||||
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
|
||||
return normalized_result
|
||||
|
||||
except Exception as e:
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
context.errors.append(f"{step_info.full_name}: {e}")
|
||||
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
diagnostics={
|
||||
"stage": stage.value,
|
||||
"step": step_info.full_name,
|
||||
"trace_id": context.trace_id,
|
||||
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
|
||||
},
|
||||
)
|
||||
|
||||
return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id})
|
||||
|
||||
def _normalize_step_result(self, result: Any) -> WorkflowStepResult:
|
||||
"""归一化workflow step返回值。"""
|
||||
if isinstance(result, WorkflowStepResult):
|
||||
return result
|
||||
if isinstance(result, bool):
|
||||
if result:
|
||||
return WorkflowStepResult(status="continue")
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value},
|
||||
)
|
||||
if result is None:
|
||||
return WorkflowStepResult(status="continue")
|
||||
if isinstance(result, str):
|
||||
return WorkflowStepResult(status="continue", return_message=result)
|
||||
if isinstance(result, dict):
|
||||
status = result.get("status", "continue")
|
||||
if status not in ["continue", "stop", "failed"]:
|
||||
status = "failed"
|
||||
return WorkflowStepResult(
|
||||
status=status,
|
||||
return_message=result.get("return_message"),
|
||||
diagnostics=result.get("diagnostics", {}),
|
||||
events=result.get("events", []),
|
||||
)
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=f"unsupported step result type: {type(result)}",
|
||||
diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value},
|
||||
)
|
||||
|
||||
def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]:
|
||||
"""按trace_id获取workflow执行路径。"""
|
||||
trace = self._execution_history.get(trace_id)
|
||||
return trace.copy() if trace else None
|
||||
|
||||
def clear_execution_trace(self, trace_id: str) -> bool:
|
||||
"""清理trace执行记录。"""
|
||||
if trace_id in self._execution_history:
|
||||
del self._execution_history[trace_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
workflow_engine = WorkflowEngine()
|
||||
Loading…
Reference in New Issue