From c6f0c518250c7c81a674334e9203c8e7d35a7437 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 21 Aug 2025 23:21:56 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8A=8A=E5=AD=97=E5=85=B8=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E4=B8=BA=E6=95=B0=E6=8D=AE=E6=A8=A1=E5=9E=8B=E5=B9=B6=E6=81=A2?= =?UTF-8?q?=E5=A4=8D=E5=85=A8=E7=B3=BB=E7=BB=9F=E5=8F=AF=E7=94=A8=E6=80=A7?= =?UTF-8?q?=EF=BC=8C=E4=B8=B4=E6=97=B6=E4=BF=AE=E5=A4=8DInstantMemory?= =?UTF-8?q?=E8=AE=A9=E5=A4=A7=E6=A8=A1=E5=9E=8B=E8=87=B3=E5=B0=91=E7=9F=A5?= =?UTF-8?q?=E9=81=93=E5=9C=A8=E8=81=8A=E4=BB=80=E4=B9=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chat_loop/heartFC_chat.py | 254 ++++++++++-------- src/chat/message_receive/message.py | 10 +- src/chat/planner_actions/action_manager.py | 7 +- src/chat/planner_actions/planner.py | 82 +++--- src/chat/replyer/default_generator.py | 61 +++-- src/chat/utils/utils.py | 3 +- src/common/data_models/database_data_model.py | 41 ++- src/common/data_models/info_data_model.py | 19 +- src/plugin_system/apis/generator_api.py | 54 ++-- src/plugin_system/apis/send_api.py | 83 ++++-- src/plugin_system/base/base_action.py | 45 +++- src/plugin_system/base/base_command.py | 61 ++++- 12 files changed, 462 insertions(+), 258 deletions(-) diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index b781dc16..35c67663 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -3,13 +3,13 @@ import time import traceback import math import random -from typing import List, Optional, Dict, Any, Tuple +from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING from rich.traceback import install from collections import deque from src.config.config import global_config from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import ActionPlannerInfo from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer @@ -24,12 +24,15 @@ from src.chat.frequency_control.focus_value_control import focus_value_control from src.chat.express.expression_learner import expression_learner_manager from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.person_info import Person -from src.plugin_system.base.component_types import ChatMode, EventType +from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo from src.plugin_system.core import events_manager from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.mais4u.mai_think import mai_thinking_manager from src.mais4u.s4u_config import s4u_config +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + ERROR_LOOP_INFO = { "loop_plan_info": { @@ -141,7 +144,7 @@ class HeartFChatting: except asyncio.CancelledError: logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") - def start_cycle(self): + def start_cycle(self) -> Tuple[Dict[str, float], str]: self._cycle_counter += 1 self._current_cycle_detail = CycleDetail(self._cycle_counter) self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" @@ -172,7 +175,8 @@ class HeartFChatting: action_type = action_result.get("action_type", "未知动作") elif isinstance(action_result, list) and action_result: # 新格式:action_result是actions列表 - action_type = action_result[0].get("action_type", "未知动作") + # TODO: 把这里写明白 + action_type = action_result[0].action_type or "未知动作" elif isinstance(loop_plan_info, list) and loop_plan_info: # 直接是actions列表的情况 action_type = loop_plan_info[0].get("action_type", "未知动作") @@ -207,7 +211,7 @@ class HeartFChatting: logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息") self.focus_energy = 1 - async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]: + async def _should_process_messages(self, new_message: List["DatabaseMessages"]) -> tuple[bool, float]: """ 判断是否应该处理消息 @@ -290,11 +294,11 @@ class HeartFChatting: async def _send_and_store_reply( self, response_set, - action_message, + action_message: "DatabaseMessages", cycle_timers: Dict[str, float], thinking_id, actions, - selected_expressions: List[int] = None, + selected_expressions: Optional[List[int]] = None, ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: with Timer("回复发送", cycle_timers): reply_text = await self._send_response( @@ -304,11 +308,11 @@ class HeartFChatting: ) # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.get("chat_info_platform") + platform = action_message.chat_info.platform if platform is None: platform = getattr(self.chat_stream, "platform", "unknown") - person = Person(platform=platform, user_id=action_message.get("user_id", "")) + person = Person(platform=platform, user_id=action_message.user_info.user_id) person_name = person.person_name action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" @@ -353,9 +357,13 @@ class HeartFChatting: k = 2.0 # 控制曲线陡峭程度 x0 = 1.0 # 控制曲线中心点 return 1.0 / (1.0 + math.exp(-k * (interest_val - x0))) - - normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency() - + + normal_mode_probability = ( + calculate_normal_mode_probability(interest_value) + * 2 + * self.talk_frequency_control.get_current_talk_frequency() + ) + # 根据概率决定使用哪种模式 if random.random() < normal_mode_probability: mode = ChatMode.NORMAL @@ -383,17 +391,17 @@ class HeartFChatting: except Exception as e: logger.error(f"{self.log_prefix} 记忆构建失败: {e}") + available_actions: Dict[str, ActionInfo] = {} if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS: # 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考 - actions = [ - { - "action_type": "no_action", - "reasoning": "专注不足", - "action_data": {}, - } + action_to_use_info = [ + ActionPlannerInfo( + action_type="no_action", + reasoning="专注不足", + action_data={}, + ) ] else: - available_actions = {} # 第一步:动作修改 with Timer("动作修改", cycle_timers): try: @@ -414,105 +422,19 @@ class HeartFChatting: ): return False with Timer("规划器", cycle_timers): - actions, _ = await self.action_planner.plan( + action_to_use_info, _ = await self.action_planner.plan( mode=mode, loop_start_time=self.last_read_time, available_actions=available_actions, ) # 3. 并行执行所有动作 - async def execute_action(action_info, actions): - """执行单个动作的通用函数""" - try: - if action_info["action_type"] == "no_action": - # 直接处理no_action逻辑,不再通过动作系统 - reason = action_info.get("reasoning", "选择不回复") - logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") - - # 存储no_action信息到数据库 - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={"reason": reason}, - action_name="no_action", - ) - - return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} - elif action_info["action_type"] != "reply": - # 执行普通动作 - with Timer("动作执行", cycle_timers): - success, reply_text, command = await self._handle_action( - action_info["action_type"], - action_info["reasoning"], - action_info["action_data"], - cycle_timers, - thinking_id, - action_info["action_message"], - ) - return { - "action_type": action_info["action_type"], - "success": success, - "reply_text": reply_text, - "command": command, - } - else: - try: - success, response_set, prompt_selected_expressions = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_info["action_message"], - available_actions=available_actions, - choosen_actions=actions, - reply_reason=action_info.get("reasoning", ""), - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - return_expressions=True, - ) - - if prompt_selected_expressions and len(prompt_selected_expressions) > 1: - _, selected_expressions = prompt_selected_expressions - else: - selected_expressions = [] - - if not success or not response_set: - logger.info( - f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败" - ) - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - - except asyncio.CancelledError: - logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - - loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( - response_set=response_set, - action_message=action_info["action_message"], - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=actions, - selected_expressions=selected_expressions, - ) - return { - "action_type": "reply", - "success": True, - "reply_text": reply_text, - "loop_info": loop_info, - } - except Exception as e: - logger.error(f"{self.log_prefix} 执行动作时出错: {e}") - logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") - return { - "action_type": action_info["action_type"], - "success": False, - "reply_text": "", - "loop_info": None, - "error": str(e), - } - - action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions] + action_tasks = [ + asyncio.create_task( + self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) + ) + for action in action_to_use_info + ] # 并行执行所有任务 results = await asyncio.gather(*action_tasks, return_exceptions=True) @@ -529,7 +451,7 @@ class HeartFChatting: logger.error(f"{self.log_prefix} 动作执行异常: {result}") continue - _cur_action = actions[i] + _cur_action = action_to_use_info[i] if result["action_type"] != "reply": action_success = result["success"] action_reply_text = result["reply_text"] @@ -558,7 +480,7 @@ class HeartFChatting: # 没有回复信息,构建纯动作的loop_info loop_info = { "loop_plan_info": { - "action_result": actions, + "action_result": action_to_use_info, }, "loop_action_info": { "action_taken": action_success, @@ -578,7 +500,7 @@ class HeartFChatting: # await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", "")) - action_type = actions[0]["action_type"] if actions else "no_action" + action_type = action_to_use_info[0].action_type if action_to_use_info else "no_action" # 管理no_action计数器:当执行了非no_action动作时,重置计数器 if action_type != "no_action": @@ -620,7 +542,7 @@ class HeartFChatting: action_data: dict, cycle_timers: Dict[str, float], thinking_id: str, - action_message: dict, + action_message: Optional["DatabaseMessages"] = None, ) -> tuple[bool, str, str]: """ 处理规划动作,使用动作工厂创建相应的动作处理器 @@ -672,8 +594,8 @@ class HeartFChatting: async def _send_response( self, reply_set, - message_data, - selected_expressions: List[int] = None, + message_data: "DatabaseMessages", + selected_expressions: Optional[List[int]] = None, ) -> str: new_message_count = message_api.count_new_messages( chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() @@ -710,3 +632,97 @@ class HeartFChatting: reply_text += data return reply_text + + async def _execute_action( + self, + action_planner_info: ActionPlannerInfo, + chosen_action_plan_infos: List[ActionPlannerInfo], + thinking_id: str, + available_actions: Dict[str, ActionInfo], + cycle_timers: Dict[str, float], + ): + """执行单个动作的通用函数""" + try: + if action_planner_info.action_type == "no_action": + # 直接处理no_action逻辑,不再通过动作系统 + reason = action_planner_info.reasoning or "选择不回复" + logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") + + # 存储no_action信息到数据库 + await database_api.store_action_info( + chat_stream=self.chat_stream, + action_build_into_prompt=False, + action_prompt_display=reason, + action_done=True, + thinking_id=thinking_id, + action_data={"reason": reason}, + action_name="no_action", + ) + + return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} + elif action_planner_info.action_type != "reply": + # 执行普通动作 + with Timer("动作执行", cycle_timers): + success, reply_text, command = await self._handle_action( + action_planner_info.action_type, + action_planner_info.reasoning or "", + action_planner_info.action_data or {}, + cycle_timers, + thinking_id, + action_planner_info.action_message, + ) + return { + "action_type": action_planner_info.action_type, + "success": success, + "reply_text": reply_text, + "command": command, + } + else: + try: + success, response_set, prompt, selected_expressions = await generator_api.generate_reply( + chat_stream=self.chat_stream, + reply_message=action_planner_info.action_message, + available_actions=available_actions, + chosen_actions=chosen_action_plan_infos, + reply_reason=action_planner_info.reasoning or "", + enable_tool=global_config.tool.enable_tool, + request_type="replyer", + from_plugin=False, + return_expressions=True, + ) + + if not success or not response_set: + if action_planner_info.action_message: + logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败") + else: + logger.info("回复生成失败") + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + + except asyncio.CancelledError: + logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + + loop_info, reply_text, _ = await self._send_and_store_reply( + response_set=response_set, + action_message=action_planner_info.action_message, # type: ignore + cycle_timers=cycle_timers, + thinking_id=thinking_id, + actions=chosen_action_plan_infos, + selected_expressions=selected_expressions, + ) + return { + "action_type": "reply", + "success": True, + "reply_text": reply_text, + "loop_info": loop_info, + } + except Exception as e: + logger.error(f"{self.log_prefix} 执行动作时出错: {e}") + logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") + return { + "action_type": action_planner_info.action_type, + "success": False, + "reply_text": "", + "loop_info": None, + "error": str(e), + } diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 098e6600..66a1c029 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -115,7 +115,7 @@ class MessageRecv(Message): self.priority_mode = "interest" self.priority_info = None self.interest_value: float = None # type: ignore - + self.key_words = [] self.key_words_lite = [] @@ -213,9 +213,9 @@ class MessageRecvS4U(MessageRecv): self.is_screen = False self.is_internal = False self.voice_done = None - + self.chat_info = None - + async def process(self) -> None: self.processed_plain_text = await self._process_message_segments(self.message_segment) @@ -420,7 +420,7 @@ class MessageSending(MessageProcessBase): thinking_start_time: float = 0, apply_set_reply_logic: bool = False, reply_to: Optional[str] = None, - selected_expressions:List[int] = None, + selected_expressions: Optional[List[int]] = None, ): # 调用父类初始化 super().__init__( @@ -445,7 +445,7 @@ class MessageSending(MessageProcessBase): self.display_message = display_message self.interest_value = 0.0 - + self.selected_expressions = selected_expressions def build_reply(self): diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 267b7a8f..b4587474 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -2,6 +2,7 @@ from typing import Dict, Optional, Type from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages from src.plugin_system.core.component_registry import component_registry from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.plugin_system.base.base_action import BaseAction @@ -37,7 +38,7 @@ class ActionManager: chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, - action_message: Optional[dict] = None, + action_message: Optional[DatabaseMessages] = None, ) -> Optional[BaseAction]: """ 创建动作处理器实例 @@ -83,7 +84,7 @@ class ActionManager: log_prefix=log_prefix, shutting_down=shutting_down, plugin_config=plugin_config, - action_message=action_message, + action_message=action_message.flatten() if action_message else None, ) logger.debug(f"创建Action实例成功: {action_name}") @@ -123,4 +124,4 @@ class ActionManager: """恢复到默认动作集""" actions_to_restore = list(self._using_actions.keys()) self._using_actions = component_registry.get_default_actions() - logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") + logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") \ No newline at end of file diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 5e0695c3..2cb2a469 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -1,7 +1,7 @@ import json import time import traceback -from typing import Dict, Any, Optional, Tuple, List +from typing import Dict, Optional, Tuple, List from rich.traceback import install from datetime import datetime from json_repair import repair_json @@ -9,6 +9,8 @@ from json_repair import repair_json from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import ActionPlannerInfo from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, @@ -97,7 +99,9 @@ class ActionPlanner: self.plan_retry_count = 0 self.max_plan_retries = 3 - def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: + def find_message_by_id( + self, message_id: str, message_id_list: List[DatabaseMessages] + ) -> Optional[DatabaseMessages]: # sourcery skip: use-next """ 根据message_id从message_id_list中查找对应的原始消息 @@ -110,37 +114,37 @@ class ActionPlanner: 找到的原始消息字典,如果未找到则返回None """ for item in message_id_list: - if item.get("id") == message_id: - return item.get("message") + if item.message_id == message_id: + return item return None - def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: + def get_latest_message(self, message_id_list: List[DatabaseMessages]) -> Optional[DatabaseMessages]: """ 获取消息列表中的最新消息 - + Args: message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...] - + Returns: 最新的消息字典,如果列表为空则返回None """ - return message_id_list[-1].get("message") if message_id_list else None + return message_id_list[-1] if message_id_list else None async def plan( self, mode: ChatMode = ChatMode.FOCUS, - loop_start_time:float = 0.0, + loop_start_time: float = 0.0, available_actions: Optional[Dict[str, ActionInfo]] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + ) -> Tuple[List[ActionPlannerInfo], Optional[DatabaseMessages]]: """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ - action = "no_action" # 默认动作 - reasoning = "规划器初始化默认" + action: str = "no_action" # 默认动作 + reasoning: str = "规划器初始化默认" action_data = {} current_available_actions: Dict[str, ActionInfo] = {} - target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量 + target_message: Optional[DatabaseMessages] = None # 初始化target_message变量 prompt: str = "" message_id_list: list = [] @@ -208,19 +212,21 @@ class ActionPlanner: # 如果获取的target_message为None,输出warning并重新plan if target_message is None: self.plan_retry_count += 1 - logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}") + logger.warning( + f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}" + ) # 仍有重试次数 if self.plan_retry_count < self.max_plan_retries: # 递归重新plan return await self.plan(mode, loop_start_time, available_actions) - logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message") + logger.error( + f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message" + ) target_message = self.get_latest_message(message_id_list) self.plan_retry_count = 0 # 重置计数器 else: logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id") - - if action != "no_action" and action != "reply" and action not in current_available_actions: logger.warning( f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'" @@ -244,38 +250,37 @@ class ActionPlanner: if mode == ChatMode.NORMAL and action in current_available_actions: is_parallel = current_available_actions[action].parallel_action - action_data["loop_start_time"] = loop_start_time actions = [ - { - "action_type": action, - "reasoning": reasoning, - "action_data": action_data, - "action_message": target_message, - "available_actions": available_actions, - } + ActionPlannerInfo( + action_type=action, + reasoning=reasoning, + action_data=action_data, + action_message=target_message, + available_actions=available_actions, + ) ] if action != "reply" and is_parallel: - actions.append({ - "action_type": "reply", - "action_message": target_message, - "available_actions": available_actions - }) + actions.append( + ActionPlannerInfo( + action_type="reply", + action_message=target_message, + available_actions=available_actions, + ) + ) - return actions,target_message - - + return actions, target_message async def build_planner_prompt( self, is_group_chat: bool, # Now passed as argument chat_target_info: Optional[dict], # Now passed as argument current_available_actions: Dict[str, ActionInfo], - refresh_time :bool = False, + refresh_time: bool = False, mode: ChatMode = ChatMode.FOCUS, - ) -> tuple[str, list]: # sourcery skip: use-join + ) -> tuple[str, List[DatabaseMessages]]: # sourcery skip: use-join """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: message_list_before_now = get_raw_msg_before_timestamp_with_chat( @@ -305,13 +310,12 @@ class ActionPlanner: actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" if refresh_time: self.last_obs_time_mark = time.time() - + mentioned_bonus = "" if global_config.chat.mentioned_bot_inevitable_reply: mentioned_bonus = "\n- 有人提到你" if global_config.chat.at_bot_inevitable_reply: mentioned_bonus = "\n- 有人提到你,或者at你" - if mode == ChatMode.FOCUS: no_action_block = """ @@ -332,7 +336,7 @@ class ActionPlanner: """ chat_context_description = "你现在正在一个群聊中" - chat_target_name = None + chat_target_name = None if not is_group_chat and chat_target_info: chat_target_name = ( chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方" @@ -388,7 +392,7 @@ class ActionPlanner: action_options_text=action_options_block, moderation_prompt=moderation_prompt_block, identity_block=identity_block, - plan_style = global_config.personality.plan_style + plan_style=global_config.personality.plan_style, ) return prompt, message_id_list except Exception as e: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 87ff7bdb..0dca9f60 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -9,6 +9,7 @@ from datetime import datetime from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import ActionPlannerInfo from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest @@ -157,12 +158,12 @@ class DefaultReplyer: extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - chosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List[ActionPlannerInfo]] = None, enable_tool: bool = True, from_plugin: bool = True, stream_id: Optional[str] = None, - reply_message: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]: + reply_message: Optional[DatabaseMessages] = None, + ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[List[int]]]: # sourcery skip: merge-nested-ifs """ 回复器 (Replier): 负责生成回复文本的核心逻辑。 @@ -181,7 +182,7 @@ class DefaultReplyer: """ prompt = None - selected_expressions = None + selected_expressions: Optional[List[int]] = None if available_actions is None: available_actions = {} try: @@ -374,7 +375,12 @@ class DefaultReplyer: ) if global_config.memory.enable_instant_memory: - asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history)) + chat_history_str = build_readable_messages( + messages=chat_history, + replace_bot_name=True, + timestamp_mode="normal" + ) + asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str)) instant_memory = await self.instant_memory.get_memory(target) logger.info(f"即时记忆:{instant_memory}") @@ -527,7 +533,7 @@ class DefaultReplyer: Returns: Tuple[str, str]: (核心对话prompt, 背景对话prompt) """ - core_dialogue_list = [] + core_dialogue_list: List[DatabaseMessages] = [] bot_id = str(global_config.bot.qq_account) # 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话 @@ -559,7 +565,7 @@ class DefaultReplyer: if core_dialogue_list: # 检查最新五条消息中是否包含bot自己说的消息 latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list - has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) + has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages) # logger.info(f"最新五条消息:{latest_5_messages}") # logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}") @@ -634,7 +640,7 @@ class DefaultReplyer: return mai_think async def build_actions_prompt( - self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None + self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None ) -> str: """构建动作提示""" @@ -646,20 +652,21 @@ class DefaultReplyer: action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - choosen_action_descriptions = "" - if choosen_actions: - for action in choosen_actions: - action_name = action.get("action_type", "unknown_action") + chosen_action_descriptions = "" + if chosen_actions_info: + for action_plan_info in chosen_actions_info: + action_name = action_plan_info.action_type if action_name == "reply": continue - action_description = action.get("reason", "无描述") - reasoning = action.get("reasoning", "无原因") + if action := available_actions.get(action_name): + action_description = action.description or "无描述" + reasoning = action_plan_info.reasoning or "无原因" - choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" + chosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" - if choosen_action_descriptions: + if chosen_action_descriptions: action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n" - action_descriptions += choosen_action_descriptions + action_descriptions += chosen_action_descriptions return action_descriptions @@ -668,9 +675,9 @@ class DefaultReplyer: extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - chosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List[ActionPlannerInfo]] = None, enable_tool: bool = True, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional[DatabaseMessages] = None, ) -> Tuple[str, List[int]]: """ 构建回复器上下文 @@ -694,11 +701,11 @@ class DefaultReplyer: platform = chat_stream.platform if reply_message: - user_id = reply_message.get("user_id", "") + user_id = reply_message.user_info.user_id person = Person(platform=platform, user_id=user_id) person_name = person.person_name or user_id sender = person_name - target = reply_message.get("processed_plain_text") + target = reply_message.processed_plain_text else: person_name = "用户" sender = "用户" @@ -774,11 +781,13 @@ class DefaultReplyer: logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s") expression_habits_block, selected_expressions = results_dict["expression_habits"] - relation_info = results_dict["relation_info"] - memory_block = results_dict["memory_block"] - tool_info = results_dict["tool_info"] - prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果 - actions_info = results_dict["actions_info"] + expression_habits_block: str + selected_expressions: List[int] + relation_info: str = results_dict["relation_info"] + memory_block: str = results_dict["memory_block"] + tool_info: str = results_dict["tool_info"] + prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果 + actions_info: str = results_dict["actions_info"] keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) if extra_info: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 3528fe4b..472a9cdd 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,7 +11,6 @@ from collections import Counter from typing import Optional, Tuple, Dict, List, Any from src.common.logger import get_logger -from src.common.data_models.info_data_model import TargetPersonInfo from src.common.data_models.database_data_model import DatabaseMessages from src.common.message_repository import find_messages, count_messages from src.config.config import global_config, model_config @@ -641,6 +640,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: platform: str = chat_stream.platform user_id: str = user_info.user_id # type: ignore + from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题 + # Initialize target_info with basic info target_info = TargetPersonInfo( platform=platform, diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 59761d09..1f671890 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Optional, Any, Dict from dataclasses import dataclass, field from . import BaseDataModel @@ -157,3 +157,42 @@ class DatabaseMessages(BaseDataModel): # assert isinstance(self.interest_value, float) or self.interest_value is None, ( # "interest_value must be a float or None" # ) + def flatten(self) -> Dict[str, Any]: + """ + 将消息数据模型转换为字典格式,便于存储或传输 + """ + return { + "message_id": self.message_id, + "time": self.time, + "chat_id": self.chat_id, + "reply_to": self.reply_to, + "interest_value": self.interest_value, + "key_words": self.key_words, + "key_words_lite": self.key_words_lite, + "is_mentioned": self.is_mentioned, + "processed_plain_text": self.processed_plain_text, + "display_message": self.display_message, + "priority_mode": self.priority_mode, + "priority_info": self.priority_info, + "additional_config": self.additional_config, + "is_emoji": self.is_emoji, + "is_picid": self.is_picid, + "is_command": self.is_command, + "is_notify": self.is_notify, + "selected_expressions": self.selected_expressions, + "user_id": self.user_info.user_id, + "user_nickname": self.user_info.user_nickname, + "user_cardname": self.user_info.user_cardname, + "user_platform": self.user_info.platform, + "chat_info_group_id": self.group_info.group_id if self.group_info else None, + "chat_info_group_name": self.group_info.group_name if self.group_info else None, + "chat_info_group_platform": self.group_info.group_platform if self.group_info else None, + "chat_info_stream_id": self.chat_info.stream_id, + "chat_info_platform": self.chat_info.platform, + "chat_info_create_time": self.chat_info.create_time, + "chat_info_last_active_time": self.chat_info.last_active_time, + "chat_info_user_platform": self.chat_info.user_info.platform, + "chat_info_user_id": self.chat_info.user_info.user_id, + "chat_info_user_nickname": self.chat_info.user_info.user_nickname, + "chat_info_user_cardname": self.chat_info.user_info.user_cardname, + } diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index ae3678d1..0f7b1f95 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,12 +1,25 @@ from dataclasses import dataclass, field -from typing import Optional - +from typing import Optional, Dict, TYPE_CHECKING from . import BaseDataModel +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + from src.plugin_system.base.component_types import ActionInfo + + @dataclass class TargetPersonInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) person_id: Optional[str] = None - person_name: Optional[str] = None \ No newline at end of file + person_name: Optional[str] = None + + +@dataclass +class ActionPlannerInfo(BaseDataModel): + action_type: str = field(default_factory=str) + reasoning: Optional[str] = None + action_data: Optional[Dict] = None + action_message: Optional["DatabaseMessages"] = None + available_actions: Optional[Dict[str, "ActionInfo"]] = None diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 3ffbc715..b0ef9995 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -9,7 +9,7 @@ """ import traceback -from typing import Tuple, Any, Dict, List, Optional +from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING from rich.traceback import install from src.common.logger import get_logger from src.chat.replyer.default_generator import DefaultReplyer @@ -18,6 +18,10 @@ from src.chat.utils.utils import process_llm_response from src.chat.replyer.replyer_manager import replyer_manager from src.plugin_system.base.component_types import ActionInfo +if TYPE_CHECKING: + from src.common.data_models.info_data_model import ActionPlannerInfo + from src.common.data_models.database_data_model import DatabaseMessages + install(extra_lines=3) logger = get_logger("generator_api") @@ -73,11 +77,11 @@ async def generate_reply( chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, action_data: Optional[Dict[str, Any]] = None, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - choosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List["ActionPlannerInfo"]] = None, enable_tool: bool = False, enable_splitter: bool = True, enable_chinese_typo: bool = True, @@ -85,7 +89,7 @@ async def generate_reply( request_type: str = "generator_api", from_plugin: bool = True, return_expressions: bool = False, -) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]: +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]: """生成回复 Args: @@ -96,7 +100,7 @@ async def generate_reply( extra_info: 额外信息,用于补充上下文 reply_reason: 回复原因 available_actions: 可用动作 - choosen_actions: 已选动作 + chosen_actions: 已选动作 enable_tool: 是否启用工具调用 enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 @@ -110,16 +114,14 @@ async def generate_reply( try: # 获取回复器 logger.debug("[GeneratorAPI] 开始生成回复") - replyer = get_replyer( - chat_stream, chat_id, request_type=request_type - ) + replyer = get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") - return False, [], None + return False, [], None, None if not extra_info and action_data: extra_info = action_data.get("extra_info", "") - + if not reply_reason and action_data: reply_reason = action_data.get("reason", "") @@ -127,7 +129,7 @@ async def generate_reply( success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context( extra_info=extra_info, available_actions=available_actions, - chosen_actions=choosen_actions, + chosen_actions=chosen_actions, enable_tool=enable_tool, reply_message=reply_message, reply_reason=reply_reason, @@ -136,7 +138,7 @@ async def generate_reply( ) if not success: logger.warning("[GeneratorAPI] 回复生成失败") - return False, [], None + return False, [], None, None assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况 if content := llm_response_dict.get("content", ""): reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) @@ -144,17 +146,23 @@ async def generate_reply( reply_set = [] logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项") - if return_prompt: - if return_expressions: - return success, reply_set, (prompt, selected_expressions) - else: - return success, reply_set, prompt - else: - if return_expressions: - return success, reply_set, (None, selected_expressions) - else: - return success, reply_set, None - + # if return_prompt: + # if return_expressions: + # return success, reply_set, prompt, selected_expressions + # else: + # return success, reply_set, prompt, None + # else: + # if return_expressions: + # return success, reply_set, (None, selected_expressions) + # else: + # return success, reply_set, None + return ( + success, + reply_set, + prompt if return_prompt else None, + selected_expressions if return_expressions else None, + ) + except ValueError as ve: raise ve diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 700042de..4bdab41e 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -21,15 +21,17 @@ import traceback import time -from typing import Optional, Union, Dict, Any, List -from src.common.logger import get_logger +from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING -# 导入依赖 +from src.common.logger import get_logger +from src.config.config import global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.message_receive.message import MessageSending, MessageRecv from maim_message import Seg, UserInfo -from src.config.config import global_config + +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("send_api") @@ -46,10 +48,10 @@ async def _send_to_target( display_message: str = "", typing: bool = False, set_reply: bool = False, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, storage_message: bool = True, show_log: bool = True, - selected_expressions:List[int] = None, + selected_expressions: Optional[List[int]] = None, ) -> bool: """向指定目标发送消息的内部实现 @@ -70,7 +72,7 @@ async def _send_to_target( if set_reply and not reply_message: logger.warning("[SendAPI] 使用引用回复,但未提供回复消息") return False - + if show_log: logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}") @@ -98,13 +100,13 @@ async def _send_to_target( message_segment = Seg(type=message_type, data=content) # type: ignore if reply_message: - anchor_message = message_dict_to_message_recv(reply_message) + anchor_message = message_dict_to_message_recv(reply_message.flatten()) if anchor_message: anchor_message.update_chat_stream(target_stream) assert anchor_message.message_info.user_info, "用户信息缺失" reply_to_platform_id = ( f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" - ) + ) else: reply_to_platform_id = "" anchor_message = None @@ -192,12 +194,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa } message_recv = MessageRecv(message_dict_recv) - + logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}") return message_recv - # ============================================================================= # 公共API函数 - 预定义类型的发送函数 # ============================================================================= @@ -208,9 +209,9 @@ async def text_to_stream( stream_id: str, typing: bool = False, set_reply: bool = False, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, storage_message: bool = True, - selected_expressions:List[int] = None, + selected_expressions: Optional[List[int]] = None, ) -> bool: """向指定流发送文本消息 @@ -237,7 +238,13 @@ async def text_to_stream( ) -async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: +async def emoji_to_stream( + emoji_base64: str, + stream_id: str, + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, +) -> bool: """向指定流发送表情包 Args: @@ -248,10 +255,25 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo Returns: bool: 是否发送成功 """ - return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message) + return await _send_to_target( + "emoji", + emoji_base64, + stream_id, + "", + typing=False, + storage_message=storage_message, + set_reply=set_reply, + reply_message=reply_message, + ) -async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: +async def image_to_stream( + image_base64: str, + stream_id: str, + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, +) -> bool: """向指定流发送图片 Args: @@ -262,11 +284,25 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo Returns: bool: 是否发送成功 """ - return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message) + return await _send_to_target( + "image", + image_base64, + stream_id, + "", + typing=False, + storage_message=storage_message, + set_reply=set_reply, + reply_message=reply_message, + ) async def command_to_stream( - command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + command: Union[str, dict], + stream_id: str, + storage_message: bool = True, + display_message: str = "", + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """向指定流发送命令 @@ -279,7 +315,14 @@ async def command_to_stream( bool: 是否发送成功 """ return await _send_to_target( - "command", command, stream_id, display_message, typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message + "command", + command, + stream_id, + display_message, + typing=False, + storage_message=storage_message, + set_reply=set_reply, + reply_message=reply_message, ) @@ -289,7 +332,7 @@ async def custom_to_stream( stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 174b6fea..03bbc0d6 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -2,13 +2,15 @@ import time import asyncio from abc import ABC, abstractmethod -from typing import Tuple, Optional, Dict, Any +from typing import Tuple, Optional, TYPE_CHECKING from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream -from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType +from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType from src.plugin_system.apis import send_api, database_api, message_api +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("base_action") @@ -206,7 +208,11 @@ class BaseAction(ABC): return False, f"等待新消息失败: {str(e)}" async def send_text( - self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False + self, + content: str, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + typing: bool = False, ) -> bool: """发送文本消息 @@ -229,7 +235,9 @@ class BaseAction(ABC): typing=typing, ) - async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_emoji( + self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None + ) -> bool: """发送表情包 Args: @@ -242,9 +250,13 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 缺少聊天ID") return False - return await send_api.emoji_to_stream(emoji_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message) + return await send_api.emoji_to_stream( + emoji_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message + ) - async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_image( + self, image_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None + ) -> bool: """发送图片 Args: @@ -257,9 +269,18 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 缺少聊天ID") return False - return await send_api.image_to_stream(image_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message) + return await send_api.image_to_stream( + image_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message + ) - async def send_custom(self, message_type: str, content: str, typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_custom( + self, + message_type: str, + content: str, + typing: bool = False, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + ) -> bool: """发送自定义类型消息 Args: @@ -308,7 +329,13 @@ class BaseAction(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + self, + command_name: str, + args: Optional[dict] = None, + display_message: str = "", + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """发送命令消息 diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 35fed909..633eba34 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod -from typing import Dict, Tuple, Optional, Any +from typing import Dict, Tuple, Optional, TYPE_CHECKING from src.common.logger import get_logger from src.plugin_system.base.component_types import CommandInfo, ComponentType from src.chat.message_receive.message import MessageRecv from src.plugin_system.apis import send_api +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + logger = get_logger("base_command") @@ -84,7 +87,13 @@ class BaseCommand(ABC): return current - async def send_text(self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool: + async def send_text( + self, + content: str, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + storage_message: bool = True, + ) -> bool: """发送回复消息 Args: @@ -100,10 +109,22 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, set_reply=set_reply,reply_message=reply_message,storage_message=storage_message) + return await send_api.text_to_stream( + text=content, + stream_id=chat_stream.stream_id, + set_reply=set_reply, + reply_message=reply_message, + storage_message=storage_message, + ) async def send_type( - self, message_type: str, content: str, display_message: str = "", typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + self, + message_type: str, + content: str, + display_message: str = "", + typing: bool = False, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """发送指定类型的回复消息到当前聊天环境 @@ -134,7 +155,13 @@ class BaseCommand(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + self, + command_name: str, + args: Optional[dict] = None, + display_message: str = "", + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """发送命令消息 @@ -177,7 +204,9 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 发送命令时出错: {e}") return False - async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_emoji( + self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None + ) -> bool: """发送表情包 Args: @@ -191,9 +220,17 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message) + return await send_api.emoji_to_stream( + emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message + ) - async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool: + async def send_image( + self, + image_base64: str, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + storage_message: bool = True, + ) -> bool: """发送图片 Args: @@ -207,7 +244,13 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.image_to_stream(image_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message,storage_message=storage_message) + return await send_api.image_to_stream( + image_base64, + chat_stream.stream_id, + set_reply=set_reply, + reply_message=reply_message, + storage_message=storage_message, + ) @classmethod def get_command_info(cls) -> "CommandInfo":