diff --git a/bot.py b/bot.py index 83e3cea8..d2c165e8 100644 --- a/bot.py +++ b/bot.py @@ -8,6 +8,7 @@ import traceback from dotenv import load_dotenv from pathlib import Path from rich.traceback import install +from src.plugin_system.base.component_types import EventType if os.path.exists(".env"): load_dotenv(".env", override=True) @@ -65,7 +66,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression from src.plugin_system.core.event_manager import event_manager # 触发 ON_STOP 事件 - await event_manager.trigger_event("on_stop") + await event_manager.trigger_event(EventType.ON_STOP) # 停止所有异步任务 await async_task_manager.stop_and_wait_all_tasks() diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 5e37d590..a23139ec 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -24,7 +24,7 @@ from src.chat.frequency_control.focus_value_control import focus_value_control from src.chat.express.expression_learner import expression_learner_manager from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.person_info import Person -from src.plugin_system.base.component_types import ChatMode, ActionInfo +from src.plugin_system.base.component_types import ChatMode, ActionInfo, EventType from src.plugin_system.core.event_manager import event_manager from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.mais4u.mai_think import mai_thinking_manager @@ -440,7 +440,7 @@ class HeartFChatting: ) # 触发 ON_PLAN 事件 - result = await event_manager.trigger_event("on_plan", prompt=prompt_info[0],stream_id=self.chat_stream.stream_id) + result = await event_manager.trigger_event(EventType.ON_PLAN, prompt=prompt_info[0],stream_id=self.chat_stream.stream_id) if result and not result.all_continue_process(): return diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 80dd7d26..1d1d74b1 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -258,7 +258,7 @@ class ChatBot: logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") return - result = await event_manager.trigger_event("on_message",message=message) + result = await event_manager.trigger_event(EventType.ON_MESSAGE,message=message) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers','')}于消息到达时取消了消息处理") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index eec41e32..e61db7a6 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -207,7 +207,7 @@ class DefaultReplyer: from src.plugin_system.core.event_manager import event_manager # 触发 POST_LLM 事件 if not from_plugin: - result = await event_manager.trigger_event("post_llm",prompt=prompt,llm_response=llm_response,stream_id=stream_id) + result = await event_manager.trigger_event(EventType.POST_LLM,prompt=prompt,llm_response=llm_response,stream_id=stream_id) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于请求前中断了内容生成") @@ -226,7 +226,7 @@ class DefaultReplyer: # 触发 AFTER_LLM 事件 if not from_plugin: - result = await event_manager.trigger_event("after_llm",prompt=prompt,llm_response=llm_response,stream_id=stream_id) + result = await event_manager.trigger_event(EventType.AFTER_LLM,prompt=prompt,llm_response=llm_response,stream_id=stream_id) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get("stopped_handlers","")}于请求后取消了内容生成") diff --git a/src/main.py b/src/main.py index 86e08672..0d28f377 100644 --- a/src/main.py +++ b/src/main.py @@ -16,6 +16,7 @@ from src.chat.knowledge import lpmm_start_up from rich.traceback import install from src.migrate_helper.migrate import check_and_run_migrations from src.plugin_system.core.event_manager import event_manager +from src.plugin_system.base.component_types import EventType # from src.api.main import start_api_server # 导入新的插件管理器 @@ -126,7 +127,7 @@ class MainSystem: # 触发 ON_START 事件 - await event_manager.trigger_event("on_start") + await event_manager.trigger_event(EventType.ON_START) # logger.info("已触发 ON_START 事件") try: init_time = int(1000 * (time.time() - init_start_time)) diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index a995c0cd..bf3e43cb 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -35,9 +35,6 @@ from .utils import ( # generate_plugin_manifest, ) -from .core.event_manager import ( - event_manager -) from .apis import ( chat_api, @@ -102,8 +99,6 @@ __all__ = [ # 工具函数 "ManifestValidator", "get_logger", - # 事件管理器 - "event_manager", # "ManifestGenerator", # "validate_plugin_manifest", # "generate_plugin_manifest", diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index db520b22..21ddd83d 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -from typing import Tuple, Optional, Dict, List +from typing import Tuple, Optional, Dict, List, Union from src.common.logger import get_logger -from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType +from .component_types import EventType, EventHandlerInfo, ComponentType logger = get_logger("base_event_handler") @@ -21,7 +21,7 @@ class BaseEventHandler(ABC): """处理器权重,越大权重越高""" intercept_message: bool = False """是否拦截消息,默认为否""" - init_subscribe: List[str] = [] + init_subscribe: List[Union[EventType, str]] = [] """初始化时订阅的事件名称""" def __init__(self): diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 362b55fe..704703b2 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -181,6 +181,25 @@ class EventInfo(ComponentInfo): super().__post_init__() self.component_type = ComponentType.EVENT +# 事件类型枚举 +class EventType(Enum): + """ + 事件类型枚举类 + """ + + ON_START = "on_start" # 启动事件,用于调用按时任务 + ON_STOP = "on_stop" # 停止事件,用于调用按时任务 + ON_MESSAGE = "on_message" + ON_PLAN = "on_plan" + POST_LLM = "post_llm" + AFTER_LLM = "after_llm" + POST_SEND = "post_send" + AFTER_SEND = "after_send" + UNKNOWN = "unknown" # 未知事件类型 + + def __str__(self) -> str: + return self.value + @dataclass class PluginInfo: """插件信息""" diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index d51a69e5..1af2d797 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -3,13 +3,13 @@ 提供统一的事件注册、管理和触发接口 """ -from typing import Dict, Type, List, Optional, Any +from typing import Dict, Type, List, Optional, Any, Union from threading import Lock from src.common.logger import get_logger from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection from src.plugin_system.base.base_events_handler import BaseEventHandler - +from src.plugin_system.base.component_types import EventType logger = get_logger("event_manager") @@ -41,11 +41,11 @@ class EventManager: self._initialized = True logger.info("EventManager 单例初始化完成") - def register_event(self, event_name: str) -> bool: + def register_event(self, event_name: Union[EventType, str]) -> bool: """注册一个新的事件 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: bool: 注册成功返回True,已存在返回False @@ -63,11 +63,11 @@ class EventManager: return True - def get_event(self, event_name: str) -> Optional[BaseEvent]: + def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]: """获取指定事件实例 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: BaseEvent: 事件实例,不存在返回None @@ -98,11 +98,11 @@ class EventManager: """ return {name: event for name, event in self._events.items() if not event.enabled} - def enable_event(self, event_name: str) -> bool: + def enable_event(self, event_name: Union[EventType, str]) -> bool: """启用指定事件 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: bool: 成功返回True,事件不存在返回False @@ -116,11 +116,11 @@ class EventManager: logger.info(f"事件 {event_name} 已启用") return True - def disable_event(self, event_name: str) -> bool: + def disable_event(self, event_name: Union[EventType, str]) -> bool: """禁用指定事件 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: bool: 成功返回True,事件不存在返回False @@ -185,12 +185,12 @@ class EventManager: """ return self._event_handlers.copy() - def subscribe_handler_to_event(self, handler_name: str, event_name: str) -> bool: + def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool: """订阅事件处理器到指定事件 Args: handler_name (str): 处理器名称 - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: bool: 订阅成功返回True @@ -217,12 +217,12 @@ class EventManager: logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成") return True - def unsubscribe_handler_from_event(self, handler_name: str, event_name: str) -> bool: + def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool: """从指定事件取消订阅事件处理器 Args: handler_name (str): 处理器名称 - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: bool: 取消订阅成功返回True @@ -247,11 +247,11 @@ class EventManager: return removed - def get_event_subscribers(self, event_name: str) -> Dict[str, BaseEventHandler]: + def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]: """获取订阅指定事件的所有事件处理器 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 Returns: Dict[str, BaseEventHandler]: 处理器字典,键为处理器名称,值为处理器实例 @@ -262,11 +262,11 @@ class EventManager: return {handler.handler_name: handler for handler in event.subscribers} - async def trigger_event(self, event_name: str, **kwargs) -> Optional[HandlerResultsCollection]: + async def trigger_event(self, event_name: Union[EventType, str], **kwargs) -> Optional[HandlerResultsCollection]: """触发指定事件 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 **kwargs: 传递给处理器的参数 Returns: @@ -284,14 +284,15 @@ class EventManager: def init_default_events(self) -> None: """初始化默认事件""" default_events = [ - "on_start", - "on_stop", - "on_plan", - "on_message", - "post_llm", - "after_llm", - "post_send", - "after_send" + EventType.ON_START, + EventType.ON_STOP, + EventType.ON_PLAN, + EventType.ON_MESSAGE, + EventType.POST_LLM, + EventType.AFTER_LLM, + EventType.POST_SEND, + EventType.AFTER_SEND, + EventType.UNKNOWN ] for event_name in default_events: @@ -324,11 +325,11 @@ class EventManager: "pending_subscriptions": len(self._pending_subscriptions) } - def _process_pending_subscriptions(self, event_name: str) -> None: + def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None: """处理指定事件的缓存订阅 Args: - event_name (str): 事件名称 + event_name Union[EventType, str]: 事件名称 """ handlers_to_remove = []