MaiBot/src/chat/focus_chat/planners/actions/plugin_action.py

275 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import traceback
from typing import Tuple, Dict, List, Any, Optional
from src.chat.focus_chat.planners.actions.base_action import BaseAction
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.common.logger_manager import get_logger
from src.person_info.person_info import person_info_manager
from abc import abstractmethod
import os
import inspect
import toml # 导入 toml 库
logger = get_logger("plugin_action")
class PluginAction(BaseAction):
"""插件动作基类
封装了主程序内部依赖提供简化的API接口给插件开发者
"""
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
global_config: Optional[dict] = None,
**kwargs,
):
"""初始化插件动作基类"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
# 存储内部服务和对象引用
self._services = {}
self._global_config = global_config # 存储全局配置的只读引用
self.config: Dict[str, Any] = {} # 用于存储插件自身的配置
# 从kwargs提取必要的内部服务
if "observations" in kwargs:
self._services["observations"] = kwargs["observations"]
if "expressor" in kwargs:
self._services["expressor"] = kwargs["expressor"]
if "chat_stream" in kwargs:
self._services["chat_stream"] = kwargs["chat_stream"]
self.log_prefix = kwargs.get("log_prefix", "")
self._load_plugin_config() # 初始化时加载插件配置
def _load_plugin_config(self):
"""
加载插件自身的配置文件。
配置文件应与插件模块在同一目录下。
插件可以通过覆盖 `action_config_file_name` 类属性来指定文件名。
如果 `action_config_file_name` 未指定,则不加载配置。
仅支持 TOML (.toml) 格式。
"""
if not self.action_config_file_name:
logger.debug(
f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name不加载插件配置。"
)
return
try:
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
config_file_path = os.path.join(plugin_dir, self.action_config_file_name)
if not os.path.exists(config_file_path):
logger.warning(
f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。"
)
return
file_ext = os.path.splitext(self.action_config_file_name)[1].lower()
if file_ext == ".toml":
with open(config_file_path, "r", encoding="utf-8") as f:
self.config = toml.load(f) or {}
logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。")
else:
logger.warning(
f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。"
)
self.config = {} # 确保未加载时为空字典
return
except Exception as e:
logger.error(
f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}"
)
self.config = {} # 出错时确保 config 是一个空字典
def get_global_config(self, key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
插件应使用此方法读取全局配置,以保证只读和隔离性。
"""
if self._global_config:
return self._global_config.get(key, default)
logger.debug(f"{self.log_prefix} 尝试访问全局配置项 '{key}',但全局配置未提供。")
return default
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
"""根据用户名获取用户ID"""
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
# 提供简化的API方法
async def send_message(self, type: str, data: str, target: Optional[str] = "") -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
# reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(target)
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
response_set = [
(type, data),
]
# 调用内部方法发送消息
success = await expressor.send_response_messages(
anchor_message=anchor_message,
response_set=response_set,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
traceback.print_exc()
return False
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 调用内部方法发送消息
success, _ = await expressor.deal_reply(
cycle_timers=self.cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=self.reasoning,
thinking_id=self.thinking_id,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
return False
def get_chat_type(self) -> str:
"""获取当前聊天类型
Returns:
str: 聊天类型 ("group""private")
"""
chat_stream = self._services.get("chat_stream")
if chat_stream and hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
"""获取最近的消息
Args:
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
observations = self._services.get("observations", [])
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
raw_messages = obs.get_talking_message()
# 转换为简化格式
for msg in raw_messages[-count:]:
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
return messages
@abstractmethod
async def process(self) -> Tuple[bool, str]:
"""插件处理逻辑,子类必须实现此方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""实现BaseAction的抽象方法调用子类的process方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
return await self.process()