mirror of https://github.com/Mai-with-u/MaiBot.git
更改改改events系统
parent
d2f98145da
commit
80783439b1
|
|
@ -0,0 +1,38 @@
|
||||||
|
from typing import TYPE_CHECKING, List, Type
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.base.component_types import EventType, MaiMessages
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .base_events_handler import BaseEventHandler
|
||||||
|
|
||||||
|
logger = get_logger("base_event")
|
||||||
|
|
||||||
|
class BaseEvent:
|
||||||
|
def __init__(self, event_type: EventType | str) -> None:
|
||||||
|
self.event_type = event_type
|
||||||
|
self.subscribers: List["BaseEventHandler"] = []
|
||||||
|
|
||||||
|
def register_handler_to_event(self, handler: "BaseEventHandler") -> bool:
|
||||||
|
if handler not in self.subscribers:
|
||||||
|
self.subscribers.append(handler)
|
||||||
|
return True
|
||||||
|
logger.warning(f"Handler {handler.handler_name} 已经注册,不可多次注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_handler_from_event(self, handler_class: Type["BaseEventHandler"]) -> bool:
|
||||||
|
for handler in self.subscribers:
|
||||||
|
if isinstance(handler, handler_class):
|
||||||
|
self.subscribers.remove(handler)
|
||||||
|
return True
|
||||||
|
logger.warning(f"Handler {handler_class.__name__} 未注册,无法移除")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trigger_event(self, message: MaiMessages):
|
||||||
|
copied_message = message.deepcopy()
|
||||||
|
for handler in self.subscribers:
|
||||||
|
result = handler.execute(copied_message)
|
||||||
|
|
||||||
|
# TODO: Unfinished Events Handler
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -13,7 +13,7 @@ class BaseEventHandler(ABC):
|
||||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event_type: EventType = EventType.UNKNOWN
|
event_type: EventType | str = EventType.UNKNOWN
|
||||||
"""事件类型,默认为未知"""
|
"""事件类型,默认为未知"""
|
||||||
handler_name: str = ""
|
handler_name: str = ""
|
||||||
"""处理器名称"""
|
"""处理器名称"""
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -165,7 +166,7 @@ class ToolInfo(ComponentInfo):
|
||||||
class EventHandlerInfo(ComponentInfo):
|
class EventHandlerInfo(ComponentInfo):
|
||||||
"""事件处理器组件信息"""
|
"""事件处理器组件信息"""
|
||||||
|
|
||||||
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
|
event_type: EventType | str = EventType.ON_MESSAGE # 监听事件类型
|
||||||
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
|
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
|
||||||
weight: int = 0 # 事件处理器权重,决定执行顺序
|
weight: int = 0 # 事件处理器权重,决定执行顺序
|
||||||
|
|
||||||
|
|
@ -281,3 +282,6 @@ class MaiMessages:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.message_segments is None:
|
if self.message_segments is None:
|
||||||
self.message_segments = []
|
self.message_segments = []
|
||||||
|
|
||||||
|
def deepcopy(self):
|
||||||
|
return copy.deepcopy(self)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import List, Dict, Optional, Type, Tuple, Any, TYPE_CHECKING
|
from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
@ -18,7 +18,7 @@ logger = get_logger("events_manager")
|
||||||
class EventsManager:
|
class EventsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 有权重的 events 订阅者注册表
|
# 有权重的 events 订阅者注册表
|
||||||
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||||
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||||
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
||||||
|
|
||||||
|
|
@ -152,7 +152,8 @@ class EventsManager:
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
if handler_class.event_type == EventType.UNKNOWN:
|
||||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
||||||
return False
|
return False
|
||||||
|
if handler_class.event_type not in self._events_subscribers:
|
||||||
|
self._events_subscribers[handler_class.event_type] = []
|
||||||
handler_instance = handler_class()
|
handler_instance = handler_class()
|
||||||
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
||||||
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
- [x] 自定义事件
|
||||||
|
- [ ] <del>允许handler随时订阅</del>
|
||||||
|
- [ ] 允许handler随时取消订阅
|
||||||
|
- [ ] 允许其他组件给handler增加订阅
|
||||||
|
- [ ] 允许其他组件给handler取消订阅
|
||||||
|
- [ ] <del>允许一个handler订阅多个事件</del>
|
||||||
|
- [ ] event激活时给handler传递参数
|
||||||
|
- [ ] handler能拿到所有handlers的结果(按照处理权重)
|
||||||
|
- [x] 随时注册
|
||||||
|
- [ ] 删除event
|
||||||
|
- [ ] 必要性?
|
||||||
|
- [ ] 能够更改prompt
|
||||||
|
- [ ] 能够更改llm_response
|
||||||
|
- [ ] 能够更改message
|
||||||
Loading…
Reference in New Issue