diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index b781dc16..7f8c861b 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -3,13 +3,13 @@ import time import traceback import math import random -from typing import List, Optional, Dict, Any, Tuple +from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING from rich.traceback import install from collections import deque from src.config.config import global_config from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import ActionPlannerInfo from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer @@ -24,12 +24,15 @@ from src.chat.frequency_control.focus_value_control import focus_value_control from src.chat.express.expression_learner import expression_learner_manager from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.person_info import Person -from src.plugin_system.base.component_types import ChatMode, EventType +from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo from src.plugin_system.core import events_manager from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.mais4u.mai_think import mai_thinking_manager from src.mais4u.s4u_config import s4u_config +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + ERROR_LOOP_INFO = { "loop_plan_info": { @@ -100,7 +103,7 @@ class HeartFChatting: self.reply_timeout_count = 0 self.plan_timeout_count = 0 - self.last_read_time = time.time() - 1 + self.last_read_time = time.time() - 10 self.focus_energy = 1 self.no_action_consecutive = 0 @@ -141,7 +144,7 @@ class HeartFChatting: except asyncio.CancelledError: logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") - def start_cycle(self): + def start_cycle(self) -> Tuple[Dict[str, float], str]: self._cycle_counter += 1 self._current_cycle_detail = CycleDetail(self._cycle_counter) self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" @@ -172,7 +175,8 @@ class HeartFChatting: action_type = action_result.get("action_type", "未知动作") elif isinstance(action_result, list) and action_result: # 新格式:action_result是actions列表 - action_type = action_result[0].get("action_type", "未知动作") + # TODO: 把这里写明白 + action_type = action_result[0].action_type or "未知动作" elif isinstance(loop_plan_info, list) and loop_plan_info: # 直接是actions列表的情况 action_type = loop_plan_info[0].get("action_type", "未知动作") @@ -207,7 +211,7 @@ class HeartFChatting: logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息") self.focus_energy = 1 - async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]: + async def _should_process_messages(self, new_message: List["DatabaseMessages"]) -> tuple[bool, float]: """ 判断是否应该处理消息 @@ -265,7 +269,7 @@ class HeartFChatting: return False, 0.0 async def _loopbody(self): - recent_messages_dict = message_api.get_messages_by_time_in_chat( + recent_messages_list = message_api.get_messages_by_time_in_chat( chat_id=self.stream_id, start_time=self.last_read_time, end_time=time.time(), @@ -275,7 +279,7 @@ class HeartFChatting: filter_command=True, ) # 统一的消息处理逻辑 - should_process, interest_value = await self._should_process_messages(recent_messages_dict) + should_process, interest_value = await self._should_process_messages(recent_messages_list) if should_process: self.last_read_time = time.time() @@ -290,11 +294,11 @@ class HeartFChatting: async def _send_and_store_reply( self, response_set, - action_message, + action_message: "DatabaseMessages", cycle_timers: Dict[str, float], thinking_id, actions, - selected_expressions: List[int] = None, + selected_expressions: Optional[List[int]] = None, ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: with Timer("回复发送", cycle_timers): reply_text = await self._send_response( @@ -304,11 +308,11 @@ class HeartFChatting: ) # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.get("chat_info_platform") + platform = action_message.chat_info.platform if platform is None: platform = getattr(self.chat_stream, "platform", "unknown") - person = Person(platform=platform, user_id=action_message.get("user_id", "")) + person = Person(platform=platform, user_id=action_message.user_info.user_id) person_name = person.person_name action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" @@ -353,9 +357,13 @@ class HeartFChatting: k = 2.0 # 控制曲线陡峭程度 x0 = 1.0 # 控制曲线中心点 return 1.0 / (1.0 + math.exp(-k * (interest_val - x0))) - - normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency() - + + normal_mode_probability = ( + calculate_normal_mode_probability(interest_value) + * 2 + * self.talk_frequency_control.get_current_talk_frequency() + ) + # 根据概率决定使用哪种模式 if random.random() < normal_mode_probability: mode = ChatMode.NORMAL @@ -383,17 +391,17 @@ class HeartFChatting: except Exception as e: logger.error(f"{self.log_prefix} 记忆构建失败: {e}") + available_actions: Dict[str, ActionInfo] = {} if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS: # 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考 - actions = [ - { - "action_type": "no_action", - "reasoning": "专注不足", - "action_data": {}, - } + action_to_use_info = [ + ActionPlannerInfo( + action_type="no_action", + reasoning="专注不足", + action_data={}, + ) ] else: - available_actions = {} # 第一步:动作修改 with Timer("动作修改", cycle_timers): try: @@ -414,105 +422,19 @@ class HeartFChatting: ): return False with Timer("规划器", cycle_timers): - actions, _ = await self.action_planner.plan( + action_to_use_info, _ = await self.action_planner.plan( mode=mode, loop_start_time=self.last_read_time, available_actions=available_actions, ) # 3. 并行执行所有动作 - async def execute_action(action_info, actions): - """执行单个动作的通用函数""" - try: - if action_info["action_type"] == "no_action": - # 直接处理no_action逻辑,不再通过动作系统 - reason = action_info.get("reasoning", "选择不回复") - logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") - - # 存储no_action信息到数据库 - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={"reason": reason}, - action_name="no_action", - ) - - return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} - elif action_info["action_type"] != "reply": - # 执行普通动作 - with Timer("动作执行", cycle_timers): - success, reply_text, command = await self._handle_action( - action_info["action_type"], - action_info["reasoning"], - action_info["action_data"], - cycle_timers, - thinking_id, - action_info["action_message"], - ) - return { - "action_type": action_info["action_type"], - "success": success, - "reply_text": reply_text, - "command": command, - } - else: - try: - success, response_set, prompt_selected_expressions = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_info["action_message"], - available_actions=available_actions, - choosen_actions=actions, - reply_reason=action_info.get("reasoning", ""), - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - return_expressions=True, - ) - - if prompt_selected_expressions and len(prompt_selected_expressions) > 1: - _, selected_expressions = prompt_selected_expressions - else: - selected_expressions = [] - - if not success or not response_set: - logger.info( - f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败" - ) - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - - except asyncio.CancelledError: - logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - - loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( - response_set=response_set, - action_message=action_info["action_message"], - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=actions, - selected_expressions=selected_expressions, - ) - return { - "action_type": "reply", - "success": True, - "reply_text": reply_text, - "loop_info": loop_info, - } - except Exception as e: - logger.error(f"{self.log_prefix} 执行动作时出错: {e}") - logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") - return { - "action_type": action_info["action_type"], - "success": False, - "reply_text": "", - "loop_info": None, - "error": str(e), - } - - action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions] + action_tasks = [ + asyncio.create_task( + self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) + ) + for action in action_to_use_info + ] # 并行执行所有任务 results = await asyncio.gather(*action_tasks, return_exceptions=True) @@ -529,7 +451,7 @@ class HeartFChatting: logger.error(f"{self.log_prefix} 动作执行异常: {result}") continue - _cur_action = actions[i] + _cur_action = action_to_use_info[i] if result["action_type"] != "reply": action_success = result["success"] action_reply_text = result["reply_text"] @@ -558,7 +480,7 @@ class HeartFChatting: # 没有回复信息,构建纯动作的loop_info loop_info = { "loop_plan_info": { - "action_result": actions, + "action_result": action_to_use_info, }, "loop_action_info": { "action_taken": action_success, @@ -578,7 +500,7 @@ class HeartFChatting: # await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", "")) - action_type = actions[0]["action_type"] if actions else "no_action" + action_type = action_to_use_info[0].action_type if action_to_use_info else "no_action" # 管理no_action计数器:当执行了非no_action动作时,重置计数器 if action_type != "no_action": @@ -620,7 +542,7 @@ class HeartFChatting: action_data: dict, cycle_timers: Dict[str, float], thinking_id: str, - action_message: dict, + action_message: Optional["DatabaseMessages"] = None, ) -> tuple[bool, str, str]: """ 处理规划动作,使用动作工厂创建相应的动作处理器 @@ -672,8 +594,8 @@ class HeartFChatting: async def _send_response( self, reply_set, - message_data, - selected_expressions: List[int] = None, + message_data: "DatabaseMessages", + selected_expressions: Optional[List[int]] = None, ) -> str: new_message_count = message_api.count_new_messages( chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() @@ -710,3 +632,97 @@ class HeartFChatting: reply_text += data return reply_text + + async def _execute_action( + self, + action_planner_info: ActionPlannerInfo, + chosen_action_plan_infos: List[ActionPlannerInfo], + thinking_id: str, + available_actions: Dict[str, ActionInfo], + cycle_timers: Dict[str, float], + ): + """执行单个动作的通用函数""" + try: + if action_planner_info.action_type == "no_action": + # 直接处理no_action逻辑,不再通过动作系统 + reason = action_planner_info.reasoning or "选择不回复" + logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") + + # 存储no_action信息到数据库 + await database_api.store_action_info( + chat_stream=self.chat_stream, + action_build_into_prompt=False, + action_prompt_display=reason, + action_done=True, + thinking_id=thinking_id, + action_data={"reason": reason}, + action_name="no_action", + ) + + return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} + elif action_planner_info.action_type != "reply": + # 执行普通动作 + with Timer("动作执行", cycle_timers): + success, reply_text, command = await self._handle_action( + action_planner_info.action_type, + action_planner_info.reasoning or "", + action_planner_info.action_data or {}, + cycle_timers, + thinking_id, + action_planner_info.action_message, + ) + return { + "action_type": action_planner_info.action_type, + "success": success, + "reply_text": reply_text, + "command": command, + } + else: + try: + success, response_set, prompt, selected_expressions = await generator_api.generate_reply( + chat_stream=self.chat_stream, + reply_message=action_planner_info.action_message, + available_actions=available_actions, + chosen_actions=chosen_action_plan_infos, + reply_reason=action_planner_info.reasoning or "", + enable_tool=global_config.tool.enable_tool, + request_type="replyer", + from_plugin=False, + return_expressions=True, + ) + + if not success or not response_set: + if action_planner_info.action_message: + logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败") + else: + logger.info("回复生成失败") + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + + except asyncio.CancelledError: + logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + + loop_info, reply_text, _ = await self._send_and_store_reply( + response_set=response_set, + action_message=action_planner_info.action_message, # type: ignore + cycle_timers=cycle_timers, + thinking_id=thinking_id, + actions=chosen_action_plan_infos, + selected_expressions=selected_expressions, + ) + return { + "action_type": "reply", + "success": True, + "reply_text": reply_text, + "loop_info": loop_info, + } + except Exception as e: + logger.error(f"{self.log_prefix} 执行动作时出错: {e}") + logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") + return { + "action_type": action_planner_info.action_type, + "success": False, + "reply_text": "", + "loop_info": None, + "error": str(e), + } diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 781b1152..8716d6bc 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -303,4 +303,4 @@ init_prompt() try: expression_selector = ExpressionSelector() except Exception as e: - print(f"ExpressionSelector初始化失败: {e}") + logger.error(f"ExpressionSelector初始化失败: {e}") diff --git a/src/chat/frequency_control/focus_value_control.py b/src/chat/frequency_control/focus_value_control.py index 0c2b323d..290dcc9e 100644 --- a/src/chat/frequency_control/focus_value_control.py +++ b/src/chat/frequency_control/focus_value_control.py @@ -4,44 +4,43 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id class FocusValueControl: - def __init__(self,chat_id:str): + def __init__(self, chat_id: str): self.chat_id = chat_id - self.focus_value_adjust = 1 - - + self.focus_value_adjust: float = 1 + def get_current_focus_value(self) -> float: return get_current_focus_value(self.chat_id) * self.focus_value_adjust - + class FocusValueControlManager: def __init__(self): - self.focus_value_controls = {} - - def get_focus_value_control(self,chat_id:str) -> FocusValueControl: + self.focus_value_controls: dict[str, FocusValueControl] = {} + + def get_focus_value_control(self, chat_id: str) -> FocusValueControl: if chat_id not in self.focus_value_controls: self.focus_value_controls[chat_id] = FocusValueControl(chat_id) return self.focus_value_controls[chat_id] - def get_current_focus_value(chat_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 focus_value """ if not global_config.chat.focus_value_adjust: return global_config.chat.focus_value - + if chat_id: stream_focus_value = get_stream_specific_focus_value(chat_id) if stream_focus_value is not None: return stream_focus_value - + global_focus_value = get_global_focus_value() if global_focus_value is not None: return global_focus_value - + return global_config.chat.focus_value + def get_stream_specific_focus_value(chat_id: str) -> Optional[float]: """ 获取特定聊天流在当前时间的专注度 @@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]: return None -focus_value_control = FocusValueControlManager() \ No newline at end of file + +focus_value_control = FocusValueControlManager() diff --git a/src/chat/frequency_control/talk_frequency_control.py b/src/chat/frequency_control/talk_frequency_control.py index 382a06ba..ad81fbd8 100644 --- a/src/chat/frequency_control/talk_frequency_control.py +++ b/src/chat/frequency_control/talk_frequency_control.py @@ -2,20 +2,21 @@ from typing import Optional from src.config.config import global_config from src.chat.frequency_control.utils import parse_stream_config_to_chat_id + class TalkFrequencyControl: - def __init__(self,chat_id:str): + def __init__(self, chat_id: str): self.chat_id = chat_id - self.talk_frequency_adjust = 1 - + self.talk_frequency_adjust: float = 1 + def get_current_talk_frequency(self) -> float: return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust - + class TalkFrequencyControlManager: def __init__(self): self.talk_frequency_controls = {} - - def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl: + + def get_talk_frequency_control(self, chat_id: str) -> TalkFrequencyControl: if chat_id not in self.talk_frequency_controls: self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id) return self.talk_frequency_controls[chat_id] @@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float: global_frequency = get_global_frequency() return global_config.chat.talk_frequency if global_frequency is None else global_frequency + def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]: """ 根据时间配置列表获取当前时段的频率 @@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str): return None + def get_global_frequency() -> Optional[float]: """ 获取全局默认频率配置 @@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]: return None -talk_frequency_control = TalkFrequencyControlManager() \ No newline at end of file + +talk_frequency_control = TalkFrequencyControlManager() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 41ba6942..8227306f 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -12,7 +12,7 @@ from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow import heartflow from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.timer_calculator import Timer -from src.chat.utils.chat_message_builder import replace_user_references_sync +from src.chat.utils.chat_message_builder import replace_user_references from src.common.logger import get_logger from src.mood.mood_manager import mood_manager from src.person_info.person_info import Person @@ -131,7 +131,7 @@ class HeartFCMessageReceiver: processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text) # 应用用户引用格式替换,将回复和@格式转换为可读格式 - processed_plain_text = replace_user_references_sync( + processed_plain_text = replace_user_references( processed_plain_text, message.message_info.platform, # type: ignore replace_bot_name=True diff --git a/src/chat/knowledge/__init__.py b/src/chat/knowledge/__init__.py index e69de29b..38f88e10 100644 --- a/src/chat/knowledge/__init__.py +++ b/src/chat/knowledge/__init__.py @@ -0,0 +1,82 @@ +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.qa_manager import QAManager +from src.chat.knowledge.kg_manager import KGManager +from src.chat.knowledge.global_logger import logger +from src.config.config import global_config +import os + +INVALID_ENTITY = [ + "", + "你", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "她们", + "它们", +] + +RAG_GRAPH_NAMESPACE = "rag-graph" +RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" +RAG_PG_HASH_NAMESPACE = "rag-pg-hash" + + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +DATA_PATH = os.path.join(ROOT_PATH, "data") + + +qa_manager = None +inspire_manager = None + +def lpmm_start_up(): # sourcery skip: extract-duplicate-method + # 检查LPMM知识库是否启用 + if global_config.lpmm_knowledge.enable: + logger.info("正在初始化Mai-LPMM") + logger.info("创建LLM客户端") + + # 初始化Embedding库 + embed_manager = EmbeddingManager() + logger.info("正在从文件加载Embedding库") + try: + embed_manager.load_from_file() + except Exception as e: + logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") + # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") + logger.info("Embedding库加载完成") + # 初始化KG + kg_manager = KGManager() + logger.info("正在从文件加载KG") + try: + kg_manager.load_from_file() + except Exception as e: + logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") + # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") + logger.info("KG加载完成") + + logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") + logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") + + # 数据比对:Embedding库与KG的段落hash集合 + for pg_hash in kg_manager.stored_paragraph_hashes: + # 使用与EmbeddingStore中一致的命名空间格式 + key = f"paragraph-{pg_hash}" + if key not in embed_manager.stored_pg_hashes: + logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") + global qa_manager + # 问答系统(用于知识库) + qa_manager = QAManager( + embed_manager, + kg_manager, + ) + + # # 记忆激活(用于记忆库) + # global inspire_manager + # inspire_manager = MemoryActiveManager( + # embed_manager, + # llm_client_list[global_config["embedding"]["provider"]], + # ) + else: + logger.info("LPMM知识库已禁用,跳过初始化") + # 创建空的占位符对象,避免导入错误 diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index 340a678d..4f7bb68a 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -5,7 +5,7 @@ from typing import List, Union from .global_logger import logger from . import prompt_template -from .knowledge_lib import INVALID_ENTITY +from . import INVALID_ENTITY from src.llm_models.utils_model import LLMRequest from json_repair import repair_json diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py deleted file mode 100644 index f3e6eca6..00000000 --- a/src/chat/knowledge/knowledge_lib.py +++ /dev/null @@ -1,80 +0,0 @@ -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.qa_manager import QAManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.global_logger import logger -from src.config.config import global_config -import os - -INVALID_ENTITY = [ - "", - "你", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "她们", - "它们", -] - -RAG_GRAPH_NAMESPACE = "rag-graph" -RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" -RAG_PG_HASH_NAMESPACE = "rag-pg-hash" - - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -DATA_PATH = os.path.join(ROOT_PATH, "data") - - -qa_manager = None -inspire_manager = None - -# 检查LPMM知识库是否启用 -if global_config.lpmm_knowledge.enable: - logger.info("正在初始化Mai-LPMM") - logger.info("创建LLM客户端") - - # 初始化Embedding库 - embed_manager = EmbeddingManager() - logger.info("正在从文件加载Embedding库") - try: - embed_manager.load_from_file() - except Exception as e: - logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") - logger.info("Embedding库加载完成") - # 初始化KG - kg_manager = KGManager() - logger.info("正在从文件加载KG") - try: - kg_manager.load_from_file() - except Exception as e: - logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") - logger.info("KG加载完成") - - logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") - logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") - - # 数据比对:Embedding库与KG的段落hash集合 - for pg_hash in kg_manager.stored_paragraph_hashes: - # 使用与EmbeddingStore中一致的命名空间格式 - key = f"paragraph-{pg_hash}" - if key not in embed_manager.stored_pg_hashes: - logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") - - # 问答系统(用于知识库) - qa_manager = QAManager( - embed_manager, - kg_manager, - ) - - # # 记忆激活(用于记忆库) - # inspire_manager = MemoryActiveManager( - # embed_manager, - # llm_client_list[global_config["embedding"]["provider"]], - # ) -else: - logger.info("LPMM知识库已禁用,跳过初始化") - # 创建空的占位符对象,避免导入错误 diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index 90977fb8..b7ad2060 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -4,7 +4,7 @@ import glob from typing import Any, Dict, List -from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH +from . import INVALID_ENTITY, ROOT_PATH, DATA_PATH # from src.manager.local_store_manager import local_storage diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index b8b31efb..6bbc1dd5 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -60,7 +60,7 @@ class QAManager: for res in relation_search_res: if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]): rel_str = store_item.str - print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") + logger.info(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") # TODO: 使用LLM过滤三元组结果 # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s") @@ -94,7 +94,7 @@ class QAManager: for res in result: raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str - print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") + logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") return result, ppr_node_weights diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 04ceccb2..1b15d717 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -30,9 +30,7 @@ def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) norm1 = np.linalg.norm(v1) norm2 = np.linalg.norm(v2) - if norm1 == 0 or norm2 == 0: - return 0 - return dot_product / (norm1 * norm2) + return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2) install(extra_lines=3) @@ -142,11 +140,10 @@ class MemoryGraph: # 获取当前节点的记忆项 node_data = self.get_dot(topic) if node_data: - concept, data = node_data + _, data = node_data if "memory_items" in data: - memory_items = data["memory_items"] # 直接使用完整的记忆内容 - if memory_items: + if memory_items := data["memory_items"]: first_layer_items.append(memory_items) # 只在depth=2时获取第二层记忆 @@ -154,11 +151,10 @@ class MemoryGraph: # 获取相邻节点的记忆项 for neighbor in neighbors: if node_data := self.get_dot(neighbor): - concept, data = node_data + _, data = node_data if "memory_items" in data: - memory_items = data["memory_items"] # 直接使用完整的记忆内容 - if memory_items: + if memory_items := data["memory_items"]: second_layer_items.append(memory_items) return first_layer_items, second_layer_items @@ -224,27 +220,17 @@ class MemoryGraph: # 获取话题节点数据 node_data = self.G.nodes[topic] + # 删除整个节点 + self.G.remove_node(topic) # 如果节点存在memory_items if "memory_items" in node_data: - memory_items = node_data["memory_items"] - - # 既然每个节点现在是一个完整的记忆内容,直接删除整个节点 - if memory_items: - # 删除整个节点 - self.G.remove_node(topic) + if memory_items := node_data["memory_items"]: return ( f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}" ) - else: - # 如果没有记忆项,删除该节点 - self.G.remove_node(topic) - return None - else: - # 如果没有memory_items字段,删除该节点 - self.G.remove_node(topic) - return None + return None # 海马体 @@ -392,9 +378,8 @@ class Hippocampus: # 如果相似度超过阈值,获取该节点的记忆 if similarity >= 0.3: # 可以调整这个阈值 node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", "") # 直接使用完整的记忆内容 - if memory_items: + if memory_items := node_data.get("memory_items", ""): memories.append((node, memory_items, similarity)) # 按相似度降序排序 @@ -411,7 +396,7 @@ class Hippocampus: 如果为False,使用LLM提取关键词,速度较慢但更准确。 """ if not text: - return [] + return [], [] # 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算 text_length = len(text) @@ -587,7 +572,7 @@ class Hippocampus: unique_memories = [] for topic, memory_items, activation_value in all_memories: # memory_items现在是完整的字符串格式 - memory = memory_items if memory_items else "" + memory = memory_items or "" if memory not in seen_memories: seen_memories.add(memory) unique_memories.append((topic, memory_items, activation_value)) @@ -599,7 +584,7 @@ class Hippocampus: result = [] for topic, memory_items, _ in unique_memories: # memory_items现在是完整的字符串格式 - memory = memory_items if memory_items else "" + memory = memory_items or "" result.append((topic, memory)) logger.debug(f"选中记忆: {memory} (来自节点: {topic})") @@ -1435,13 +1420,11 @@ class HippocampusManager: if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") try: - response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text( - text, max_depth, fast_retrieval - ) + return await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) except Exception as e: logger.error(f"文本产生激活值失败: {e}") logger.error(traceback.format_exc()) - return 0.0, [], [] + return 0.0, [], [] def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: """从关键词获取相关记忆的公共接口""" @@ -1473,6 +1456,7 @@ class MemoryBuilder: self.last_processed_time: float = 0.0 def should_trigger_memory_build(self) -> bool: + # sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else """检查是否应该触发记忆构建""" current_time = time.time() diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index a6be80ef..f8e91b5c 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -11,7 +11,7 @@ from datetime import datetime, timedelta from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.common.database.database_model import Memory # Peewee Models导入 -from src.config.config import model_config +from src.config.config import model_config, global_config logger = get_logger(__name__) @@ -42,7 +42,7 @@ class InstantMemory: request_type="memory.summary", ) - async def if_need_build(self, text): + async def if_need_build(self, text: str): prompt = f""" 请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0 {text} @@ -51,8 +51,9 @@ class InstantMemory: try: response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) - print(prompt) - print(response) + if global_config.debug.show_prompt: + print(prompt) + print(response) return "1" in response except Exception as e: @@ -94,7 +95,7 @@ class InstantMemory: logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}") return None - async def create_and_store_memory(self, text): + async def create_and_store_memory(self, text: str): if_need = await self.if_need_build(text) if if_need: logger.info(f"需要记忆:{text}") @@ -126,24 +127,25 @@ class InstantMemory: from json_repair import repair_json prompt = f""" - 请根据以下发言内容,判断是否需要提取记忆 - {target} - 请用json格式输出,包含以下字段: - 其中,time的要求是: - 可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD - 可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前 - 可以选择留空进行模糊搜索 - {{ - "need_memory": 1, - "keywords": "希望获取的记忆关键词,用/划分", - "time": "希望获取的记忆大致时间" - }} - 请只输出json格式,不要输出其他多余内容 - """ +请根据以下发言内容,判断是否需要提取记忆 +{target} +请用json格式输出,包含以下字段: +其中,time的要求是: +可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD +可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前 +可以选择留空进行模糊搜索 +{{ + "need_memory": 1, + "keywords": "希望获取的记忆关键词,用/划分", + "time": "希望获取的记忆大致时间" +}} +请只输出json格式,不要输出其他多余内容 +""" try: response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) - print(prompt) - print(response) + if global_config.debug.show_prompt: + print(prompt) + print(response) if not response: return None try: diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index beae4136..bb667cbf 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -145,7 +145,7 @@ class ChatBot: logger.error(f"处理命令时出错: {e}") return False, None, True # 出错时继续处理消息 - async def hanle_notice_message(self, message: MessageRecv): + async def handle_notice_message(self, message: MessageRecv): if message.message_info.message_id == "notice": message.is_notify = True logger.info("notice消息") @@ -212,7 +212,7 @@ class ChatBot: # logger.debug(str(message_data)) message = MessageRecv(message_data) - if await self.hanle_notice_message(message): + if await self.handle_notice_message(message): # return pass diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 098e6600..66a1c029 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -115,7 +115,7 @@ class MessageRecv(Message): self.priority_mode = "interest" self.priority_info = None self.interest_value: float = None # type: ignore - + self.key_words = [] self.key_words_lite = [] @@ -213,9 +213,9 @@ class MessageRecvS4U(MessageRecv): self.is_screen = False self.is_internal = False self.voice_done = None - + self.chat_info = None - + async def process(self) -> None: self.processed_plain_text = await self._process_message_segments(self.message_segment) @@ -420,7 +420,7 @@ class MessageSending(MessageProcessBase): thinking_start_time: float = 0, apply_set_reply_logic: bool = False, reply_to: Optional[str] = None, - selected_expressions:List[int] = None, + selected_expressions: Optional[List[int]] = None, ): # 调用父类初始化 super().__init__( @@ -445,7 +445,7 @@ class MessageSending(MessageProcessBase): self.display_message = display_message self.interest_value = 0.0 - + self.selected_expressions = selected_expressions def build_reply(self): diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 267b7a8f..b4587474 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -2,6 +2,7 @@ from typing import Dict, Optional, Type from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages from src.plugin_system.core.component_registry import component_registry from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.plugin_system.base.base_action import BaseAction @@ -37,7 +38,7 @@ class ActionManager: chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, - action_message: Optional[dict] = None, + action_message: Optional[DatabaseMessages] = None, ) -> Optional[BaseAction]: """ 创建动作处理器实例 @@ -83,7 +84,7 @@ class ActionManager: log_prefix=log_prefix, shutting_down=shutting_down, plugin_config=plugin_config, - action_message=action_message, + action_message=action_message.flatten() if action_message else None, ) logger.debug(f"创建Action实例成功: {action_name}") @@ -123,4 +124,4 @@ class ActionManager: """恢复到默认动作集""" actions_to_restore = list(self._using_actions.keys()) self._using_actions = component_registry.get_default_actions() - logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") + logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") \ No newline at end of file diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 5e0695c3..2cb2a469 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -1,7 +1,7 @@ import json import time import traceback -from typing import Dict, Any, Optional, Tuple, List +from typing import Dict, Optional, Tuple, List from rich.traceback import install from datetime import datetime from json_repair import repair_json @@ -9,6 +9,8 @@ from json_repair import repair_json from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import ActionPlannerInfo from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, @@ -97,7 +99,9 @@ class ActionPlanner: self.plan_retry_count = 0 self.max_plan_retries = 3 - def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: + def find_message_by_id( + self, message_id: str, message_id_list: List[DatabaseMessages] + ) -> Optional[DatabaseMessages]: # sourcery skip: use-next """ 根据message_id从message_id_list中查找对应的原始消息 @@ -110,37 +114,37 @@ class ActionPlanner: 找到的原始消息字典,如果未找到则返回None """ for item in message_id_list: - if item.get("id") == message_id: - return item.get("message") + if item.message_id == message_id: + return item return None - def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: + def get_latest_message(self, message_id_list: List[DatabaseMessages]) -> Optional[DatabaseMessages]: """ 获取消息列表中的最新消息 - + Args: message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...] - + Returns: 最新的消息字典,如果列表为空则返回None """ - return message_id_list[-1].get("message") if message_id_list else None + return message_id_list[-1] if message_id_list else None async def plan( self, mode: ChatMode = ChatMode.FOCUS, - loop_start_time:float = 0.0, + loop_start_time: float = 0.0, available_actions: Optional[Dict[str, ActionInfo]] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + ) -> Tuple[List[ActionPlannerInfo], Optional[DatabaseMessages]]: """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ - action = "no_action" # 默认动作 - reasoning = "规划器初始化默认" + action: str = "no_action" # 默认动作 + reasoning: str = "规划器初始化默认" action_data = {} current_available_actions: Dict[str, ActionInfo] = {} - target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量 + target_message: Optional[DatabaseMessages] = None # 初始化target_message变量 prompt: str = "" message_id_list: list = [] @@ -208,19 +212,21 @@ class ActionPlanner: # 如果获取的target_message为None,输出warning并重新plan if target_message is None: self.plan_retry_count += 1 - logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}") + logger.warning( + f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}" + ) # 仍有重试次数 if self.plan_retry_count < self.max_plan_retries: # 递归重新plan return await self.plan(mode, loop_start_time, available_actions) - logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message") + logger.error( + f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message" + ) target_message = self.get_latest_message(message_id_list) self.plan_retry_count = 0 # 重置计数器 else: logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id") - - if action != "no_action" and action != "reply" and action not in current_available_actions: logger.warning( f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'" @@ -244,38 +250,37 @@ class ActionPlanner: if mode == ChatMode.NORMAL and action in current_available_actions: is_parallel = current_available_actions[action].parallel_action - action_data["loop_start_time"] = loop_start_time actions = [ - { - "action_type": action, - "reasoning": reasoning, - "action_data": action_data, - "action_message": target_message, - "available_actions": available_actions, - } + ActionPlannerInfo( + action_type=action, + reasoning=reasoning, + action_data=action_data, + action_message=target_message, + available_actions=available_actions, + ) ] if action != "reply" and is_parallel: - actions.append({ - "action_type": "reply", - "action_message": target_message, - "available_actions": available_actions - }) + actions.append( + ActionPlannerInfo( + action_type="reply", + action_message=target_message, + available_actions=available_actions, + ) + ) - return actions,target_message - - + return actions, target_message async def build_planner_prompt( self, is_group_chat: bool, # Now passed as argument chat_target_info: Optional[dict], # Now passed as argument current_available_actions: Dict[str, ActionInfo], - refresh_time :bool = False, + refresh_time: bool = False, mode: ChatMode = ChatMode.FOCUS, - ) -> tuple[str, list]: # sourcery skip: use-join + ) -> tuple[str, List[DatabaseMessages]]: # sourcery skip: use-join """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: message_list_before_now = get_raw_msg_before_timestamp_with_chat( @@ -305,13 +310,12 @@ class ActionPlanner: actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" if refresh_time: self.last_obs_time_mark = time.time() - + mentioned_bonus = "" if global_config.chat.mentioned_bot_inevitable_reply: mentioned_bonus = "\n- 有人提到你" if global_config.chat.at_bot_inevitable_reply: mentioned_bonus = "\n- 有人提到你,或者at你" - if mode == ChatMode.FOCUS: no_action_block = """ @@ -332,7 +336,7 @@ class ActionPlanner: """ chat_context_description = "你现在正在一个群聊中" - chat_target_name = None + chat_target_name = None if not is_group_chat and chat_target_info: chat_target_name = ( chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方" @@ -388,7 +392,7 @@ class ActionPlanner: action_options_text=action_options_block, moderation_prompt=moderation_prompt_block, identity_block=identity_block, - plan_style = global_config.personality.plan_style + plan_style=global_config.personality.plan_style, ) return prompt, message_id_list except Exception as e: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 87ff7bdb..6ef225e4 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -9,6 +9,7 @@ from datetime import datetime from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import ActionPlannerInfo from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest @@ -21,7 +22,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, - replace_user_references_sync, + replace_user_references, ) from src.chat.express.expression_selector import expression_selector from src.chat.memory_system.memory_activator import MemoryActivator @@ -157,12 +158,12 @@ class DefaultReplyer: extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - chosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List[ActionPlannerInfo]] = None, enable_tool: bool = True, from_plugin: bool = True, stream_id: Optional[str] = None, - reply_message: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]: + reply_message: Optional[DatabaseMessages] = None, + ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[List[int]]]: # sourcery skip: merge-nested-ifs """ 回复器 (Replier): 负责生成回复文本的核心逻辑。 @@ -181,7 +182,7 @@ class DefaultReplyer: """ prompt = None - selected_expressions = None + selected_expressions: Optional[List[int]] = None if available_actions is None: available_actions = {} try: @@ -374,7 +375,12 @@ class DefaultReplyer: ) if global_config.memory.enable_instant_memory: - asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history)) + chat_history_str = build_readable_messages( + messages=chat_history, + replace_bot_name=True, + timestamp_mode="normal" + ) + asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str)) instant_memory = await self.instant_memory.get_memory(target) logger.info(f"即时记忆:{instant_memory}") @@ -527,7 +533,7 @@ class DefaultReplyer: Returns: Tuple[str, str]: (核心对话prompt, 背景对话prompt) """ - core_dialogue_list = [] + core_dialogue_list: List[DatabaseMessages] = [] bot_id = str(global_config.bot.qq_account) # 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话 @@ -559,7 +565,7 @@ class DefaultReplyer: if core_dialogue_list: # 检查最新五条消息中是否包含bot自己说的消息 latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list - has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) + has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages) # logger.info(f"最新五条消息:{latest_5_messages}") # logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}") @@ -634,7 +640,7 @@ class DefaultReplyer: return mai_think async def build_actions_prompt( - self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None + self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None ) -> str: """构建动作提示""" @@ -646,20 +652,21 @@ class DefaultReplyer: action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - choosen_action_descriptions = "" - if choosen_actions: - for action in choosen_actions: - action_name = action.get("action_type", "unknown_action") + chosen_action_descriptions = "" + if chosen_actions_info: + for action_plan_info in chosen_actions_info: + action_name = action_plan_info.action_type if action_name == "reply": continue - action_description = action.get("reason", "无描述") - reasoning = action.get("reasoning", "无原因") + if action := available_actions.get(action_name): + action_description = action.description or "无描述" + reasoning = action_plan_info.reasoning or "无原因" - choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" + chosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" - if choosen_action_descriptions: + if chosen_action_descriptions: action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n" - action_descriptions += choosen_action_descriptions + action_descriptions += chosen_action_descriptions return action_descriptions @@ -668,9 +675,9 @@ class DefaultReplyer: extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - chosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List[ActionPlannerInfo]] = None, enable_tool: bool = True, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional[DatabaseMessages] = None, ) -> Tuple[str, List[int]]: """ 构建回复器上下文 @@ -694,11 +701,11 @@ class DefaultReplyer: platform = chat_stream.platform if reply_message: - user_id = reply_message.get("user_id", "") + user_id = reply_message.user_info.user_id person = Person(platform=platform, user_id=user_id) person_name = person.person_name or user_id sender = person_name - target = reply_message.get("processed_plain_text") + target = reply_message.processed_plain_text else: person_name = "用户" sender = "用户" @@ -710,7 +717,7 @@ class DefaultReplyer: else: mood_prompt = "" - target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) + target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, @@ -774,11 +781,13 @@ class DefaultReplyer: logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s") expression_habits_block, selected_expressions = results_dict["expression_habits"] - relation_info = results_dict["relation_info"] - memory_block = results_dict["memory_block"] - tool_info = results_dict["tool_info"] - prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果 - actions_info = results_dict["actions_info"] + expression_habits_block: str + selected_expressions: List[int] + relation_info: str = results_dict["relation_info"] + memory_block: str = results_dict["memory_block"] + tool_info: str = results_dict["tool_info"] + prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果 + actions_info: str = results_dict["actions_info"] keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) if extra_info: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 51ecb46d..51edd045 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -19,8 +19,8 @@ install(extra_lines=3) logger = get_logger("chat_message_builder") -def replace_user_references_sync( - content: str, +def replace_user_references( + content: Optional[str], platform: str, name_resolver: Optional[Callable[[str, str], str]] = None, replace_bot_name: bool = True, @@ -38,6 +38,8 @@ def replace_user_references_sync( Returns: str: 处理后的内容字符串 """ + if not content: + return "" if name_resolver is None: def default_resolver(platform: str, user_id: str) -> str: @@ -93,80 +95,6 @@ def replace_user_references_sync( return content -async def replace_user_references_async( - content: str, - platform: str, - name_resolver: Optional[Callable[[str, str], Any]] = None, - replace_bot_name: bool = True, -) -> str: - """ - 替换内容中的用户引用格式,包括回复和@格式 - - Args: - content: 要处理的内容字符串 - platform: 平台标识 - name_resolver: 名称解析函数,接收(platform, user_id)参数,返回用户名称 - 如果为None,则使用默认的person_info_manager - replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)" - - Returns: - str: 处理后的内容字符串 - """ - if name_resolver is None: - - async def default_resolver(platform: str, user_id: str) -> str: - # 检查是否是机器人自己 - if replace_bot_name and user_id == global_config.bot.qq_account: - return f"{global_config.bot.nickname}(你)" - person = Person(platform=platform, user_id=user_id) - return person.person_name or user_id # type: ignore - - name_resolver = default_resolver - - # 处理回复格式 - reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" - match = re.search(reply_pattern, content) - if match: - aaa = match.group(1) - bbb = match.group(2) - try: - # 检查是否是机器人自己 - if replace_bot_name and bbb == global_config.bot.qq_account: - reply_person_name = f"{global_config.bot.nickname}(你)" - else: - reply_person_name = await name_resolver(platform, bbb) or aaa - content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1) - except Exception: - # 如果解析失败,使用原始昵称 - content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1) - - # 处理@格式 - at_pattern = r"@<([^:<>]+):([^:<>]+)>" - at_matches = list(re.finditer(at_pattern, content)) - if at_matches: - new_content = "" - last_end = 0 - for m in at_matches: - new_content += content[last_end : m.start()] - aaa = m.group(1) - bbb = m.group(2) - try: - # 检查是否是机器人自己 - if replace_bot_name and bbb == global_config.bot.qq_account: - at_person_name = f"{global_config.bot.nickname}(你)" - else: - at_person_name = await name_resolver(platform, bbb) or aaa - new_content += f"@{at_person_name}" - except Exception: - # 如果解析失败,使用原始昵称 - new_content += f"@{aaa}" - last_end = m.end() - new_content += content[last_end:] - content = new_content - - return content - - def get_raw_msg_by_timestamp(timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"): """ 获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 @@ -498,7 +426,7 @@ def _build_readable_messages_internal( person_name = f"{global_config.bot.nickname}(你)" # 使用独立函数处理用户引用格式 - if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name): + if content := replace_user_references(content, platform, replace_bot_name=replace_bot_name): detailed_messages_raw.append((timestamp, person_name, content, False)) if not detailed_messages_raw: @@ -658,7 +586,10 @@ async def build_readable_messages_with_list( 允许通过参数控制格式化行为。 """ formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( - convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate + [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages], + replace_bot_name, + timestamp_mode, + truncate, ) if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): @@ -725,19 +656,7 @@ def build_readable_messages( if not messages: return "" - copy_messages: List[MessageAndActionModel] = [ - MessageAndActionModel( - msg.time, - msg.user_info.user_id, - msg.user_info.platform, - msg.user_info.user_nickname, - msg.user_info.user_cardname, - msg.processed_plain_text, - msg.display_message, - msg.chat_info.platform, - ) - for msg in messages - ] + copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages] if show_actions and copy_messages: # 获取所有消息的时间范围 @@ -942,7 +861,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str: except Exception: return "?" - content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False) + content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False) header = f"{anon_name}说 " output_lines.append(header) @@ -996,22 +915,3 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set.add(person_id) return list(person_ids_set) # 将集合转换为列表返回 - - -def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]: - """ - 将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。 - """ - return [ - MessageAndActionModel( - time=msg.time, - user_id=msg.user_info.user_id, - user_platform=msg.user_info.platform, - user_nickname=msg.user_info.user_nickname, - user_cardname=msg.user_info.user_cardname, - processed_plain_text=msg.processed_plain_text, - display_message=msg.display_message, - chat_info_platform=msg.chat_info.platform, - ) - for msg in message - ] diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 3528fe4b..472a9cdd 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,7 +11,6 @@ from collections import Counter from typing import Optional, Tuple, Dict, List, Any from src.common.logger import get_logger -from src.common.data_models.info_data_model import TargetPersonInfo from src.common.data_models.database_data_model import DatabaseMessages from src.common.message_repository import find_messages, count_messages from src.config.config import global_config, model_config @@ -641,6 +640,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: platform: str = chat_stream.platform user_id: str = user_info.user_id # type: ignore + from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题 + # Initialize target_info with basic info target_info = TargetPersonInfo( platform=platform, diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 59761d09..1f671890 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Optional, Any, Dict from dataclasses import dataclass, field from . import BaseDataModel @@ -157,3 +157,42 @@ class DatabaseMessages(BaseDataModel): # assert isinstance(self.interest_value, float) or self.interest_value is None, ( # "interest_value must be a float or None" # ) + def flatten(self) -> Dict[str, Any]: + """ + 将消息数据模型转换为字典格式,便于存储或传输 + """ + return { + "message_id": self.message_id, + "time": self.time, + "chat_id": self.chat_id, + "reply_to": self.reply_to, + "interest_value": self.interest_value, + "key_words": self.key_words, + "key_words_lite": self.key_words_lite, + "is_mentioned": self.is_mentioned, + "processed_plain_text": self.processed_plain_text, + "display_message": self.display_message, + "priority_mode": self.priority_mode, + "priority_info": self.priority_info, + "additional_config": self.additional_config, + "is_emoji": self.is_emoji, + "is_picid": self.is_picid, + "is_command": self.is_command, + "is_notify": self.is_notify, + "selected_expressions": self.selected_expressions, + "user_id": self.user_info.user_id, + "user_nickname": self.user_info.user_nickname, + "user_cardname": self.user_info.user_cardname, + "user_platform": self.user_info.platform, + "chat_info_group_id": self.group_info.group_id if self.group_info else None, + "chat_info_group_name": self.group_info.group_name if self.group_info else None, + "chat_info_group_platform": self.group_info.group_platform if self.group_info else None, + "chat_info_stream_id": self.chat_info.stream_id, + "chat_info_platform": self.chat_info.platform, + "chat_info_create_time": self.chat_info.create_time, + "chat_info_last_active_time": self.chat_info.last_active_time, + "chat_info_user_platform": self.chat_info.user_info.platform, + "chat_info_user_id": self.chat_info.user_info.user_id, + "chat_info_user_nickname": self.chat_info.user_info.user_nickname, + "chat_info_user_cardname": self.chat_info.user_info.user_cardname, + } diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index ae3678d1..0f7b1f95 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,12 +1,25 @@ from dataclasses import dataclass, field -from typing import Optional - +from typing import Optional, Dict, TYPE_CHECKING from . import BaseDataModel +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + from src.plugin_system.base.component_types import ActionInfo + + @dataclass class TargetPersonInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) person_id: Optional[str] = None - person_name: Optional[str] = None \ No newline at end of file + person_name: Optional[str] = None + + +@dataclass +class ActionPlannerInfo(BaseDataModel): + action_type: str = field(default_factory=str) + reasoning: Optional[str] = None + action_data: Optional[Dict] = None + action_message: Optional["DatabaseMessages"] = None + available_actions: Optional[Dict[str, "ActionInfo"]] = None diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py index 0fa87ba0..8e0b7786 100644 --- a/src/common/data_models/message_data_model.py +++ b/src/common/data_models/message_data_model.py @@ -1,10 +1,15 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING from dataclasses import dataclass, field from . import BaseDataModel +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + + @dataclass class MessageAndActionModel(BaseDataModel): + chat_id: str = field(default_factory=str) time: float = field(default_factory=float) user_id: str = field(default_factory=str) user_platform: str = field(default_factory=str) @@ -15,3 +20,17 @@ class MessageAndActionModel(BaseDataModel): chat_info_platform: str = field(default_factory=str) is_action_record: bool = field(default=False) action_name: Optional[str] = None + + @classmethod + def from_DatabaseMessages(cls, message: "DatabaseMessages"): + return cls( + chat_id=message.chat_id, + time=message.time, + user_id=message.user_info.user_id, + user_platform=message.user_info.platform, + user_nickname=message.user_info.user_nickname, + user_cardname=message.user_info.user_cardname, + processed_plain_text=message.processed_plain_text, + display_message=message.display_message, + chat_info_platform=message.chat_info.platform, + ) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 93a41a3d..d253d29c 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -47,10 +47,13 @@ logger = get_logger("Gemini客户端") # gemini_thinking参数(默认范围) # 不同模型的思考预算范围配置 THINKING_BUDGET_LIMITS = { - "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, - "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, - "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, + "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, + "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, + "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, } +# 思维预算特殊值 +THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 +THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) gemini_safe_settings = [ SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), @@ -91,9 +94,7 @@ def _convert_messages( for item in message.content: if isinstance(item, tuple): image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower() - content.append( - Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}") - ) + content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")) elif isinstance(item, str): content.append(Part.from_text(text=item)) else: @@ -336,47 +337,40 @@ class GeminiClient(BaseClient): api_key=api_provider.api_key, ) # 这里和openai不一样,gemini会自己决定自己是否需要retry - # 思维预算特殊值 - THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 - THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) - @staticmethod - def clamp_thinking_budget(tb: int, model_id: str): + def clamp_thinking_budget(tb: int, model_id: str) -> int: """ 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) """ limits = None - matched_key = None # 优先尝试精确匹配 if model_id in THINKING_BUDGET_LIMITS: limits = THINKING_BUDGET_LIMITS[model_id] - matched_key = model_id else: # 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先 sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True) for key in sorted_keys: # 必须满足:完全等于 或者 前缀匹配(带 "-" 边界) - if model_id == key or model_id.startswith(key + "-"): - limits = THINKING_BUDGET_LIMITS[key] - matched_key = key - break + if model_id == key or model_id.startswith(f"{key}-"): + limits = THINKING_BUDGET_LIMITS[key] + break # 特殊值处理 - if tb == GeminiClient.THINKING_BUDGET_AUTO: - return GeminiClient.THINKING_BUDGET_AUTO - if tb == GeminiClient.THINKING_BUDGET_DISABLED: + if tb == THINKING_BUDGET_AUTO: + return THINKING_BUDGET_AUTO + if tb == THINKING_BUDGET_DISABLED: if limits and limits.get("can_disable", False): - return GeminiClient.THINKING_BUDGET_DISABLED - return limits["min"] if limits else GeminiClient.THINKING_BUDGET_AUTO + return THINKING_BUDGET_DISABLED + return limits["min"] if limits else THINKING_BUDGET_AUTO # 已知模型裁剪到范围 if limits: - return max(limits["min"], min(tb, limits["max"])) + return max(limits["min"], min(tb, limits["max"])) # 未知模型,返回动态模式 logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。") - return GeminiClient.THINKING_BUDGET_AUTO + return THINKING_BUDGET_AUTO async def get_response( self, @@ -424,15 +418,13 @@ class GeminiClient(BaseClient): # 将tool_options转换为Gemini API所需的格式 tools = _convert_tool_options(tool_options) if tool_options else None - tb = GeminiClient.THINKING_BUDGET_AUTO - #空处理 + tb = THINKING_BUDGET_AUTO + # 空处理 if extra_params and "thinking_budget" in extra_params: try: tb = int(extra_params["thinking_budget"]) except (ValueError, TypeError): - logger.warning( - f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}" - ) + logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}") # 裁剪到模型支持的范围 tb = self.clamp_thinking_budget(tb, model_info.model_identifier) diff --git a/src/main.py b/src/main.py index cad85a0e..51d83c13 100644 --- a/src/main.py +++ b/src/main.py @@ -13,6 +13,7 @@ from src.common.logger import get_logger from src.individuality.individuality import get_individuality, Individuality from src.common.server import get_global_server, Server from src.mood.mood_manager import mood_manager +from src.chat.knowledge import lpmm_start_up from rich.traceback import install from src.migrate_helper.migrate import check_and_run_migrations # from src.api.main import start_api_server @@ -83,6 +84,9 @@ class MainSystem: # 启动API服务器 # start_api_server() # logger.info("API服务器启动成功") + + # 启动LPMM + lpmm_start_up() # 加载所有actions,包括默认的和插件的 plugin_manager.load_all_plugins() @@ -104,7 +108,6 @@ class MainSystem: logger.info("情绪管理器初始化成功") # 初始化聊天管理器 - await get_chat_manager()._initialize() asyncio.create_task(get_chat_manager()._auto_save_task()) diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 0fe759bd..9122dc77 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -3,6 +3,7 @@ import asyncio import json import time import random +import math from json_repair import repair_json from typing import Union, Optional @@ -16,6 +17,7 @@ from src.config.config import global_config, model_config logger = get_logger("person_info") + def get_person_id(platform: str, user_id: Union[int, str]) -> str: """获取唯一id""" if "-" in platform: @@ -24,6 +26,7 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str: key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() + def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" try: @@ -33,7 +36,8 @@ def get_person_id_by_person_name(person_name: str) -> str: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") return "" -def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool: + +def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore if person_id: person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) return person.is_known if person else False @@ -47,89 +51,84 @@ def is_person_known(person_id: str = None,user_id: str = None,platform: str = No return person.is_known if person else False else: return False - - -def get_catagory_from_memory(memory_point:str) -> str: + + +def get_category_from_memory(memory_point: str) -> Optional[str]: """从记忆点中获取分类""" # 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类 if not isinstance(memory_point, str): return None parts = memory_point.split(":", 1) - if len(parts) > 1: - return parts[0].strip() - else: - return None - -def get_weight_from_memory(memory_point:str) -> float: + return parts[0].strip() if len(parts) > 1 else None + + +def get_weight_from_memory(memory_point: str) -> float: """从记忆点中获取权重""" # 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重 if not isinstance(memory_point, str): - return None + return -math.inf parts = memory_point.rsplit(":", 1) - if len(parts) > 1: - try: - return float(parts[-1].strip()) - except Exception: - return None - else: - return None - -def get_memory_content_from_memory(memory_point:str) -> str: + if len(parts) <= 1: + return -math.inf + try: + return float(parts[-1].strip()) + except Exception: + return -math.inf + + +def get_memory_content_from_memory(memory_point: str) -> str: """从记忆点中获取记忆内容""" # 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容 if not isinstance(memory_point, str): - return None + return "" parts = memory_point.split(":") - if len(parts) > 2: - return ":".join(parts[1:-1]).strip() - else: - return None - - + return ":".join(parts[1:-1]).strip() if len(parts) > 2 else "" + + def calculate_string_similarity(s1: str, s2: str) -> float: """ 计算两个字符串的相似度 - + Args: s1: 第一个字符串 s2: 第二个字符串 - + Returns: float: 相似度,范围0-1,1表示完全相同 """ if s1 == s2: return 1.0 - + if not s1 or not s2: return 0.0 - + # 计算Levenshtein距离 - - + distance = levenshtein_distance(s1, s2) max_len = max(len(s1), len(s2)) - + # 计算相似度:1 - (编辑距离 / 最大长度) similarity = 1 - (distance / max_len if max_len > 0 else 0) return similarity + def levenshtein_distance(s1: str, s2: str) -> int: """ 计算两个字符串的编辑距离 - + Args: s1: 第一个字符串 s2: 第二个字符串 - + Returns: int: 编辑距离 """ if len(s1) < len(s2): return levenshtein_distance(s2, s1) - + if len(s2) == 0: return len(s1) - + previous_row = range(len(s2) + 1) for i, c1 in enumerate(s1): current_row = [i + 1] @@ -139,44 +138,45 @@ def levenshtein_distance(s1: str, s2: str) -> int: substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row - + return previous_row[-1] + class Person: @classmethod def register_person(cls, platform: str, user_id: str, nickname: str): """ 注册新用户的类方法 必须输入 platform、user_id 和 nickname 参数 - + Args: platform: 平台名称 user_id: 用户ID nickname: 用户昵称 - + Returns: Person: 新注册的Person实例 """ if not platform or not user_id or not nickname: logger.error("注册用户失败:platform、user_id 和 nickname 都是必需参数") return None - + # 生成唯一的person_id person_id = get_person_id(platform, user_id) - + if is_person_known(person_id=person_id): logger.debug(f"用户 {nickname} 已存在") return Person(person_id=person_id) - + # 创建Person实例 person = cls.__new__(cls) - + # 设置基本属性 person.person_id = person_id person.platform = platform person.user_id = user_id person.nickname = nickname - + # 初始化默认值 person.is_known = True # 注册后立即标记为已认识 person.person_name = nickname # 使用nickname作为初始person_name @@ -185,34 +185,34 @@ class Person: person.know_since = time.time() person.last_know = time.time() person.memory_points = [] - + # 初始化性格特征相关字段 person.attitude_to_me = 0 person.attitude_to_me_confidence = 1 - + person.neuroticism = 5 person.neuroticism_confidence = 1 - + person.friendly_value = 50 person.friendly_value_confidence = 1 - + person.rudeness = 50 person.rudeness_confidence = 1 - + person.conscientiousness = 50 person.conscientiousness_confidence = 1 - + person.likeness = 50 person.likeness_confidence = 1 - + # 同步到数据库 person.sync_to_database() - + logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}") - + return person - - def __init__(self, platform: str = "", user_id: str = "",person_id: str = "",person_name: str = ""): + + def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""): if platform == global_config.bot.platform and user_id == global_config.bot.qq_account: self.is_known = True self.person_id = get_person_id(platform, user_id) @@ -221,10 +221,10 @@ class Person: self.nickname = global_config.bot.nickname self.person_name = global_config.bot.nickname return - + self.user_id = "" self.platform = "" - + if person_id: self.person_id = person_id elif person_name: @@ -232,7 +232,7 @@ class Person: if not self.person_id: self.is_known = False logger.warning(f"根据用户名 {person_name} 获取用户ID时,不存在用户{person_name}") - return + return elif platform and user_id: self.person_id = get_person_id(platform, user_id) self.user_id = user_id @@ -240,17 +240,16 @@ class Person: else: logger.error("Person 初始化失败,缺少必要参数") raise ValueError("Person 初始化失败,缺少必要参数") - + if not is_person_known(person_id=self.person_id): self.is_known = False logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识") self.person_name = f"未知用户{self.person_id[:4]}" return # raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识") - - + self.is_known = False - + # 初始化默认值 self.nickname = "" self.person_name: Optional[str] = None @@ -259,47 +258,47 @@ class Person: self.know_since = None self.last_know = None self.memory_points = [] - + # 初始化性格特征相关字段 - self.attitude_to_me:float = 0 - self.attitude_to_me_confidence:float = 1 - - self.neuroticism:float = 5 - self.neuroticism_confidence:float = 1 - - self.friendly_value:float = 50 - self.friendly_value_confidence:float = 1 - - self.rudeness:float = 50 - self.rudeness_confidence:float = 1 - - self.conscientiousness:float = 50 - self.conscientiousness_confidence:float = 1 - - self.likeness:float = 50 - self.likeness_confidence:float = 1 - + self.attitude_to_me: float = 0 + self.attitude_to_me_confidence: float = 1 + + self.neuroticism: float = 5 + self.neuroticism_confidence: float = 1 + + self.friendly_value: float = 50 + self.friendly_value_confidence: float = 1 + + self.rudeness: float = 50 + self.rudeness_confidence: float = 1 + + self.conscientiousness: float = 50 + self.conscientiousness_confidence: float = 1 + + self.likeness: float = 50 + self.likeness_confidence: float = 1 + # 从数据库加载数据 self.load_from_database() - + def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95): """ 删除指定分类和记忆内容的记忆点 - + Args: category: 记忆分类 memory_content: 要删除的记忆内容 similarity_threshold: 相似度阈值,默认0.95(95%) - + Returns: int: 删除的记忆点数量 """ if not self.memory_points: return 0 - + deleted_count = 0 memory_points_to_keep = [] - + for memory_point in self.memory_points: # 跳过None值 if memory_point is None: @@ -310,80 +309,76 @@ class Person: # 格式不正确,保留原样 memory_points_to_keep.append(memory_point) continue - + memory_category = parts[0].strip() memory_text = parts[1].strip() memory_weight = parts[2].strip() - + # 检查分类是否匹配 if memory_category != category: memory_points_to_keep.append(memory_point) continue - + # 计算记忆内容的相似度 similarity = calculate_string_similarity(memory_content, memory_text) - + # 如果相似度达到阈值,则删除(不添加到保留列表) if similarity >= similarity_threshold: deleted_count += 1 logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})") else: memory_points_to_keep.append(memory_point) - + # 更新memory_points self.memory_points = memory_points_to_keep - + # 同步到数据库 if deleted_count > 0: self.sync_to_database() logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}") - + return deleted_count - - - def get_all_category(self): category_list = [] for memory in self.memory_points: if memory is None: continue - category = get_catagory_from_memory(memory) + category = get_category_from_memory(memory) if category and category not in category_list: category_list.append(category) return category_list - - - def get_memory_list_by_category(self,category:str): + + def get_memory_list_by_category(self, category: str): memory_list = [] for memory in self.memory_points: if memory is None: continue - if get_catagory_from_memory(memory) == category: + if get_category_from_memory(memory) == category: memory_list.append(memory) return memory_list - - def get_random_memory_by_category(self,category:str,num:int=1): + + def get_random_memory_by_category(self, category: str, num: int = 1): memory_list = self.get_memory_list_by_category(category) if len(memory_list) < num: return memory_list return random.sample(memory_list, num) - + def load_from_database(self): """从数据库加载个人信息数据""" try: # 查询数据库中的记录 record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) - + if record: - self.user_id = record.user_id if record.user_id else "" - self.platform = record.platform if record.platform else "" - self.is_known = record.is_known if record.is_known else False - self.nickname = record.nickname if record.nickname else "" - self.person_name = record.person_name if record.person_name else self.nickname - self.name_reason = record.name_reason if record.name_reason else None - self.know_times = record.know_times if record.know_times else 0 - + self.user_id = record.user_id or "" + self.platform = record.platform or "" + self.is_known = record.is_known or False + self.nickname = record.nickname or "" + self.person_name = record.person_name or self.nickname + self.name_reason = record.name_reason or None + self.know_times = record.know_times or 0 + # 处理points字段(JSON格式的列表) if record.memory_points: try: @@ -398,53 +393,53 @@ class Person: self.memory_points = [] else: self.memory_points = [] - + # 加载性格特征相关字段 if record.attitude_to_me and not isinstance(record.attitude_to_me, str): self.attitude_to_me = record.attitude_to_me - + if record.attitude_to_me_confidence is not None: self.attitude_to_me_confidence = float(record.attitude_to_me_confidence) - + if record.friendly_value is not None: self.friendly_value = float(record.friendly_value) - + if record.friendly_value_confidence is not None: self.friendly_value_confidence = float(record.friendly_value_confidence) - + if record.rudeness is not None: self.rudeness = float(record.rudeness) - + if record.rudeness_confidence is not None: self.rudeness_confidence = float(record.rudeness_confidence) - + if record.neuroticism and not isinstance(record.neuroticism, str): self.neuroticism = float(record.neuroticism) - + if record.neuroticism_confidence is not None: self.neuroticism_confidence = float(record.neuroticism_confidence) - + if record.conscientiousness is not None: self.conscientiousness = float(record.conscientiousness) - + if record.conscientiousness_confidence is not None: self.conscientiousness_confidence = float(record.conscientiousness_confidence) - + if record.likeness is not None: self.likeness = float(record.likeness) - + if record.likeness_confidence is not None: self.likeness_confidence = float(record.likeness_confidence) - + logger.debug(f"已从数据库加载用户 {self.person_id} 的信息") else: self.sync_to_database() logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建") - + except Exception as e: logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}") # 出错时保持默认值 - + def sync_to_database(self): """将所有属性同步回数据库""" if not self.is_known: @@ -452,34 +447,38 @@ class Person: try: # 准备数据 data = { - 'person_id': self.person_id, - 'is_known': self.is_known, - 'platform': self.platform, - 'user_id': self.user_id, - 'nickname': self.nickname, - 'person_name': self.person_name, - 'name_reason': self.name_reason, - 'know_times': self.know_times, - 'know_since': self.know_since, - 'last_know': self.last_know, - 'memory_points': json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points else json.dumps([], ensure_ascii=False), - 'attitude_to_me': self.attitude_to_me, - 'attitude_to_me_confidence': self.attitude_to_me_confidence, - 'friendly_value': self.friendly_value, - 'friendly_value_confidence': self.friendly_value_confidence, - 'rudeness': self.rudeness, - 'rudeness_confidence': self.rudeness_confidence, - 'neuroticism': self.neuroticism, - 'neuroticism_confidence': self.neuroticism_confidence, - 'conscientiousness': self.conscientiousness, - 'conscientiousness_confidence': self.conscientiousness_confidence, - 'likeness': self.likeness, - 'likeness_confidence': self.likeness_confidence, + "person_id": self.person_id, + "is_known": self.is_known, + "platform": self.platform, + "user_id": self.user_id, + "nickname": self.nickname, + "person_name": self.person_name, + "name_reason": self.name_reason, + "know_times": self.know_times, + "know_since": self.know_since, + "last_know": self.last_know, + "memory_points": json.dumps( + [point for point in self.memory_points if point is not None], ensure_ascii=False + ) + if self.memory_points + else json.dumps([], ensure_ascii=False), + "attitude_to_me": self.attitude_to_me, + "attitude_to_me_confidence": self.attitude_to_me_confidence, + "friendly_value": self.friendly_value, + "friendly_value_confidence": self.friendly_value_confidence, + "rudeness": self.rudeness, + "rudeness_confidence": self.rudeness_confidence, + "neuroticism": self.neuroticism, + "neuroticism_confidence": self.neuroticism_confidence, + "conscientiousness": self.conscientiousness, + "conscientiousness_confidence": self.conscientiousness_confidence, + "likeness": self.likeness, + "likeness_confidence": self.likeness_confidence, } - + # 检查记录是否存在 record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) - + if record: # 更新现有记录 for field, value in data.items(): @@ -491,10 +490,10 @@ class Person: # 创建新记录 PersonInfo.create(**data) logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") - + except Exception as e: logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}") - + def build_relationship(self): if not self.is_known: return "" @@ -505,22 +504,21 @@ class Person: nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})" relation_info = "" - + attitude_info = "" if self.attitude_to_me: if self.attitude_to_me > 8: attitude_info = f"{self.person_name}对你的态度十分好," elif self.attitude_to_me > 5: attitude_info = f"{self.person_name}对你的态度较好," - - + if self.attitude_to_me < -8: attitude_info = f"{self.person_name}对你的态度十分恶劣," elif self.attitude_to_me < -4: attitude_info = f"{self.person_name}对你的态度不好," elif self.attitude_to_me < 0: attitude_info = f"{self.person_name}对你的态度一般," - + neuroticism_info = "" if self.neuroticism: if self.neuroticism > 8: @@ -533,29 +531,28 @@ class Person: neuroticism_info = f"{self.person_name}的情绪比较稳定," else: neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动" - + points_text = "" category_list = self.get_all_category() for category in category_list: - random_memory = self.get_random_memory_by_category(category,1)[0] + random_memory = self.get_random_memory_by_category(category, 1)[0] if random_memory: points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}" break - + points_info = "" if points_text: points_info = f"你还记得有关{self.person_name}的最近记忆:{points_text}" - + if not (nickname_str or attitude_info or neuroticism_info or points_info): return "" relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}" - + return relation_info class PersonInfoManager: def __init__(self): - self.person_name_list = {} self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") try: @@ -580,8 +577,6 @@ class PersonInfoManager: logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") except Exception as e: logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") - - @staticmethod def _extract_json_from_text(text: str) -> dict: @@ -717,6 +712,6 @@ class PersonInfoManager: person.sync_to_database() self.person_name_list[person_id] = unique_nickname return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"} - + person_info_manager = PersonInfoManager() diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 9bf484f0..7d2591ff 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,7 +3,8 @@ import traceback import os import pickle import random -from typing import List, Dict, Any +import asyncio +from typing import List, Dict, Any, TYPE_CHECKING from src.config.config import global_config from src.common.logger import get_logger from src.person_info.relationship_manager import get_relationship_manager @@ -15,7 +16,9 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, num_new_messages_since, ) -import asyncio + +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("relationship_builder") @@ -429,7 +432,7 @@ class RelationshipBuilder: if dropped_count > 0: logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段") - processed_messages = [] + processed_messages: List["DatabaseMessages"] = [] # 对筛选后的消息段进行排序,确保时间顺序 segments_to_process.sort(key=lambda x: x["start_time"]) @@ -449,17 +452,18 @@ class RelationshipBuilder: # 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识 if processed_messages: # 创建一个特殊的间隔消息 - gap_message = { - "time": start_time - 0.1, # 稍微早于段开始时间 - "user_id": "system", - "user_platform": "system", - "user_nickname": "系统", - "user_cardname": "", - "display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...", - "is_action_record": True, - "chat_info_platform": segment_messages[0].chat_info.platform or "", - "chat_id": chat_id, - } + gap_message = DatabaseMessages( + time=start_time - 0.1, + user_id="system", + user_platform="system", + user_nickname="系统", + user_cardname="", + display_message=f"...(中间省略一些消息){start_date} 之后的消息如下...", + is_action_record=True, + chat_info_platform=segment_messages[0].chat_info.platform or "", + chat_id=chat_id, + ) + processed_messages.append(gap_message) # 添加该段的所有消息 @@ -467,11 +471,11 @@ class RelationshipBuilder: if processed_messages: # 按时间排序所有消息(包括间隔标识) - processed_messages.sort(key=lambda x: x["time"]) + processed_messages.sort(key=lambda x: x.time) logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新") relationship_manager = get_relationship_manager() - + build_frequency = 0.3 * global_config.relationship.relation_frequency if random.random() < build_frequency: # 调用原有的更新方法 diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 916162a8..151446b6 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -3,16 +3,18 @@ import traceback from json_repair import repair_json from datetime import datetime -from typing import List +from typing import List, TYPE_CHECKING from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from .person_info import Person +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + logger = get_logger("relation") @@ -177,7 +179,7 @@ class RelationshipManager: return person - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]): + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List["DatabaseMessages"]): """更新用户印象 Args: @@ -192,8 +194,6 @@ class RelationshipManager: # nickname = person.nickname know_times: float = person.know_times - user_messages = bot_engaged_messages - # 匿名化消息 # 创建用户名称映射 name_mapping = {} @@ -201,13 +201,14 @@ class RelationshipManager: user_count = 1 # 遍历消息,构建映射 - for msg in user_messages: + for msg in bot_engaged_messages: if msg.user_info.user_id == "system": continue try: user_id = msg.user_info.user_id platform = msg.chat_info.platform - assert isinstance(user_id, str) and isinstance(platform, str) + assert user_id, "用户ID不能为空" + assert platform, "平台不能为空" msg_person = Person(user_id=user_id, platform=platform) except Exception as e: @@ -233,7 +234,7 @@ class RelationshipManager: current_user = chr(ord(current_user) + 1) readable_messages = build_readable_messages( - messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True + messages=bot_engaged_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True ) for original_name, mapped_name in name_mapping.items(): diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 3ffbc715..49e78e95 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -9,7 +9,7 @@ """ import traceback -from typing import Tuple, Any, Dict, List, Optional +from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING from rich.traceback import install from src.common.logger import get_logger from src.chat.replyer.default_generator import DefaultReplyer @@ -18,6 +18,10 @@ from src.chat.utils.utils import process_llm_response from src.chat.replyer.replyer_manager import replyer_manager from src.plugin_system.base.component_types import ActionInfo +if TYPE_CHECKING: + from src.common.data_models.info_data_model import ActionPlannerInfo + from src.common.data_models.database_data_model import DatabaseMessages + install(extra_lines=3) logger = get_logger("generator_api") @@ -73,11 +77,11 @@ async def generate_reply( chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, action_data: Optional[Dict[str, Any]] = None, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - choosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List["ActionPlannerInfo"]] = None, enable_tool: bool = False, enable_splitter: bool = True, enable_chinese_typo: bool = True, @@ -85,7 +89,7 @@ async def generate_reply( request_type: str = "generator_api", from_plugin: bool = True, return_expressions: bool = False, -) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]: +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]: """生成回复 Args: @@ -96,7 +100,7 @@ async def generate_reply( extra_info: 额外信息,用于补充上下文 reply_reason: 回复原因 available_actions: 可用动作 - choosen_actions: 已选动作 + chosen_actions: 已选动作 enable_tool: 是否启用工具调用 enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 @@ -110,16 +114,14 @@ async def generate_reply( try: # 获取回复器 logger.debug("[GeneratorAPI] 开始生成回复") - replyer = get_replyer( - chat_stream, chat_id, request_type=request_type - ) + replyer = get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") - return False, [], None + return False, [], None, None if not extra_info and action_data: extra_info = action_data.get("extra_info", "") - + if not reply_reason and action_data: reply_reason = action_data.get("reason", "") @@ -127,7 +129,7 @@ async def generate_reply( success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context( extra_info=extra_info, available_actions=available_actions, - chosen_actions=choosen_actions, + chosen_actions=chosen_actions, enable_tool=enable_tool, reply_message=reply_message, reply_reason=reply_reason, @@ -136,7 +138,7 @@ async def generate_reply( ) if not success: logger.warning("[GeneratorAPI] 回复生成失败") - return False, [], None + return False, [], None, None assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况 if content := llm_response_dict.get("content", ""): reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) @@ -144,28 +146,34 @@ async def generate_reply( reply_set = [] logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项") - if return_prompt: - if return_expressions: - return success, reply_set, (prompt, selected_expressions) - else: - return success, reply_set, prompt - else: - if return_expressions: - return success, reply_set, (None, selected_expressions) - else: - return success, reply_set, None - + # if return_prompt: + # if return_expressions: + # return success, reply_set, prompt, selected_expressions + # else: + # return success, reply_set, prompt, None + # else: + # if return_expressions: + # return success, reply_set, (None, selected_expressions) + # else: + # return success, reply_set, None + return ( + success, + reply_set, + prompt if return_prompt else None, + selected_expressions if return_expressions else None, + ) + except ValueError as ve: raise ve except UserWarning as uw: logger.warning(f"[GeneratorAPI] 中断了生成: {uw}") - return False, [], None + return False, [], None, None except Exception as e: logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") logger.error(traceback.format_exc()) - return False, [], None + return False, [], None, None async def rewrite_reply( diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 700042de..4bdab41e 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -21,15 +21,17 @@ import traceback import time -from typing import Optional, Union, Dict, Any, List -from src.common.logger import get_logger +from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING -# 导入依赖 +from src.common.logger import get_logger +from src.config.config import global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.message_receive.message import MessageSending, MessageRecv from maim_message import Seg, UserInfo -from src.config.config import global_config + +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("send_api") @@ -46,10 +48,10 @@ async def _send_to_target( display_message: str = "", typing: bool = False, set_reply: bool = False, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, storage_message: bool = True, show_log: bool = True, - selected_expressions:List[int] = None, + selected_expressions: Optional[List[int]] = None, ) -> bool: """向指定目标发送消息的内部实现 @@ -70,7 +72,7 @@ async def _send_to_target( if set_reply and not reply_message: logger.warning("[SendAPI] 使用引用回复,但未提供回复消息") return False - + if show_log: logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}") @@ -98,13 +100,13 @@ async def _send_to_target( message_segment = Seg(type=message_type, data=content) # type: ignore if reply_message: - anchor_message = message_dict_to_message_recv(reply_message) + anchor_message = message_dict_to_message_recv(reply_message.flatten()) if anchor_message: anchor_message.update_chat_stream(target_stream) assert anchor_message.message_info.user_info, "用户信息缺失" reply_to_platform_id = ( f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" - ) + ) else: reply_to_platform_id = "" anchor_message = None @@ -192,12 +194,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa } message_recv = MessageRecv(message_dict_recv) - + logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}") return message_recv - # ============================================================================= # 公共API函数 - 预定义类型的发送函数 # ============================================================================= @@ -208,9 +209,9 @@ async def text_to_stream( stream_id: str, typing: bool = False, set_reply: bool = False, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, storage_message: bool = True, - selected_expressions:List[int] = None, + selected_expressions: Optional[List[int]] = None, ) -> bool: """向指定流发送文本消息 @@ -237,7 +238,13 @@ async def text_to_stream( ) -async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: +async def emoji_to_stream( + emoji_base64: str, + stream_id: str, + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, +) -> bool: """向指定流发送表情包 Args: @@ -248,10 +255,25 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo Returns: bool: 是否发送成功 """ - return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message) + return await _send_to_target( + "emoji", + emoji_base64, + stream_id, + "", + typing=False, + storage_message=storage_message, + set_reply=set_reply, + reply_message=reply_message, + ) -async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: +async def image_to_stream( + image_base64: str, + stream_id: str, + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, +) -> bool: """向指定流发送图片 Args: @@ -262,11 +284,25 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo Returns: bool: 是否发送成功 """ - return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message) + return await _send_to_target( + "image", + image_base64, + stream_id, + "", + typing=False, + storage_message=storage_message, + set_reply=set_reply, + reply_message=reply_message, + ) async def command_to_stream( - command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + command: Union[str, dict], + stream_id: str, + storage_message: bool = True, + display_message: str = "", + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """向指定流发送命令 @@ -279,7 +315,14 @@ async def command_to_stream( bool: 是否发送成功 """ return await _send_to_target( - "command", command, stream_id, display_message, typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message + "command", + command, + stream_id, + display_message, + typing=False, + storage_message=storage_message, + set_reply=set_reply, + reply_message=reply_message, ) @@ -289,7 +332,7 @@ async def custom_to_stream( stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: Optional["DatabaseMessages"] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 174b6fea..03bbc0d6 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -2,13 +2,15 @@ import time import asyncio from abc import ABC, abstractmethod -from typing import Tuple, Optional, Dict, Any +from typing import Tuple, Optional, TYPE_CHECKING from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream -from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType +from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType from src.plugin_system.apis import send_api, database_api, message_api +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("base_action") @@ -206,7 +208,11 @@ class BaseAction(ABC): return False, f"等待新消息失败: {str(e)}" async def send_text( - self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False + self, + content: str, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + typing: bool = False, ) -> bool: """发送文本消息 @@ -229,7 +235,9 @@ class BaseAction(ABC): typing=typing, ) - async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_emoji( + self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None + ) -> bool: """发送表情包 Args: @@ -242,9 +250,13 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 缺少聊天ID") return False - return await send_api.emoji_to_stream(emoji_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message) + return await send_api.emoji_to_stream( + emoji_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message + ) - async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_image( + self, image_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None + ) -> bool: """发送图片 Args: @@ -257,9 +269,18 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 缺少聊天ID") return False - return await send_api.image_to_stream(image_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message) + return await send_api.image_to_stream( + image_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message + ) - async def send_custom(self, message_type: str, content: str, typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_custom( + self, + message_type: str, + content: str, + typing: bool = False, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + ) -> bool: """发送自定义类型消息 Args: @@ -308,7 +329,13 @@ class BaseAction(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + self, + command_name: str, + args: Optional[dict] = None, + display_message: str = "", + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """发送命令消息 diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 35fed909..633eba34 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod -from typing import Dict, Tuple, Optional, Any +from typing import Dict, Tuple, Optional, TYPE_CHECKING from src.common.logger import get_logger from src.plugin_system.base.component_types import CommandInfo, ComponentType from src.chat.message_receive.message import MessageRecv from src.plugin_system.apis import send_api +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + logger = get_logger("base_command") @@ -84,7 +87,13 @@ class BaseCommand(ABC): return current - async def send_text(self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool: + async def send_text( + self, + content: str, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + storage_message: bool = True, + ) -> bool: """发送回复消息 Args: @@ -100,10 +109,22 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, set_reply=set_reply,reply_message=reply_message,storage_message=storage_message) + return await send_api.text_to_stream( + text=content, + stream_id=chat_stream.stream_id, + set_reply=set_reply, + reply_message=reply_message, + storage_message=storage_message, + ) async def send_type( - self, message_type: str, content: str, display_message: str = "", typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + self, + message_type: str, + content: str, + display_message: str = "", + typing: bool = False, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """发送指定类型的回复消息到当前聊天环境 @@ -134,7 +155,13 @@ class BaseCommand(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None + self, + command_name: str, + args: Optional[dict] = None, + display_message: str = "", + storage_message: bool = True, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, ) -> bool: """发送命令消息 @@ -177,7 +204,9 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 发送命令时出错: {e}") return False - async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: + async def send_emoji( + self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None + ) -> bool: """发送表情包 Args: @@ -191,9 +220,17 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message) + return await send_api.emoji_to_stream( + emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message + ) - async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool: + async def send_image( + self, + image_base64: str, + set_reply: bool = False, + reply_message: Optional["DatabaseMessages"] = None, + storage_message: bool = True, + ) -> bool: """发送图片 Args: @@ -207,7 +244,13 @@ class BaseCommand(ABC): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.image_to_stream(image_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message,storage_message=storage_message) + return await send_api.image_to_stream( + image_base64, + chat_stream.stream_id, + set_reply=set_reply, + reply_message=reply_message, + storage_message=storage_message, + ) @classmethod def get_command_info(cls) -> "CommandInfo": diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index fd3d811b..fcbdc918 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -2,7 +2,7 @@ from typing import Dict, Any from src.common.logger import get_logger from src.config.config import global_config -from src.chat.knowledge.knowledge_lib import qa_manager +from src.chat.knowledge import qa_manager from src.plugin_system import BaseTool, ToolParamType logger = get_logger("lpmm_get_knowledge_tool") diff --git a/src/plugins/built_in/relation/relation.py b/src/plugins/built_in/relation/relation.py index 24193651..15fb59bd 100644 --- a/src/plugins/built_in/relation/relation.py +++ b/src/plugins/built_in/relation/relation.py @@ -1,20 +1,13 @@ -import random +import json +from json_repair import repair_json from typing import Tuple -# 导入新插件系统 -from src.plugin_system import BaseAction, ActionActivationType, ChatMode - -# 导入依赖的系统组件 from src.common.logger import get_logger - -# 导入API模块 - 标准Python包方式 -from src.plugin_system.apis import emoji_api, llm_api, message_api -# NoReplyAction已集成到heartFC_chat.py中,不再需要导入 from src.config.config import global_config from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -import json -from json_repair import repair_json +from src.plugin_system import BaseAction, ActionActivationType +from src.plugin_system.apis import llm_api logger = get_logger("relation") @@ -39,10 +32,9 @@ def init_prompt(): {{ "category": "分类名称" }} """, - "relation_category" + "relation_category", ) - - + Prompt( """ 以下是有关{category}的现有记忆: @@ -73,7 +65,7 @@ def init_prompt(): 现在,请你根据情况选出合适的修改方式,并输出json,不要输出其他内容: """, - "relation_category_update" + "relation_category_update", ) @@ -98,17 +90,14 @@ class BuildRelationAction(BaseAction): """ # 动作参数定义 - action_parameters = { - "person_name":"需要了解或记忆的人的名称", - "impression":"需要了解的对某人的记忆或印象" - } + action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"} # 动作使用场景 action_require = [ "了解对于某人的记忆,并添加到你对对方的印象中", "对方与有明确提到有关其自身的事件", "对方有提到其个人信息,包括喜好,身份,等等", - "对方希望你记住对方的信息" + "对方希望你记住对方的信息", ] # 关联类型 @@ -129,9 +118,7 @@ class BuildRelationAction(BaseAction): if not person.is_known: logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆") return False, f"用户 {person_name} 不存在,跳过添加记忆" - - category_list = person.get_all_category() if not category_list: category_list_str = "无分类" @@ -142,9 +129,8 @@ class BuildRelationAction(BaseAction): "relation_category", category_list=category_list_str, memory_point=impression, - person_name=person.person_name + person_name=person.person_name, ) - if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") @@ -161,84 +147,76 @@ class BuildRelationAction(BaseAction): success, category, _, _ = await llm_api.generate_with_model( prompt, model_config=chat_model_config, request_type="relation.category" ) - - category_data = json.loads(repair_json(category)) category = category_data.get("category", "") if not category: logger.warning(f"{self.log_prefix} LLM未给出分类,跳过添加记忆") return False, "LLM未给出分类,跳过添加记忆" - - + # 第二部分:更新记忆 - + memory_list = person.get_memory_list_by_category(category) if not memory_list: logger.info(f"{self.log_prefix} {person.person_name} 的 {category} 的记忆为空,进行创建") person.memory_points.append(f"{category}:{impression}:1.0") person.sync_to_database() - + return True, f"未找到分类为{category}的记忆点,进行添加" - + memory_list_str = "" memory_list_id = {} - id = 1 - for memory in memory_list: + for id, memory in enumerate(memory_list, start=1): memory_content = get_memory_content_from_memory(memory) memory_list_str += f"{id}. {memory_content}\n" memory_list_id[id] = memory - id += 1 - prompt = await global_prompt_manager.format_prompt( "relation_category_update", category=category, memory_list=memory_list_str, memory_point=impression, - person_name=person.person_name + person_name=person.person_name, ) - + if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") else: logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") - chat_model_config = models.get("utils") + chat_model_config = models.get("utils") success, update_memory, _, _ = await llm_api.generate_with_model( - prompt, model_config=chat_model_config, request_type="relation.category.update" + prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore ) - + update_memory_data = json.loads(repair_json(update_memory)) new_memory = update_memory_data.get("new_memory", "") memory_id = update_memory_data.get("memory_id", "") integrate_memory = update_memory_data.get("integrate_memory", "") - + if new_memory: # 新记忆 person.memory_points.append(f"{category}:{new_memory}:1.0") person.sync_to_database() - + return True, f"为{person.person_name}新增记忆点: {new_memory}" elif memory_id and integrate_memory: # 现存或冲突记忆 memory = memory_list_id[memory_id] memory_content = get_memory_content_from_memory(memory) - del_count = person.del_memory(category,memory_content) - + del_count = person.del_memory(category, memory_content) + if del_count > 0: logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}") memory_weight = get_weight_from_memory(memory) person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}") person.sync_to_database() - + return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}" - + else: logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}") return False, f"删除{person.person_name}的记忆点失败: {memory_content}" - - return True, "关系动作执行成功" @@ -248,4 +226,4 @@ class BuildRelationAction(BaseAction): # 还缺一个关系的太多遗忘和对应的提取 -init_prompt() \ No newline at end of file +init_prompt() diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index 92640af6..d83fc762 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -2,7 +2,7 @@ from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo from src.common.logger import get_logger -from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode +from src.plugin_system.base.base_action import BaseAction, ActionActivationType from src.plugin_system.base.config_types import ConfigField from typing import Tuple, List, Type diff --git a/template/model_config_template.toml b/template/model_config_template.toml index 0d756314..4c32e876 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "1.3.0" +version = "1.3.1" # 配置文件版本号迭代规则同bot_config.toml