mirror of https://github.com/Mai-with-u/MaiBot.git
events主体框架完成
parent
2cec83e0dc
commit
8a55e14aa4
2
bot.py
2
bot.py
|
|
@ -66,7 +66,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.plugin_system.base.component_types import EventType
|
||||||
# 触发 ON_STOP 事件
|
# 触发 ON_STOP 事件
|
||||||
_ = await events_manager.handle_mai_events(event_type=EventType.ON_STOP)
|
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)
|
||||||
|
|
||||||
# 停止所有异步任务
|
# 停止所有异步任务
|
||||||
await async_task_manager.stop_and_wait_all_tasks()
|
await async_task_manager.stop_and_wait_all_tasks()
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from src.plugin_system import (
|
||||||
BaseEventHandler,
|
BaseEventHandler,
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
ToolParamType
|
ToolParamType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -136,12 +136,12 @@ class PrintMessage(BaseEventHandler):
|
||||||
handler_name = "print_message_handler"
|
handler_name = "print_message_handler"
|
||||||
handler_description = "打印接收到的消息"
|
handler_description = "打印接收到的消息"
|
||||||
|
|
||||||
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
|
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None]:
|
||||||
"""执行打印消息事件处理"""
|
"""执行打印消息事件处理"""
|
||||||
# 打印接收到的消息
|
# 打印接收到的消息
|
||||||
if self.get_config("print_message.enabled", False):
|
if self.get_config("print_message.enabled", False):
|
||||||
print(f"接收到消息: {message.raw_message}")
|
print(f"接收到消息: {message.raw_message if message else '无效消息'}")
|
||||||
return True, True, "消息已打印"
|
return True, True, "消息已打印", None
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
# ===== 插件注册 =====
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from .base import (
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
ToolParamType,
|
ToolParamType,
|
||||||
|
CustomEventHandlerResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入工具模块
|
# 导入工具模块
|
||||||
|
|
@ -37,7 +38,7 @@ from .utils import (
|
||||||
|
|
||||||
from .apis import (
|
from .apis import (
|
||||||
chat_api,
|
chat_api,
|
||||||
tool_api,
|
tool_api,
|
||||||
component_manage_api,
|
component_manage_api,
|
||||||
config_api,
|
config_api,
|
||||||
database_api,
|
database_api,
|
||||||
|
|
@ -92,6 +93,7 @@ __all__ = [
|
||||||
"ToolParamType",
|
"ToolParamType",
|
||||||
# 消息
|
# 消息
|
||||||
"MaiMessages",
|
"MaiMessages",
|
||||||
|
"CustomEventHandlerResult",
|
||||||
# 装饰器
|
# 装饰器
|
||||||
"register_plugin",
|
"register_plugin",
|
||||||
"ConfigField",
|
"ConfigField",
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from .component_types import (
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
ToolParamType,
|
ToolParamType,
|
||||||
|
CustomEventHandlerResult,
|
||||||
)
|
)
|
||||||
from .config_types import ConfigField
|
from .config_types import ConfigField
|
||||||
|
|
||||||
|
|
@ -46,4 +47,5 @@ __all__ = [
|
||||||
"BaseEventHandler",
|
"BaseEventHandler",
|
||||||
"MaiMessages",
|
"MaiMessages",
|
||||||
"ToolParamType",
|
"ToolParamType",
|
||||||
|
"CustomEventHandlerResult",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
from typing import TYPE_CHECKING, List, Type
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import EventType, MaiMessages
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .base_events_handler import BaseEventHandler
|
|
||||||
|
|
||||||
logger = get_logger("base_event")
|
|
||||||
|
|
||||||
class BaseEvent:
|
|
||||||
def __init__(self, event_type: EventType | str) -> None:
|
|
||||||
self.event_type = event_type
|
|
||||||
self.subscribers: List["BaseEventHandler"] = []
|
|
||||||
|
|
||||||
def register_handler_to_event(self, handler: "BaseEventHandler") -> bool:
|
|
||||||
if handler not in self.subscribers:
|
|
||||||
self.subscribers.append(handler)
|
|
||||||
return True
|
|
||||||
logger.warning(f"Handler {handler.handler_name} 已经注册,不可多次注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def remove_handler_from_event(self, handler_class: Type["BaseEventHandler"]) -> bool:
|
|
||||||
for handler in self.subscribers:
|
|
||||||
if isinstance(handler, handler_class):
|
|
||||||
self.subscribers.remove(handler)
|
|
||||||
return True
|
|
||||||
logger.warning(f"Handler {handler_class.__name__} 未注册,无法移除")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def trigger_event(self, message: MaiMessages):
|
|
||||||
copied_message = message.deepcopy()
|
|
||||||
for handler in self.subscribers:
|
|
||||||
result = handler.execute(copied_message)
|
|
||||||
|
|
||||||
# TODO: Unfinished Events Handler
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple, Optional, Dict
|
from typing import Tuple, Optional, Dict, List
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType, CustomEventHandlerResult
|
||||||
|
|
||||||
logger = get_logger("base_event_handler")
|
logger = get_logger("base_event_handler")
|
||||||
|
|
||||||
|
|
@ -30,16 +30,19 @@ class BaseEventHandler(ABC):
|
||||||
"""对应插件名"""
|
"""对应插件名"""
|
||||||
self.plugin_config: Optional[Dict] = None
|
self.plugin_config: Optional[Dict] = None
|
||||||
"""插件配置字典"""
|
"""插件配置字典"""
|
||||||
|
self._events_subscribed: List[EventType | str] = []
|
||||||
if self.event_type == EventType.UNKNOWN:
|
if self.event_type == EventType.UNKNOWN:
|
||||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, Optional[str]]:
|
async def execute(
|
||||||
|
self, message: MaiMessages | None
|
||||||
|
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]:
|
||||||
"""执行事件处理的抽象方法,子类必须实现
|
"""执行事件处理的抽象方法,子类必须实现
|
||||||
Args:
|
Args:
|
||||||
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
|
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("子类必须实现 execute 方法")
|
raise NotImplementedError("子类必须实现 execute 方法")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -285,3 +285,9 @@ class MaiMessages:
|
||||||
|
|
||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CustomEventHandlerResult:
|
||||||
|
message: str = ""
|
||||||
|
timestamp: float = 0.0
|
||||||
|
extra_info: Optional[Dict] = None
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
from .global_announcement_manager import global_announcement_manager
|
from .global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
|
|
@ -18,9 +18,23 @@ logger = get_logger("events_manager")
|
||||||
class EventsManager:
|
class EventsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 有权重的 events 订阅者注册表
|
# 有权重的 events 订阅者注册表
|
||||||
self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {event: [] for event in EventType}
|
self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {}
|
||||||
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||||
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
||||||
|
self._events_result_history: Dict[EventType | str, List[CustomEventHandlerResult]] = {} # 事件的结果历史记录
|
||||||
|
self._history_enable_map: Dict[EventType | str, bool] = {} # 是否启用历史记录的映射表,同时作为events注册表
|
||||||
|
|
||||||
|
# 事件注册(同时作为注册样例)
|
||||||
|
for event in EventType:
|
||||||
|
self.register_event(event, enable_history_result=False)
|
||||||
|
|
||||||
|
def register_event(self, event_type: EventType | str, enable_history_result: bool = False):
|
||||||
|
if event_type in self._events_subscribers:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 已存在")
|
||||||
|
self._events_subscribers[event_type] = []
|
||||||
|
self._history_enable_map[event_type] = enable_history_result
|
||||||
|
if enable_history_result:
|
||||||
|
self._events_result_history[event_type] = []
|
||||||
|
|
||||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
||||||
"""注册事件处理器
|
"""注册事件处理器
|
||||||
|
|
@ -32,69 +46,23 @@ class EventsManager:
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否注册成功
|
bool: 是否注册成功
|
||||||
"""
|
"""
|
||||||
|
if not issubclass(handler_class, BaseEventHandler):
|
||||||
|
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||||
|
return False
|
||||||
|
|
||||||
handler_name = handler_info.name
|
handler_name = handler_info.name
|
||||||
|
|
||||||
if handler_name in self._handler_mapping:
|
if handler_name in self._handler_mapping:
|
||||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not issubclass(handler_class, BaseEventHandler):
|
if handler_info.event_type not in self._history_enable_map:
|
||||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._handler_mapping[handler_name] = handler_class
|
self._handler_mapping[handler_name] = handler_class
|
||||||
return self._insert_event_handler(handler_class, handler_info)
|
return self._insert_event_handler(handler_class, handler_info)
|
||||||
|
|
||||||
def _prepare_message(
|
|
||||||
self,
|
|
||||||
event_type: EventType,
|
|
||||||
message: Optional[MessageRecv] = None,
|
|
||||||
llm_prompt: Optional[str] = None,
|
|
||||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
) -> Optional[MaiMessages]:
|
|
||||||
"""根据事件类型和输入,准备和转换消息对象。"""
|
|
||||||
if message:
|
|
||||||
return self._transform_event_message(message, llm_prompt, llm_response)
|
|
||||||
|
|
||||||
if event_type not in [EventType.ON_START, EventType.ON_STOP]:
|
|
||||||
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
|
|
||||||
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
|
|
||||||
return self._build_message_from_stream(stream_id, llm_prompt, llm_response)
|
|
||||||
else:
|
|
||||||
return self._transform_event_without_message(stream_id, llm_prompt, llm_response, action_usage)
|
|
||||||
|
|
||||||
return None # ON_START, ON_STOP事件没有消息体
|
|
||||||
|
|
||||||
def _dispatch_handler_task(self, handler: BaseEventHandler, message: Optional[MaiMessages]):
|
|
||||||
"""分发一个非阻塞(异步)的事件处理任务。"""
|
|
||||||
try:
|
|
||||||
task = asyncio.create_task(handler.execute(message))
|
|
||||||
|
|
||||||
task_name = f"{handler.plugin_name}-{handler.handler_name}"
|
|
||||||
task.set_name(task_name)
|
|
||||||
task.add_done_callback(self._task_done_callback)
|
|
||||||
|
|
||||||
self._handler_tasks.setdefault(handler.handler_name, []).append(task)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
|
||||||
|
|
||||||
async def _dispatch_intercepting_handler(self, handler: BaseEventHandler, message: Optional[MaiMessages]) -> bool:
|
|
||||||
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
|
||||||
try:
|
|
||||||
success, continue_processing, result = await handler.execute(message)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
|
|
||||||
|
|
||||||
return continue_processing
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
|
||||||
return True # 发生异常时默认不中断其他处理
|
|
||||||
|
|
||||||
async def handle_mai_events(
|
async def handle_mai_events(
|
||||||
self,
|
self,
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
|
|
@ -115,6 +83,8 @@ class EventsManager:
|
||||||
transformed_message = self._prepare_message(
|
transformed_message = self._prepare_message(
|
||||||
event_type, message, llm_prompt, llm_response, stream_id, action_usage
|
event_type, message, llm_prompt, llm_response, stream_id, action_usage
|
||||||
)
|
)
|
||||||
|
if transformed_message:
|
||||||
|
transformed_message = transformed_message.deepcopy()
|
||||||
|
|
||||||
# 2. 获取并遍历处理器
|
# 2. 获取并遍历处理器
|
||||||
handlers = self._events_subscribers.get(event_type, [])
|
handlers = self._events_subscribers.get(event_type, [])
|
||||||
|
|
@ -137,16 +107,68 @@ class EventsManager:
|
||||||
handler.set_plugin_config(plugin_config)
|
handler.set_plugin_config(plugin_config)
|
||||||
|
|
||||||
# 4. 根据类型分发任务
|
# 4. 根据类型分发任务
|
||||||
if handler.intercept_message:
|
if handler.intercept_message or event_type == EventType.ON_STOP: # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
||||||
# 阻塞执行,并更新 continue_flag
|
# 阻塞执行,并更新 continue_flag
|
||||||
should_continue = await self._dispatch_intercepting_handler(handler, transformed_message)
|
should_continue = await self._dispatch_intercepting_handler(handler, event_type, transformed_message)
|
||||||
continue_flag = continue_flag and should_continue
|
continue_flag = continue_flag and should_continue
|
||||||
else:
|
else:
|
||||||
# 异步执行,不阻塞
|
# 异步执行,不阻塞
|
||||||
self._dispatch_handler_task(handler, transformed_message)
|
self._dispatch_handler_task(handler, event_type, transformed_message)
|
||||||
|
|
||||||
return continue_flag
|
return continue_flag
|
||||||
|
|
||||||
|
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||||
|
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||||
|
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
|
||||||
|
for task in remaining_tasks:
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
||||||
|
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
||||||
|
if handler_name in self._handler_tasks:
|
||||||
|
del self._handler_tasks[handler_name]
|
||||||
|
|
||||||
|
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
||||||
|
"""取消注册事件处理器"""
|
||||||
|
if handler_name not in self._handler_mapping:
|
||||||
|
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self.cancel_handler_tasks(handler_name)
|
||||||
|
|
||||||
|
handler_class = self._handler_mapping.pop(handler_name)
|
||||||
|
if not self._remove_event_handler_instance(handler_class):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_event_result_history(self, event_type: EventType | str) -> List[CustomEventHandlerResult]:
|
||||||
|
"""获取事件的结果历史记录"""
|
||||||
|
if event_type == EventType.UNKNOWN:
|
||||||
|
raise ValueError("未知事件类型")
|
||||||
|
if event_type not in self._history_enable_map:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||||
|
if not self._history_enable_map[event_type]:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
|
||||||
|
|
||||||
|
return self._events_result_history[event_type]
|
||||||
|
|
||||||
|
async def clear_event_result_history(self, event_type: EventType | str) -> None:
|
||||||
|
"""清空事件的结果历史记录"""
|
||||||
|
if event_type == EventType.UNKNOWN:
|
||||||
|
raise ValueError("未知事件类型")
|
||||||
|
if event_type not in self._history_enable_map:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||||
|
if not self._history_enable_map[event_type]:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
|
||||||
|
|
||||||
|
self._events_result_history[event_type] = []
|
||||||
|
|
||||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
||||||
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
if handler_class.event_type == EventType.UNKNOWN:
|
||||||
|
|
@ -179,7 +201,10 @@ class EventsManager:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _transform_event_message(
|
def _transform_event_message(
|
||||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
|
self,
|
||||||
|
message: MessageRecv,
|
||||||
|
llm_prompt: Optional[str] = None,
|
||||||
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
) -> MaiMessages:
|
) -> MaiMessages:
|
||||||
"""转换事件消息格式"""
|
"""转换事件消息格式"""
|
||||||
# 直接赋值部分内容
|
# 直接赋值部分内容
|
||||||
|
|
@ -263,52 +288,100 @@ class EventsManager:
|
||||||
additional_data={"response_is_processed": True},
|
additional_data={"response_is_processed": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
|
def _prepare_message(
|
||||||
|
self,
|
||||||
|
event_type: EventType,
|
||||||
|
message: Optional[MessageRecv] = None,
|
||||||
|
llm_prompt: Optional[str] = None,
|
||||||
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
|
stream_id: Optional[str] = None,
|
||||||
|
action_usage: Optional[List[str]] = None,
|
||||||
|
) -> Optional[MaiMessages]:
|
||||||
|
"""根据事件类型和输入,准备和转换消息对象。"""
|
||||||
|
if message:
|
||||||
|
return self._transform_event_message(message, llm_prompt, llm_response)
|
||||||
|
|
||||||
|
if event_type not in [EventType.ON_START, EventType.ON_STOP]:
|
||||||
|
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
|
||||||
|
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
|
||||||
|
return self._build_message_from_stream(stream_id, llm_prompt, llm_response)
|
||||||
|
else:
|
||||||
|
return self._transform_event_without_message(stream_id, llm_prompt, llm_response, action_usage)
|
||||||
|
|
||||||
|
return None # ON_START, ON_STOP事件没有消息体
|
||||||
|
|
||||||
|
def _dispatch_handler_task(
|
||||||
|
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
||||||
|
):
|
||||||
|
"""分发一个非阻塞(异步)的事件处理任务。"""
|
||||||
|
if event_type == EventType.UNKNOWN:
|
||||||
|
raise ValueError("未知事件类型")
|
||||||
|
try:
|
||||||
|
task = asyncio.create_task(handler.execute(message))
|
||||||
|
|
||||||
|
task_name = f"{handler.plugin_name}-{handler.handler_name}"
|
||||||
|
task.set_name(task_name)
|
||||||
|
task.add_done_callback(lambda t: self._task_done_callback(t, event_type))
|
||||||
|
|
||||||
|
self._handler_tasks.setdefault(handler.handler_name, []).append(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _dispatch_intercepting_handler(
|
||||||
|
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
||||||
|
) -> bool:
|
||||||
|
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
||||||
|
if event_type == EventType.UNKNOWN:
|
||||||
|
raise ValueError("未知事件类型")
|
||||||
|
if event_type not in self._history_enable_map:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||||
|
try:
|
||||||
|
success, continue_processing, return_message, custom_result = await handler.execute(message)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {return_message}")
|
||||||
|
|
||||||
|
if self._history_enable_map[event_type] and custom_result:
|
||||||
|
self._events_result_history[event_type].append(custom_result)
|
||||||
|
return continue_processing
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
||||||
|
return True # 发生异常时默认不中断其他处理
|
||||||
|
|
||||||
|
def _task_done_callback(
|
||||||
|
self,
|
||||||
|
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None]],
|
||||||
|
event_type: EventType | str,
|
||||||
|
):
|
||||||
"""任务完成回调"""
|
"""任务完成回调"""
|
||||||
task_name = task.get_name() or "Unknown Task"
|
task_name = task.get_name() or "Unknown Task"
|
||||||
|
if event_type == EventType.UNKNOWN:
|
||||||
|
raise ValueError("未知事件类型")
|
||||||
|
if event_type not in self._history_enable_map:
|
||||||
|
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||||
try:
|
try:
|
||||||
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
success, _, result, custom_result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||||
if success:
|
if success:
|
||||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
||||||
|
|
||||||
|
if self._history_enable_map[event_type] and custom_result:
|
||||||
|
self._events_result_history[event_type].append(custom_result)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
||||||
finally:
|
finally:
|
||||||
with contextlib.suppress(ValueError, KeyError):
|
with contextlib.suppress(ValueError, KeyError):
|
||||||
self._handler_tasks[task_name].remove(task)
|
self._handler_tasks[task_name].remove(task)
|
||||||
|
|
||||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
|
||||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
|
||||||
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
|
|
||||||
for task in remaining_tasks:
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
|
||||||
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
|
||||||
if handler_name in self._handler_tasks:
|
|
||||||
del self._handler_tasks[handler_name]
|
|
||||||
|
|
||||||
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
|
||||||
"""取消注册事件处理器"""
|
|
||||||
if handler_name not in self._handler_mapping:
|
|
||||||
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
await self.cancel_handler_tasks(handler_name)
|
|
||||||
|
|
||||||
handler_class = self._handler_mapping.pop(handler_name)
|
|
||||||
if not self._remove_event_handler_instance(handler_class):
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
events_manager = EventsManager()
|
events_manager = EventsManager()
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
- [x] 自定义事件
|
- [x] 自定义事件
|
||||||
- [ ] <del>允许handler随时订阅</del>
|
- [ ] <del>允许handler随时订阅</del>
|
||||||
- [ ] 允许handler随时取消订阅
|
- [x] 允许其他组件给handler增加订阅
|
||||||
- [ ] 允许其他组件给handler增加订阅
|
- [x] 允许其他组件给handler取消订阅
|
||||||
- [ ] 允许其他组件给handler取消订阅
|
|
||||||
- [ ] <del>允许一个handler订阅多个事件</del>
|
- [ ] <del>允许一个handler订阅多个事件</del>
|
||||||
- [ ] event激活时给handler传递参数
|
- [x] event激活时给handler传递参数
|
||||||
- [ ] handler能拿到所有handlers的结果(按照处理权重)
|
- [ ] handler能拿到所有handlers的结果(按照处理权重)
|
||||||
- [x] 随时注册
|
- [x] 随时注册
|
||||||
- [ ] 删除event
|
- [ ] <del>删除event</del>
|
||||||
- [ ] 必要性?
|
- [ ] 必要性?
|
||||||
- [ ] 能够更改prompt
|
- [ ] 能够更改prompt
|
||||||
- [ ] 能够更改llm_response
|
- [ ] 能够更改llm_response
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue