mirror of https://github.com/Mai-with-u/MaiBot.git
545 lines
20 KiB
Python
545 lines
20 KiB
Python
import time
|
||
import asyncio
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
|
||
|
||
from src.common.logger import get_logger
|
||
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
|
||
from src.chat.message_receive.chat_stream import ChatStream
|
||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
|
||
from src.plugin_system.apis import send_api, database_api, message_api
|
||
|
||
if TYPE_CHECKING:
|
||
from src.common.data_models.database_data_model import DatabaseMessages
|
||
|
||
logger = get_logger("base_action")
|
||
|
||
|
||
class BaseAction(ABC):
|
||
"""Action组件基类
|
||
|
||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
||
|
||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
||
- focus_activation_type: 专注模式激活类型
|
||
- normal_activation_type: 普通模式激活类型
|
||
- activation_keywords: 激活关键词列表
|
||
- keyword_case_sensitive: 关键词是否区分大小写
|
||
- parallel_action: 是否允许并行执行
|
||
- random_activation_probability: 随机激活概率
|
||
- llm_judge_prompt: LLM判断提示词
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
action_data: dict,
|
||
action_reasoning: str,
|
||
cycle_timers: dict,
|
||
thinking_id: str,
|
||
chat_stream: ChatStream,
|
||
plugin_config: Optional[dict] = None,
|
||
action_message: Optional["DatabaseMessages"] = None,
|
||
**kwargs,
|
||
):
|
||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
||
"""初始化Action组件
|
||
|
||
Args:
|
||
action_data: 动作数据
|
||
reasoning: 执行该动作的理由
|
||
cycle_timers: 计时器字典
|
||
thinking_id: 思考ID
|
||
chat_stream: 聊天流对象
|
||
log_prefix: 日志前缀
|
||
plugin_config: 插件配置字典
|
||
action_message: 消息数据
|
||
**kwargs: 其他参数
|
||
"""
|
||
if plugin_config is None:
|
||
plugin_config = {}
|
||
self.action_data = action_data
|
||
self.reasoning = ""
|
||
self.cycle_timers = cycle_timers
|
||
self.thinking_id = thinking_id
|
||
|
||
self.action_reasoning = action_reasoning
|
||
|
||
self.plugin_config = plugin_config or {}
|
||
"""对应的插件配置"""
|
||
|
||
# 设置动作基本信息实例属性
|
||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||
"""Action的名字"""
|
||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||
"""Action的描述"""
|
||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||
|
||
"""NORMAL模式下的激活类型"""
|
||
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
|
||
"""激活类型"""
|
||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||
"""当激活类型为RANDOM时的概率"""
|
||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "") # 已弃用
|
||
"""协助LLM进行判断的Prompt"""
|
||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||
|
||
# =============================================================================
|
||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||
# =============================================================================
|
||
|
||
# 获取聊天流对象
|
||
self.chat_stream = chat_stream or kwargs.get("chat_stream")
|
||
self.chat_id = self.chat_stream.stream_id
|
||
self.platform = getattr(self.chat_stream, "platform", None)
|
||
|
||
# 初始化基础信息(带类型注解)
|
||
self.action_message = action_message
|
||
|
||
self.group_id = None
|
||
self.group_name = None
|
||
self.user_id = None
|
||
self.user_nickname = None
|
||
self.is_group = False
|
||
self.target_id = None
|
||
|
||
|
||
self.group_id = (
|
||
str(self.action_message.chat_info.group_info.group_id)
|
||
if self.action_message.chat_info.group_info
|
||
else None
|
||
)
|
||
self.group_name = (
|
||
self.action_message.chat_info.group_info.group_name
|
||
if self.action_message.chat_info.group_info
|
||
else None
|
||
)
|
||
|
||
self.user_id = str(self.action_message.user_info.user_id)
|
||
self.user_nickname = self.action_message.user_info.user_nickname
|
||
|
||
if self.group_id:
|
||
self.is_group = True
|
||
self.target_id = self.group_id
|
||
self.log_prefix = f"[{self.group_name}]"
|
||
else:
|
||
self.is_group = False
|
||
self.target_id = self.user_id
|
||
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
|
||
|
||
|
||
logger.debug(
|
||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||
)
|
||
|
||
@abstractmethod
|
||
async def execute(self) -> Tuple[bool, str]:
|
||
"""执行Action的抽象方法,子类必须实现
|
||
|
||
Returns:
|
||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||
"""
|
||
pass
|
||
|
||
async def send_text(
|
||
self,
|
||
content: str,
|
||
set_reply: bool = False,
|
||
reply_message: Optional["DatabaseMessages"] = None,
|
||
typing: bool = False,
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""发送文本消息
|
||
|
||
Args:
|
||
content: 文本内容
|
||
set_reply: 是否作为回复发送
|
||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
||
typing: 是否计算输入时间
|
||
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
|
||
return await send_api.text_to_stream(
|
||
text=content,
|
||
stream_id=self.chat_id,
|
||
set_reply=set_reply,
|
||
reply_message=reply_message,
|
||
typing=typing,
|
||
storage_message=storage_message,
|
||
)
|
||
|
||
async def send_emoji(
|
||
self,
|
||
emoji_base64: str,
|
||
set_reply: bool = False,
|
||
reply_message: Optional["DatabaseMessages"] = None,
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""发送表情包
|
||
|
||
Args:
|
||
emoji_base64: 表情包的base64编码
|
||
set_reply: 是否作为回复发送
|
||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
||
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
|
||
return await send_api.emoji_to_stream(
|
||
emoji_base64,
|
||
self.chat_id,
|
||
set_reply=set_reply,
|
||
reply_message=reply_message,
|
||
storage_message=storage_message,
|
||
)
|
||
|
||
async def send_image(
|
||
self,
|
||
image_base64: str,
|
||
set_reply: bool = False,
|
||
reply_message: Optional["DatabaseMessages"] = None,
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""发送图片
|
||
|
||
Args:
|
||
image_base64: 图片的base64编码
|
||
set_reply: 是否作为回复发送
|
||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
||
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
|
||
return await send_api.image_to_stream(
|
||
image_base64,
|
||
self.chat_id,
|
||
set_reply=set_reply,
|
||
reply_message=reply_message,
|
||
storage_message=storage_message,
|
||
)
|
||
|
||
async def send_command(
|
||
self,
|
||
command_name: str,
|
||
args: Optional[dict] = None,
|
||
display_message: str = "",
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""发送命令消息
|
||
|
||
Args:
|
||
command_name: 命令名称
|
||
args: 命令参数
|
||
display_message: 显示消息
|
||
storage_message: 是否存储消息到数据库
|
||
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
|
||
# 构造命令数据
|
||
command_data = {"name": command_name, "args": args or {}}
|
||
|
||
return await send_api.command_to_stream(
|
||
command=command_data,
|
||
stream_id=self.chat_id,
|
||
storage_message=storage_message,
|
||
display_message=display_message,
|
||
)
|
||
|
||
async def send_custom(
|
||
self,
|
||
message_type: str,
|
||
content: str | Dict,
|
||
typing: bool = False,
|
||
set_reply: bool = False,
|
||
reply_message: Optional["DatabaseMessages"] = None,
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""发送自定义类型消息
|
||
|
||
Args:
|
||
message_type: 消息类型,如"video"、"file"、"audio"等
|
||
content: 消息内容
|
||
typing: 是否显示正在输入
|
||
set_reply: 是否作为回复发送
|
||
reply_message: 回复的消息对象(set_reply 为 True时必填)
|
||
storage_message: 是否存储消息到数据库
|
||
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
|
||
return await send_api.custom_to_stream(
|
||
message_type=message_type,
|
||
content=content,
|
||
stream_id=self.chat_id,
|
||
typing=typing,
|
||
set_reply=set_reply,
|
||
reply_message=reply_message,
|
||
storage_message=storage_message,
|
||
)
|
||
|
||
async def send_hybrid(
|
||
self,
|
||
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
|
||
typing: bool = False,
|
||
set_reply: bool = False,
|
||
reply_message: Optional["DatabaseMessages"] = None,
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""
|
||
发送混合类型消息
|
||
|
||
Args:
|
||
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
|
||
typing: 是否计算打字时间
|
||
set_reply: 是否作为回复发送
|
||
reply_message: 回复的消息对象
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
reply_set = ReplySetModel()
|
||
reply_set.add_hybrid_content_by_raw(message_tuple_list)
|
||
return await send_api.custom_reply_set_to_stream(
|
||
reply_set=reply_set,
|
||
stream_id=self.chat_id,
|
||
typing=typing,
|
||
set_reply=set_reply,
|
||
reply_message=reply_message,
|
||
storage_message=storage_message,
|
||
)
|
||
|
||
async def send_forward(
|
||
self,
|
||
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
|
||
storage_message: bool = True,
|
||
) -> bool:
|
||
"""转发消息
|
||
|
||
Args:
|
||
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id"
|
||
其中消息体的格式为 [(内容类型, 内容), ...]
|
||
任意长度的消息都需要使用列表的形式传入
|
||
storage_message: 是否存储消息到数据库
|
||
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||
return False
|
||
reply_set = ReplySetModel()
|
||
forward_message_nodes: List[ForwardNode] = []
|
||
for message in messages_list:
|
||
if isinstance(message, str):
|
||
forward_message_node = ForwardNode.construct_as_id_reference(message)
|
||
elif isinstance(message, Tuple) and len(message) == 3:
|
||
sender_id, nickname, content_list = message
|
||
single_node_content_list: List[ReplyContent] = []
|
||
for node_content_type, node_content in content_list:
|
||
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
|
||
single_node_content_list.append(reply_node_content)
|
||
forward_message_node = ForwardNode.construct_as_created_node(
|
||
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
|
||
)
|
||
else:
|
||
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
|
||
continue
|
||
forward_message_nodes.append(forward_message_node)
|
||
reply_set.add_forward_content(forward_message_nodes)
|
||
return await send_api.custom_reply_set_to_stream(
|
||
reply_set=reply_set,
|
||
stream_id=self.chat_id,
|
||
storage_message=storage_message,
|
||
set_reply=False,
|
||
reply_message=None,
|
||
)
|
||
|
||
async def send_voice(self, audio_base64: str) -> bool:
|
||
"""
|
||
发送语音消息
|
||
Args:
|
||
audio_base64: 语音的base64编码
|
||
Returns:
|
||
bool: 是否发送成功
|
||
"""
|
||
if not audio_base64:
|
||
logger.error(f"{self.log_prefix} 缺少音频内容")
|
||
return False
|
||
reply_set = ReplySetModel()
|
||
reply_set.add_voice_content(audio_base64)
|
||
return await send_api.custom_reply_set_to_stream(
|
||
reply_set=reply_set,
|
||
stream_id=self.chat_id,
|
||
storage_message=False,
|
||
)
|
||
|
||
async def store_action_info(
|
||
self,
|
||
action_build_into_prompt: bool = False,
|
||
action_prompt_display: str = "",
|
||
action_done: bool = True,
|
||
) -> None:
|
||
"""存储动作信息到数据库
|
||
|
||
Args:
|
||
action_build_into_prompt: 是否构建到提示中
|
||
action_prompt_display: 显示的action提示信息
|
||
action_done: action是否完成
|
||
"""
|
||
await database_api.store_action_info(
|
||
chat_stream=self.chat_stream,
|
||
action_build_into_prompt=action_build_into_prompt,
|
||
action_prompt_display=action_prompt_display,
|
||
action_done=action_done,
|
||
thinking_id=self.thinking_id,
|
||
action_data=self.action_data,
|
||
action_name=self.action_name,
|
||
action_reasoning=self.action_reasoning,
|
||
)
|
||
|
||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||
"""等待新消息或超时
|
||
|
||
在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。
|
||
使用message_api检查self.chat_id对应的聊天中是否有新消息。
|
||
|
||
Args:
|
||
timeout: 超时时间(秒),默认1200秒
|
||
|
||
Returns:
|
||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||
"""
|
||
try:
|
||
# 获取循环开始时间,如果没有则使用当前时间
|
||
loop_start_time = self.action_data.get("loop_start_time", time.time())
|
||
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
|
||
|
||
# 确保有有效的chat_id
|
||
if not self.chat_id:
|
||
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
|
||
return False, "没有有效的chat_id"
|
||
|
||
wait_start_time = asyncio.get_event_loop().time()
|
||
while True:
|
||
|
||
# 检查新消息
|
||
current_time = time.time()
|
||
new_message_count = message_api.count_new_messages(
|
||
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
|
||
)
|
||
|
||
if new_message_count > 0:
|
||
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息,聊天ID: {self.chat_id}")
|
||
return True, ""
|
||
|
||
# 检查超时
|
||
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
|
||
if elapsed_time > timeout:
|
||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒),聊天ID: {self.chat_id}")
|
||
return False, ""
|
||
|
||
# 每30秒记录一次等待状态
|
||
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
|
||
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
|
||
|
||
# 短暂休眠
|
||
await asyncio.sleep(0.5)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||
return False, ""
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||
return False, f"等待新消息失败: {str(e)}"
|
||
|
||
@classmethod
|
||
def get_action_info(cls) -> "ActionInfo":
|
||
"""从类属性生成ActionInfo
|
||
|
||
所有信息都从类属性中读取,确保一致性和完整性。
|
||
Action类必须定义所有必要的类属性。
|
||
|
||
Returns:
|
||
ActionInfo: 生成的Action信息对象
|
||
"""
|
||
|
||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
||
if "." in name:
|
||
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||
# 获取focus_activation_type和normal_activation_type
|
||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
||
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||
|
||
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
||
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
||
|
||
return ActionInfo(
|
||
name=name,
|
||
component_type=ComponentType.ACTION,
|
||
description=getattr(cls, "action_description", "Action动作"),
|
||
activation_type=activation_type,
|
||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||
parallel_action=getattr(cls, "parallel_action", True),
|
||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||
# 使用正确的字段名
|
||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||
action_require=getattr(cls, "action_require", []).copy(),
|
||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||
)
|
||
|
||
def get_config(self, key: str, default=None):
|
||
"""获取插件配置值,使用嵌套键访问
|
||
|
||
Args:
|
||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||
default: 默认值
|
||
|
||
Returns:
|
||
Any: 配置值或默认值
|
||
"""
|
||
if not self.plugin_config:
|
||
return default
|
||
|
||
# 支持嵌套键访问
|
||
keys = key.split(".")
|
||
current = self.plugin_config
|
||
|
||
for k in keys:
|
||
if isinstance(current, dict) and k in current:
|
||
current = current[k]
|
||
else:
|
||
return default
|
||
|
||
return current
|