MaiBot/src/plugin_system/base/base_action.py

545 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
"""
def __init__(
self,
action_data: dict,
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
plugin_config: Optional[dict] = None,
action_message: Optional["DatabaseMessages"] = None,
**kwargs,
):
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
chat_stream: 聊天流对象
log_prefix: 日志前缀
plugin_config: 插件配置字典
action_message: 消息数据
**kwargs: 其他参数
"""
if plugin_config is None:
plugin_config = {}
self.action_data = action_data
self.reasoning = ""
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.action_reasoning = action_reasoning
self.plugin_config = plugin_config or {}
"""对应的插件配置"""
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
"""Action的名字"""
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
"""Action的描述"""
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
"""NORMAL模式下的激活类型"""
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
"""激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "") # 已弃用
"""协助LLM进行判断的Prompt"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)
self.action_message = action_message
self.group_id = None
self.group_name = None
self.user_id = None
self.user_nickname = None
self.is_group = False
self.target_id = None
self.group_id = (
str(self.action_message.chat_info.group_info.group_id)
if self.action_message.chat_info.group_info
else None
)
self.group_name = (
self.action_message.chat_info.group_info.group_name
if self.action_message.chat_info.group_info
else None
)
self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname
if self.group_id:
self.is_group = True
self.target_id = self.group_id
self.log_prefix = f"[{self.group_name}]"
else:
self.is_group = False
self.target_id = self.user_id
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
storage_message: bool = True,
) -> bool:
"""发送文本消息
Args:
content: 文本内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
typing: 是否计算输入时间
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.text_to_stream(
text=content,
stream_id=self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
typing=typing,
storage_message=storage_message,
)
async def send_emoji(
self,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(
emoji_base64,
self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(
image_base64,
self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_command(
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
return await send_api.command_to_stream(
command=command_data,
stream_id=self.chat_id,
storage_message=storage_message,
display_message=display_message,
)
async def send_custom(
self,
message_type: str,
content: str | Dict,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送自定义类型消息
Args:
message_type: 消息类型,如"video""file""audio"
content: 消息内容
typing: 是否显示正在输入
set_reply: 是否作为回复发送
reply_message: 回复的消息对象set_reply 为 True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=self.chat_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_hybrid(
self,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否计算打字时间
set_reply: 是否作为回复发送
reply_message: 回复的消息对象
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)
async def send_voice(self, audio_base64: str) -> bool:
"""
发送语音消息
Args:
audio_base64: 语音的base64编码
Returns:
bool: 是否发送成功
"""
if not audio_base64:
logger.error(f"{self.log_prefix} 缺少音频内容")
return False
reply_set = ReplySetModel()
reply_set.add_voice_content(audio_base64)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
storage_message=False,
)
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
) -> None:
"""存储动作信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
"""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
action_done=action_done,
thinking_id=self.thinking_id,
action_data=self.action_data,
action_name=self.action_name,
action_reasoning=self.action_reasoning,
)
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时
在loop_start_time之后等待新消息如果没有新消息且没有超时就一直等待。
使用message_api检查self.chat_id对应的聊天中是否有新消息。
Args:
timeout: 超时时间默认1200秒
Returns:
Tuple[bool, str]: (是否收到新消息, 空字符串)
"""
try:
# 获取循环开始时间,如果没有则使用当前时间
loop_start_time = self.action_data.get("loop_start_time", time.time())
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
# 确保有有效的chat_id
if not self.chat_id:
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
return False, "没有有效的chat_id"
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
)
if new_message_count > 0:
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息聊天ID: {self.chat_id}")
return True, ""
# 检查超时
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
if elapsed_time > timeout:
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)聊天ID: {self.chat_id}")
return False, ""
# 每30秒记录一次等待状态
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
# 短暂休眠
await asyncio.sleep(0.5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
@classmethod
def get_action_info(cls) -> "ActionInfo":
"""从类属性生成ActionInfo
所有信息都从类属性中读取,确保一致性和完整性。
Action类必须定义所有必要的类属性。
Returns:
ActionInfo: 生成的Action信息对象
"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
if "." in name:
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type)
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=getattr(cls, "action_description", "Action动作"),
activation_type=activation_type,
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
)
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current