mirror of https://github.com/Mai-with-u/MaiBot.git
EventType回归
parent
f54022c17a
commit
316a2b6567
3
bot.py
3
bot.py
|
|
@ -8,6 +8,7 @@ import traceback
|
|||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
from rich.traceback import install
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env", override=True)
|
||||
|
|
@ -65,7 +66,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
|||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
# 触发 ON_STOP 事件
|
||||
await event_manager.trigger_event("on_stop")
|
||||
await event_manager.trigger_event(EventType.ON_STOP)
|
||||
|
||||
# 停止所有异步任务
|
||||
await async_task_manager.stop_and_wait_all_tasks()
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from src.chat.frequency_control.focus_value_control import focus_value_control
|
|||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ChatMode, ActionInfo
|
||||
from src.plugin_system.base.component_types import ChatMode, ActionInfo, EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
|
|
@ -440,7 +440,7 @@ class HeartFChatting:
|
|||
)
|
||||
|
||||
# 触发 ON_PLAN 事件
|
||||
result = await event_manager.trigger_event("on_plan", prompt=prompt_info[0],stream_id=self.chat_stream.stream_id)
|
||||
result = await event_manager.trigger_event(EventType.ON_PLAN, prompt=prompt_info[0],stream_id=self.chat_stream.stream_id)
|
||||
if result and not result.all_continue_process():
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ class ChatBot:
|
|||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
result = await event_manager.trigger_event("on_message",message=message)
|
||||
result = await event_manager.trigger_event(EventType.ON_MESSAGE,message=message)
|
||||
if not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers','')}于消息到达时取消了消息处理")
|
||||
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class DefaultReplyer:
|
|||
from src.plugin_system.core.event_manager import event_manager
|
||||
# 触发 POST_LLM 事件
|
||||
if not from_plugin:
|
||||
result = await event_manager.trigger_event("post_llm",prompt=prompt,llm_response=llm_response,stream_id=stream_id)
|
||||
result = await event_manager.trigger_event(EventType.POST_LLM,prompt=prompt,llm_response=llm_response,stream_id=stream_id)
|
||||
if not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于请求前中断了内容生成")
|
||||
|
||||
|
|
@ -226,7 +226,7 @@ class DefaultReplyer:
|
|||
|
||||
# 触发 AFTER_LLM 事件
|
||||
if not from_plugin:
|
||||
result = await event_manager.trigger_event("after_llm",prompt=prompt,llm_response=llm_response,stream_id=stream_id)
|
||||
result = await event_manager.trigger_event(EventType.AFTER_LLM,prompt=prompt,llm_response=llm_response,stream_id=stream_id)
|
||||
if not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get("stopped_handlers","")}于请求后取消了内容生成")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from src.chat.knowledge import lpmm_start_up
|
|||
from rich.traceback import install
|
||||
from src.migrate_helper.migrate import check_and_run_migrations
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器
|
||||
|
|
@ -126,7 +127,7 @@ class MainSystem:
|
|||
|
||||
|
||||
# 触发 ON_START 事件
|
||||
await event_manager.trigger_event("on_start")
|
||||
await event_manager.trigger_event(EventType.ON_START)
|
||||
# logger.info("已触发 ON_START 事件")
|
||||
try:
|
||||
init_time = int(1000 * (time.time() - init_start_time))
|
||||
|
|
|
|||
|
|
@ -35,9 +35,6 @@ from .utils import (
|
|||
# generate_plugin_manifest,
|
||||
)
|
||||
|
||||
from .core.event_manager import (
|
||||
event_manager
|
||||
)
|
||||
|
||||
from .apis import (
|
||||
chat_api,
|
||||
|
|
@ -102,8 +99,6 @@ __all__ = [
|
|||
# 工具函数
|
||||
"ManifestValidator",
|
||||
"get_logger",
|
||||
# 事件管理器
|
||||
"event_manager",
|
||||
# "ManifestGenerator",
|
||||
# "validate_plugin_manifest",
|
||||
# "generate_plugin_manifest",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Dict, List
|
||||
from typing import Tuple, Optional, Dict, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||
from .component_types import EventType, EventHandlerInfo, ComponentType
|
||||
|
||||
logger = get_logger("base_event_handler")
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ class BaseEventHandler(ABC):
|
|||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
init_subscribe: List[str] = []
|
||||
init_subscribe: List[Union[EventType, str]] = []
|
||||
"""初始化时订阅的事件名称"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -181,6 +181,25 @@ class EventInfo(ComponentInfo):
|
|||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
事件类型枚举类
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
AFTER_LLM = "after_llm"
|
||||
POST_SEND = "post_send"
|
||||
AFTER_SEND = "after_send"
|
||||
UNKNOWN = "unknown" # 未知事件类型
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@
|
|||
提供统一的事件注册、管理和触发接口
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, List, Optional, Any
|
||||
from typing import Dict, Type, List, Optional, Any, Union
|
||||
from threading import Lock
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
logger = get_logger("event_manager")
|
||||
|
||||
|
||||
|
|
@ -41,11 +41,11 @@ class EventManager:
|
|||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
def register_event(self, event_name: str) -> bool:
|
||||
def register_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""注册一个新的事件
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
bool: 注册成功返回True,已存在返回False
|
||||
|
|
@ -63,11 +63,11 @@ class EventManager:
|
|||
|
||||
return True
|
||||
|
||||
def get_event(self, event_name: str) -> Optional[BaseEvent]:
|
||||
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
|
||||
"""获取指定事件实例
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
BaseEvent: 事件实例,不存在返回None
|
||||
|
|
@ -98,11 +98,11 @@ class EventManager:
|
|||
"""
|
||||
return {name: event for name, event in self._events.items() if not event.enabled}
|
||||
|
||||
def enable_event(self, event_name: str) -> bool:
|
||||
def enable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""启用指定事件
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
bool: 成功返回True,事件不存在返回False
|
||||
|
|
@ -116,11 +116,11 @@ class EventManager:
|
|||
logger.info(f"事件 {event_name} 已启用")
|
||||
return True
|
||||
|
||||
def disable_event(self, event_name: str) -> bool:
|
||||
def disable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""禁用指定事件
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
bool: 成功返回True,事件不存在返回False
|
||||
|
|
@ -185,12 +185,12 @@ class EventManager:
|
|||
"""
|
||||
return self._event_handlers.copy()
|
||||
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: str) -> bool:
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
"""订阅事件处理器到指定事件
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
bool: 订阅成功返回True
|
||||
|
|
@ -217,12 +217,12 @@ class EventManager:
|
|||
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
||||
return True
|
||||
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: str) -> bool:
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
"""从指定事件取消订阅事件处理器
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
bool: 取消订阅成功返回True
|
||||
|
|
@ -247,11 +247,11 @@ class EventManager:
|
|||
|
||||
return removed
|
||||
|
||||
def get_event_subscribers(self, event_name: str) -> Dict[str, BaseEventHandler]:
|
||||
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
|
||||
"""获取订阅指定事件的所有事件处理器
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEventHandler]: 处理器字典,键为处理器名称,值为处理器实例
|
||||
|
|
@ -262,11 +262,11 @@ class EventManager:
|
|||
|
||||
return {handler.handler_name: handler for handler in event.subscribers}
|
||||
|
||||
async def trigger_event(self, event_name: str, **kwargs) -> Optional[HandlerResultsCollection]:
|
||||
async def trigger_event(self, event_name: Union[EventType, str], **kwargs) -> Optional[HandlerResultsCollection]:
|
||||
"""触发指定事件
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
**kwargs: 传递给处理器的参数
|
||||
|
||||
Returns:
|
||||
|
|
@ -284,14 +284,15 @@ class EventManager:
|
|||
def init_default_events(self) -> None:
|
||||
"""初始化默认事件"""
|
||||
default_events = [
|
||||
"on_start",
|
||||
"on_stop",
|
||||
"on_plan",
|
||||
"on_message",
|
||||
"post_llm",
|
||||
"after_llm",
|
||||
"post_send",
|
||||
"after_send"
|
||||
EventType.ON_START,
|
||||
EventType.ON_STOP,
|
||||
EventType.ON_PLAN,
|
||||
EventType.ON_MESSAGE,
|
||||
EventType.POST_LLM,
|
||||
EventType.AFTER_LLM,
|
||||
EventType.POST_SEND,
|
||||
EventType.AFTER_SEND,
|
||||
EventType.UNKNOWN
|
||||
]
|
||||
|
||||
for event_name in default_events:
|
||||
|
|
@ -324,11 +325,11 @@ class EventManager:
|
|||
"pending_subscriptions": len(self._pending_subscriptions)
|
||||
}
|
||||
|
||||
def _process_pending_subscriptions(self, event_name: str) -> None:
|
||||
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
|
||||
"""处理指定事件的缓存订阅
|
||||
|
||||
Args:
|
||||
event_name (str): 事件名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
"""
|
||||
handlers_to_remove = []
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue