From 7a3f260cc3e6da8458ce10e8b2b6a74ea368e3cb Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 9 Nov 2025 14:02:29 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E6=B7=BB=E5=8A=A0ReAct=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=8F=90=E5=8F=96=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/deep_think/plugin.py | 2 +- src/chat/replyer/group_generator.py | 26 +- src/chat/replyer/prompt/replyer_prompt.py | 2 +- src/jargon/jargon_miner.py | 180 +++++- src/memory_system/memory_retrieval.py | 585 ++++++++++++++++++ src/memory_system/memory_utils.py | 37 +- src/memory_system/retrieval_tools/README.md | 155 +++++ src/memory_system/retrieval_tools/__init__.py | 30 + .../retrieval_tools/query_chat_history.py | 221 +++++++ .../retrieval_tools/query_jargon.py | 92 +++ .../retrieval_tools/tool_registry.py | 114 ++++ .../retrieval_tools/tool_utils.py | 64 ++ 12 files changed, 1463 insertions(+), 45 deletions(-) create mode 100644 src/memory_system/memory_retrieval.py create mode 100644 src/memory_system/retrieval_tools/README.md create mode 100644 src/memory_system/retrieval_tools/__init__.py create mode 100644 src/memory_system/retrieval_tools/query_chat_history.py create mode 100644 src/memory_system/retrieval_tools/query_jargon.py create mode 100644 src/memory_system/retrieval_tools/tool_registry.py create mode 100644 src/memory_system/retrieval_tools/tool_utils.py diff --git a/plugins/deep_think/plugin.py b/plugins/deep_think/plugin.py index 19f177bc..5d9debfe 100644 --- a/plugins/deep_think/plugin.py +++ b/plugins/deep_think/plugin.py @@ -9,7 +9,7 @@ from src.plugin_system.base.base_tool import BaseTool, ToolParamType # 导入依赖的系统组件 from src.common.logger import get_logger -from src.plugins.built_in.relation.relation import BuildRelationAction +# from src.plugins.built_in.relation.relation import BuildRelationAction from src.plugin_system.apis import llm_api logger = get_logger("relation_actions") diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 6f0a944d..ec4dd587 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -37,10 +37,12 @@ from src.plugin_system.apis import llm_api from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt +from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt init_lpmm_prompt() init_replyer_prompt() init_rewrite_prompt() +init_memory_retrieval_prompt() logger = get_logger("replyer") @@ -289,16 +291,8 @@ class DefaultReplyer: async def build_question_block(self) -> str: """构建问题块""" - # if not global_config.question.enable_question: - # return "" - questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id) - questions_str = "" - for question in questions: - questions_str += f"- {question.question}\n" - if questions_str: - return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}" - else: - return "" + # 问题跟踪功能已移除,返回空字符串 + return "" async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: @@ -807,7 +801,7 @@ class DefaultReplyer: show_actions=True, ) - # 并行执行五个构建任务 + # 并行执行九个构建任务 task_results = await asyncio.gather( self._time_and_run_task( self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" @@ -821,6 +815,12 @@ class DefaultReplyer: self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"), self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"), self._time_and_run_task(self.build_question_block(), "question_block"), + self._time_and_run_task( + build_memory_retrieval_prompt( + chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor + ), + "memory_retrieval", + ), ) # 任务名称中英文映射 @@ -835,6 +835,7 @@ class DefaultReplyer: "personality_prompt": "人格信息", "mood_state_prompt": "情绪状态", "question_block": "问题", + "memory_retrieval": "记忆检索", } # 处理结果 @@ -865,6 +866,7 @@ class DefaultReplyer: actions_info: str = results_dict["actions_info"] personality_prompt: str = results_dict["personality_prompt"] question_block: str = results_dict["question_block"] + memory_retrieval: str = results_dict["memory_retrieval"] keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) mood_state_prompt: str = results_dict["mood_state_prompt"] @@ -922,6 +924,7 @@ class DefaultReplyer: keywords_reaction_prompt=keywords_reaction_prompt, moderation_prompt=moderation_prompt_block, question_block=question_block, + memory_retrieval=memory_retrieval, chat_prompt=chat_prompt_block, ), selected_expressions @@ -1150,7 +1153,6 @@ class DefaultReplyer: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "" - def weighted_sample_no_replacement(items, weights, k) -> list: """ 加权且不放回地随机抽取k个元素。 diff --git a/src/chat/replyer/prompt/replyer_prompt.py b/src/chat/replyer/prompt/replyer_prompt.py index 9e5d145a..32cf84b3 100644 --- a/src/chat/replyer/prompt/replyer_prompt.py +++ b/src/chat/replyer/prompt/replyer_prompt.py @@ -11,7 +11,7 @@ def init_replyer_prompt(): Prompt( """{knowledge_prompt}{tool_info_block}{extra_info_block} -{expression_habits_block}{memory_block}{question_block} +{expression_habits_block}{memory_block}{question_block}{memory_retrieval} 你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片: {time_block} diff --git a/src/jargon/jargon_miner.py b/src/jargon/jargon_miner.py index a8c88cb7..554d886c 100644 --- a/src/jargon/jargon_miner.py +++ b/src/jargon/jargon_miner.py @@ -1,13 +1,14 @@ import time import json import asyncio -from typing import List +from typing import List, Dict, Optional from json_repair import repair_json +from peewee import fn from src.common.logger import get_logger from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config +from src.config.config import model_config, global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( build_anonymous_messages, @@ -21,28 +22,27 @@ logger = get_logger("jargon") def _init_prompt() -> None: prompt_str = """ -**聊天内容** +**聊天内容,其中的SELF是你自己的发言** {chat_str} 请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。 - 必须为对话中真实出现过的短词或短语 -- 必须是你无法理解含义的词语,或者出现频率较高的词语 +- 必须是你无法理解含义的词语,没有明确含义的词语 - 请不要选择有明确含义,或者含义清晰的词语 -- 必须是这几种类别之一:英文或中文缩写、中文拼音短语、字母数字混合 -- 排除:人名、@、明显的表情/图片占位、纯标点、常规功能词(如的、了、呢、啊等) +- 必须是这几种类别之一:英文或中文缩写、中文拼音短语 +- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等) - 每个词条长度建议 2-8 个字符(不强制),尽量短小 - 合并重复项,去重 分类规则,type必须根据规则填写: -- p(拼音缩写):由字母或字母和汉字构成的,用汉语拼音简写词,或汉语拼音首字母的简写词,例如:nb、yyds、xswl -- c(中文缩写):中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷 +- p(拼音缩写):由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl - e(英文缩写):英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API -- x(谐音梗):谐音梗,用谐音词概括一个词汇或含义,例如:好似,难崩 +- c(中文缩写):中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷 以 JSON 数组输出,元素为对象(严格按以下结构): [ - {{"content": "词条", "raw_content": "包含该词条的完整对话原文", "type": "p"}}, - {{"content": "词条2", "raw_content": "包含该词条的完整对话原文", "type": "c"}} + {{"content": "词条", "raw_content": "包含该词条的完整对话上下文原文", "type": "p"}}, + {{"content": "词条2", "raw_content": "包含该词条的完整对话上下文原文", "type": "c"}} ] 现在请输出: @@ -57,7 +57,7 @@ def _init_inference_prompts() -> None: **词条内容** {content} -**词条出现的上下文(raw_content)** +**词条出现的上下文(raw_content)其中的SELF是你自己的发言** {raw_content_list} 请根据以上词条内容和上下文,推断这个词条的含义。 @@ -66,8 +66,8 @@ def _init_inference_prompts() -> None: 以 JSON 格式输出: {{ - "meaning": "含义说明", - "translation": "翻译或解释" + "meaning": "详细含义说明(包含使用场景、来源、具体解释等)", + "translation": "原文(用一个词语写明这个词的实际含义)" }} """ Prompt(prompt1_str, "jargon_inference_with_context_prompt") @@ -83,8 +83,8 @@ def _init_inference_prompts() -> None: 以 JSON 格式输出: {{ - "meaning": "含义说明", - "translation": "翻译或解释" + "meaning": "详细含义说明(包含使用场景、来源、具体解释等)", + "translation": "原文(用一个词语写明这个词的实际含义)" }} """ Prompt(prompt2_str, "jargon_inference_content_only_prompt") @@ -117,7 +117,7 @@ _init_inference_prompts() def _should_infer_meaning(jargon_obj: Jargon) -> bool: """ 判断是否需要进行含义推断 - 在 count 达到 5, 10, 20, 40, 60, 100 时进行推断 + 在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断 并且count必须大于last_inference_count,避免重启后重复判定 如果is_complete为True,不再进行推断 """ @@ -128,8 +128,8 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool: count = jargon_obj.count or 0 last_inference = jargon_obj.last_inference_count or 0 - # 阈值列表:5, 10, 20, 40, 60, 100 - thresholds = [5, 10, 20, 40, 60, 100] + # 阈值列表:3,6, 10, 20, 40, 60, 100 + thresholds = [3,6, 10, 20, 40, 60, 100] if count < thresholds[0]: return False @@ -165,6 +165,11 @@ class JargonMiner: model_set=model_config.model_task_config.utils, request_type="jargon.extract", ) + + # 初始化stream_name作为类属性,避免重复提取 + chat_manager = get_chat_manager() + stream_name = chat_manager.get_stream_name(self.chat_id) + self.stream_name = stream_name if stream_name else self.chat_id async def _infer_meaning_by_id(self, jargon_id: int) -> None: """通过ID加载对象并推断""" @@ -255,12 +260,14 @@ class JargonMiner: except Exception as e: logger.error(f"jargon {content} 推断2解析失败: {e}") return - logger.info(f"jargon {content} 推断2提示词: {prompt2}") - logger.info(f"jargon {content} 推断2结果: {response2}") - # logger.info(f"jargon {content} 推断2结果: {inference2}") - logger.info(f"jargon {content} 推断1提示词: {prompt1}") - logger.info(f"jargon {content} 推断1结果: {response1}") - # logger.info(f"jargon {content} 推断1结果: {inference1}") + + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon {content} 推断2提示词: {prompt2}") + logger.info(f"jargon {content} 推断2结果: {response2}") + # logger.info(f"jargon {content} 推断2结果: {inference2}") + logger.info(f"jargon {content} 推断1提示词: {prompt1}") + logger.info(f"jargon {content} 推断1结果: {response1}") + # logger.info(f"jargon {content} 推断1结果: {inference1}") # 步骤3: 比较两个推断结果 prompt3 = await global_prompt_manager.format_prompt( @@ -269,7 +276,8 @@ class JargonMiner: inference2=json.dumps(inference2, ensure_ascii=False), ) - logger.info(f"jargon {content} 比较提示词: {prompt3}") + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon {content} 比较提示词: {prompt3}") response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3) if not response3: @@ -317,6 +325,20 @@ class JargonMiner: jargon_obj.save() logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}") + # 固定输出推断结果,格式化为可读形式 + if is_jargon: + # 是黑话,输出格式:[聊天名]xxx (translation)的含义是 xxxxxxxxxxx + translation = jargon_obj.translation or "未知" + meaning = jargon_obj.meaning or "无详细说明" + is_global = jargon_obj.is_global + if is_global: + logger.info(f"[通用黑话]{content} ({translation})的含义是 {meaning}") + else: + logger.info(f"[{self.stream_name}]{content} ({translation})的含义是 {meaning}") + else: + # 不是黑话,输出格式:[聊天名]xxx 不是黑话 + logger.info(f"[{self.stream_name}]{content} 不是黑话") + except Exception as e: logger.error(f"jargon推断失败: {e}") import traceback @@ -371,8 +393,9 @@ class JargonMiner: if not response: return - logger.info(f"jargon提取提示词: {prompt}") - logger.info(f"jargon提取结果: {response}") + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon提取提示词: {prompt}") + logger.info(f"jargon提取结果: {response}") # 解析为JSON entries: List[dict] = [] @@ -404,6 +427,8 @@ class JargonMiner: raw_content_list = [] if isinstance(raw_content_value, list): raw_content_list = [str(rc).strip() for rc in raw_content_value if str(rc).strip()] + # 去重 + raw_content_list = list(dict.fromkeys(raw_content_list)) elif isinstance(raw_content_value, str): raw_content_str = raw_content_value.strip() if raw_content_str: @@ -585,10 +610,20 @@ class JargonMiner: logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}") continue - if saved or updated or merged: - logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,合并为global {merged} 条,chat_id={self.chat_id}") + # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) + if uniq_entries: + # 收集所有提取的jargon内容 + jargon_list = [entry["content"] for entry in uniq_entries] + jargon_str = ",".join(jargon_list) + + # 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色) + logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}") + # 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口 self.last_learning_time = extraction_end_time + + if saved or updated or merged: + logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,合并为global {merged} 条,chat_id={self.chat_id}") except Exception as e: logger.error(f"JargonMiner 运行失败: {e}") @@ -611,3 +646,88 @@ async def extract_and_store_jargon(chat_id: str) -> None: await miner.run_once() +def search_jargon( + keyword: str, + chat_id: Optional[str] = None, + limit: int = 10, + case_sensitive: bool = False, + fuzzy: bool = True +) -> List[Dict[str, str]]: + """ + 搜索jargon,支持大小写不敏感和模糊搜索 + + Args: + keyword: 搜索关键词 + chat_id: 可选的聊天ID,如果提供则优先搜索该聊天或global的jargon + limit: 返回结果数量限制,默认10 + case_sensitive: 是否大小写敏感,默认False(不敏感) + fuzzy: 是否模糊搜索,默认True(使用LIKE匹配) + + Returns: + List[Dict[str, str]]: 包含content, translation, meaning的字典列表 + """ + if not keyword or not keyword.strip(): + return [] + + keyword = keyword.strip() + + # 构建查询 + query = Jargon.select( + Jargon.content, + Jargon.translation, + Jargon.meaning + ) + + # 构建搜索条件 + if case_sensitive: + # 大小写敏感 + if fuzzy: + # 模糊搜索 + search_condition = Jargon.content.contains(keyword) + else: + # 精确匹配 + search_condition = (Jargon.content == keyword) + else: + # 大小写不敏感 + if fuzzy: + # 模糊搜索(使用LOWER函数) + search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) + else: + # 精确匹配(使用LOWER函数) + search_condition = (fn.LOWER(Jargon.content) == keyword.lower()) + + query = query.where(search_condition) + + # 如果提供了chat_id,优先搜索该聊天或global的jargon + if chat_id: + query = query.where( + (Jargon.chat_id == chat_id) | Jargon.is_global + ) + + # 只返回有translation或meaning的记录 + query = query.where( + ( + (Jargon.translation.is_null(False)) & (Jargon.translation != "") + ) | ( + (Jargon.meaning.is_null(False)) & (Jargon.meaning != "") + ) + ) + + # 按count降序排序,优先返回出现频率高的 + query = query.order_by(Jargon.count.desc()) + + # 限制结果数量 + query = query.limit(limit) + + # 执行查询并返回结果 + results = [] + for jargon in query: + results.append({ + "content": jargon.content or "", + "translation": jargon.translation or "", + "meaning": jargon.meaning or "" + }) + + return results + + diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py new file mode 100644 index 00000000..b3207459 --- /dev/null +++ b/src/memory_system/memory_retrieval.py @@ -0,0 +1,585 @@ +import time +import json +import re +from typing import List, Dict, Any, Optional, Tuple +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.plugin_system.apis import llm_api +from src.common.database.database_model import ThinkingBack +from json_repair import repair_json +from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools + +logger = get_logger("memory_retrieval") + + +def init_memory_retrieval_prompt(): + """初始化记忆检索相关的 prompt 模板和工具""" + # 首先注册所有工具 + init_all_tools() + + # 第一步:问题生成prompt + Prompt( + """ +你是一个专门检测是否需要回忆的助手。你的名字是{bot_name}。现在是{time_now}。 +群里正在进行的聊天内容: +{chat_history} + +{recent_query_history} + +现在,{sender}发送了内容:{target_message},你想要回复ta。 +请仔细分析聊天内容,考虑以下几点: +1. 对话中是否提到了过去发生的事情、人物、事件或信息 +2. 是否有需要回忆的内容(比如"之前说过"、"上次"、"以前"等) +3. 是否有需要查找历史信息的问题 + +重要提示: +- 如果"最近已查询的问题和结果"中已经包含了类似的问题,请避免重复生成相同或相似的问题 +- 如果之前已经查询过某个问题但未找到答案,可以尝试用不同的方式提问或更具体的问题 +- 如果之前已经查询过某个问题并找到了答案,可以直接参考已有结果,不需要重复查询 + +如果你认为需要从记忆中检索信息来回答,请根据上下文提出一个或多个具体的问题。 +问题格式示例: +- "xxx在前几天干了什么" +- "xxx是什么" +- "xxxx和xxx的关系是什么" +- "xxx在某个时间点发生了什么" + +请输出JSON格式的问题数组。如果不需要检索记忆,则输出空数组[]。 + +输出格式示例: +```json +[ + "张三在前几天干了什么", + "自然选择是什么", + "李四和王五的关系是什么" +] +``` + +请只输出JSON数组,不要输出其他内容: +""", + name="memory_retrieval_question_prompt", + ) + + # 第二步:ReAct Agent prompt(工具描述会在运行时动态生成) + Prompt( + """ +你是一个记忆检索助手,需要通过思考(Think)、行动(Action)、观察(Observation)的循环来回答问题。 + +当前问题:{question} +已收集的信息: +{collected_info} + +你可以使用以下工具来查询信息: +{tools_description} + +请按照以下格式输出你的思考过程: +```json +{{ + "thought": "你的思考过程,分析当前情况,决定下一步行动", + "action": "要执行的动作,格式为:工具名(参数)", + "action_type": {action_types_list}, + "action_params": {{参数名: 参数值}} 或 null +}} +``` + +你可以选择以下动作: +1. 如果已经收集到足够的信息可以回答问题,请设置action_type为"final_answer",并在thought中说明答案。 +2. 如果经过多次查询后,确认无法找到相关信息或答案,请设置action_type为"no_answer",并在thought中说明原因。 + +请只输出JSON,不要输出其他内容: +""", + name="memory_retrieval_react_prompt", + ) + + +def _parse_react_response(response: str) -> Optional[Dict[str, Any]]: + """解析ReAct Agent的响应 + + Args: + response: LLM返回的响应 + + Returns: + Dict[str, Any]: 解析后的动作信息,如果解析失败返回None + """ + try: + # 尝试提取JSON(可能包含在```json代码块中) + json_pattern = r"```json\s*(.*?)\s*```" + matches = re.findall(json_pattern, response, re.DOTALL) + + if matches: + json_str = matches[0] + else: + # 尝试直接解析整个响应 + json_str = response.strip() + + # 修复可能的JSON错误 + repaired_json = repair_json(json_str) + + # 解析JSON + action_info = json.loads(repaired_json) + + if not isinstance(action_info, dict): + logger.warning(f"解析的JSON不是对象格式: {action_info}") + return None + + return action_info + + except Exception as e: + logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...") + return None + + +async def _react_agent_solve_question( + question: str, + chat_id: str, + max_iterations: int = 5, + timeout: float = 30.0 +) -> Tuple[bool, str, List[Dict[str, Any]]]: + """使用ReAct架构的Agent来解决问题 + + Args: + question: 要回答的问题 + chat_id: 聊天ID + max_iterations: 最大迭代次数 + timeout: 超时时间(秒) + + Returns: + Tuple[bool, str, List[Dict[str, Any]]]: (是否找到答案, 答案内容, 思考步骤列表) + """ + start_time = time.time() + collected_info = "" + thinking_steps = [] + + for iteration in range(max_iterations): + # 检查超时 + if time.time() - start_time > timeout: + logger.warning(f"ReAct Agent超时,已迭代{iteration}次") + break + + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}") + logger.info(f"ReAct Agent 已收集信息: {collected_info if collected_info else '暂无信息'}") + + # 获取工具注册器 + tool_registry = get_tool_registry() + + # 构建prompt(动态生成工具描述) + prompt = await global_prompt_manager.format_prompt( + "memory_retrieval_react_prompt", + question=question, + collected_info=collected_info if collected_info else "暂无信息", + tools_description=tool_registry.get_tools_description(), + action_types_list=tool_registry.get_action_types_list(), + ) + + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 Prompt: {prompt}") + + # 调用LLM + success, response, reasoning_content, model_name = await llm_api.generate_with_model( + prompt, + model_config=model_config.model_task_config.tool_use, + ) + + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM响应: {response}") + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM推理: {reasoning_content}") + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM模型: {model_name}") + + if not success: + logger.error(f"ReAct Agent LLM调用失败: {response}") + break + + # 解析响应 + action_info = _parse_react_response(response) + if not action_info: + logger.warning(f"无法解析ReAct响应,迭代{iteration + 1}") + break + + thought = action_info.get("thought", "") + action_type = action_info.get("action_type", "") + action_params = action_info.get("action_params", {}) + + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考: {thought}") + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作类型: {action_type}") + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作参数: {action_params}") + + # 记录思考步骤 + step = { + "iteration": iteration + 1, + "thought": thought, + "action_type": action_type, + "action_params": action_params, + "observation": "" + } + + # 执行动作 + if action_type == "final_answer": + # Agent认为已经找到答案 + answer = thought # 使用thought作为答案 + step["observation"] = "找到答案" + thinking_steps.append(step) + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到最终答案: {answer}") + return True, answer, thinking_steps + + elif action_type == "no_answer": + # Agent确认无法找到答案 + answer = thought # 使用thought说明无法找到答案的原因 + step["observation"] = "确认无法找到答案" + thinking_steps.append(step) + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}") + return False, answer, thinking_steps + + # 使用工具注册器执行工具 + tool_registry = get_tool_registry() + tool = tool_registry.get_tool(action_type) + + if tool: + try: + # 准备工具参数(需要添加chat_id如果工具需要) + tool_params = action_params.copy() + + # 如果工具函数签名需要chat_id,添加它 + import inspect + sig = inspect.signature(tool.execute_func) + if "chat_id" in sig.parameters: + tool_params["chat_id"] = chat_id + + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 执行工具: {action_type}({tool_params})") + + # 执行工具 + observation = await tool.execute(**tool_params) + step["observation"] = observation + + # 构建收集信息的描述 + param_str = ", ".join([f"{k}={v}" for k, v in action_params.items()]) + collected_info += f"\n查询{action_type}({param_str})的结果:{observation}\n" + + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具执行结果: {observation}") + except Exception as e: + error_msg = f"工具执行失败: {str(e)}" + step["observation"] = error_msg + logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 {error_msg}") + else: + step["observation"] = f"未知的工具类型: {action_type}" + logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 未知的工具类型: {action_type}") + + thinking_steps.append(step) + + # 如果观察结果为空或无效,继续下一轮 + if step["observation"] and "无有效信息" not in step["observation"] and "未找到" not in step["observation"]: + # 有有效信息,继续思考 + pass + + # 达到最大迭代次数或超时,但Agent没有明确返回final_answer + # 这种情况下,即使收集到了一些信息,也不认为找到了答案 + # 只有Agent明确返回final_answer时,才认为找到了答案 + if collected_info: + logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}...") + return False, collected_info, thinking_steps + else: + return False, "未找到相关信息", thinking_steps + + +def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0) -> str: + """获取最近一段时间内的查询历史 + + Args: + chat_id: 聊天ID + time_window_seconds: 时间窗口(秒),默认1小时 + + Returns: + str: 格式化的查询历史字符串 + """ + try: + current_time = time.time() + start_time = current_time - time_window_seconds + + # 查询最近时间窗口内的记录,按更新时间倒序 + records = ( + ThinkingBack.select() + .where( + (ThinkingBack.chat_id == chat_id) & + (ThinkingBack.update_time >= start_time) + ) + .order_by(ThinkingBack.update_time.desc()) + .limit(8) # 最多返回10条最近的记录 + ) + + if not records.exists(): + return "" + + history_lines = [] + history_lines.append("最近已查询的问题和结果:") + + for record in records: + status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" + answer_preview = "" + if record.answer: + # 截取答案前100字符 + answer_preview = record.answer[:100] + if len(record.answer) > 100: + answer_preview += "..." + + history_lines.append(f"- 问题:{record.question}") + history_lines.append(f" 状态:{status}") + if answer_preview: + history_lines.append(f" 答案:{answer_preview}") + history_lines.append("") # 空行分隔 + + return "\n".join(history_lines) + + except Exception as e: + logger.error(f"获取查询历史失败: {e}") + return "" + + +def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]: + """从thinking_back数据库中查询是否有现成的答案 + + Args: + chat_id: 聊天ID + question: 问题 + + Returns: + Optional[Tuple[bool, str]]: 如果找到答案,返回(True, answer),否则返回None + """ + try: + # 查询相同chat_id和问题,且found_answer为True的记录 + # 按更新时间倒序,获取最新的答案 + records = ( + ThinkingBack.select() + .where( + (ThinkingBack.chat_id == chat_id) & + (ThinkingBack.question == question) & + (ThinkingBack.found_answer == 1) + ) + .order_by(ThinkingBack.update_time.desc()) + .limit(1) + ) + + if records.exists(): + record = records.get() + logger.info(f"在thinking_back中找到现成答案,问题: {question[:50]}...") + return True, record.answer or "" + + return None + + except Exception as e: + logger.error(f"查询thinking_back失败: {e}") + return None + + +def _store_thinking_back( + chat_id: str, + question: str, + context: str, + found_answer: bool, + answer: str, + thinking_steps: List[Dict[str, Any]] +) -> None: + """存储或更新思考过程到数据库(如果已存在则更新,否则创建) + + Args: + chat_id: 聊天ID + question: 问题 + context: 上下文信息 + found_answer: 是否找到答案 + answer: 答案内容 + thinking_steps: 思考步骤列表 + """ + try: + now = time.time() + + # 先查询是否已存在相同chat_id和问题的记录 + existing = ( + ThinkingBack.select() + .where( + (ThinkingBack.chat_id == chat_id) & + (ThinkingBack.question == question) + ) + .order_by(ThinkingBack.update_time.desc()) + .limit(1) + ) + + if existing.exists(): + # 更新现有记录 + record = existing.get() + record.context = context + record.found_answer = found_answer + record.answer = answer + record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False) + record.update_time = now + record.save() + logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...") + else: + # 创建新记录 + ThinkingBack.create( + chat_id=chat_id, + question=question, + context=context, + found_answer=found_answer, + answer=answer, + thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), + create_time=now, + update_time=now + ) + logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...") + except Exception as e: + logger.error(f"存储思考过程失败: {e}") + + +async def build_memory_retrieval_prompt( + message: str, + sender: str, + target: str, + chat_stream, + tool_executor, +) -> str: + """构建记忆检索提示 + 使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案 + + Args: + message: 聊天历史记录 + sender: 发送者名称 + target: 目标消息内容 + chat_stream: 聊天流对象 + tool_executor: 工具执行器(保留参数以兼容接口) + + Returns: + str: 记忆检索结果字符串 + """ + start_time = time.time() + + logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}") + try: + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + bot_name = global_config.bot.nickname + chat_id = chat_stream.stream_id + + # 获取最近查询历史(最近1小时内的查询) + recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0) + if not recent_query_history: + recent_query_history = "最近没有查询记录。" + + # 第一步:生成问题 + question_prompt = await global_prompt_manager.format_prompt( + "memory_retrieval_question_prompt", + bot_name=bot_name, + time_now=time_now, + chat_history=message, + recent_query_history=recent_query_history, + sender=sender, + target_message=target, + ) + + success, response, reasoning_content, model_name = await llm_api.generate_with_model( + question_prompt, + model_config=model_config.model_task_config.tool_use, + ) + + logger.info(f"记忆检索问题生成提示词: {question_prompt}") + logger.info(f"记忆检索问题生成响应: {response}") + logger.info(f"记忆检索问题生成推理: {reasoning_content}") + logger.info(f"记忆检索问题生成模型: {model_name}") + + if not success: + logger.error(f"LLM生成问题失败: {response}") + return "" + + # 解析问题列表 + questions = _parse_questions_json(response) + + if not questions: + logger.debug("模型认为不需要检索记忆或解析失败") + return "" + + logger.info(f"解析到 {len(questions)} 个问题: {questions}") + + # 第二步:对每个问题查询答案 + all_results = [] + for question in questions: + logger.info(f"开始处理问题: {question}") + + # 先检查thinking_back数据库中是否有现成答案 + cached_result = _query_thinking_back(chat_id, question) + if cached_result: + found_answer, answer = cached_result + if found_answer and answer: + logger.info(f"从thinking_back缓存中获取答案,问题: {question[:50]}...") + all_results.append(f"问题:{question}\n答案:{answer}") + continue # 跳过ReAct Agent查询 + + # 如果没有缓存答案,使用ReAct Agent查询 + logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...") + found_answer, answer, thinking_steps = await _react_agent_solve_question( + question=question, + chat_id=chat_id, + max_iterations=5, + timeout=30.0 + ) + + # 存储到数据库 + _store_thinking_back( + chat_id=chat_id, + question=question, + context=message, # 只存储前500字符作为上下文 + found_answer=found_answer, + answer=answer, + thinking_steps=thinking_steps + ) + + if found_answer and answer: + all_results.append(f"问题:{question}\n答案:{answer}") + + end_time = time.time() + + if all_results: + retrieved_memory = "\n\n".join(all_results) + logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒") + return f"你回忆起了以下信息:\n{retrieved_memory}\n请在回复时参考这些回忆的信息。\n" + else: + logger.debug("所有问题均未找到答案") + return "" + + except Exception as e: + logger.error(f"记忆检索时发生异常: {str(e)}") + return "" + + +def _parse_questions_json(response: str) -> List[str]: + """解析问题JSON + + Args: + response: LLM返回的响应 + + Returns: + List[str]: 问题列表 + """ + try: + # 尝试提取JSON(可能包含在```json代码块中) + json_pattern = r"```json\s*(.*?)\s*```" + matches = re.findall(json_pattern, response, re.DOTALL) + + if matches: + json_str = matches[0] + else: + # 尝试直接解析整个响应 + json_str = response.strip() + + # 修复可能的JSON错误 + repaired_json = repair_json(json_str) + + # 解析JSON + questions = json.loads(repaired_json) + + if not isinstance(questions, list): + logger.warning(f"解析的JSON不是数组格式: {questions}") + return [] + + # 确保所有元素都是字符串 + questions = [q for q in questions if isinstance(q, str) and q.strip()] + + return questions + + except Exception as e: + logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") + return [] diff --git a/src/memory_system/memory_utils.py b/src/memory_system/memory_utils.py index bc5347bb..0033dbc6 100644 --- a/src/memory_system/memory_utils.py +++ b/src/memory_system/memory_utils.py @@ -11,6 +11,7 @@ from typing import List, Tuple, Optional from src.common.database.database_model import MemoryChest as MemoryChestModel from src.common.logger import get_logger from json_repair import repair_json +from src.config.config import global_config logger = get_logger("memory_utils") @@ -354,4 +355,38 @@ def find_most_similar_memory_by_chat_id(target_title: str, target_chat_id: str, except Exception as e: logger.error(f"查找最相似记忆时出错: {e}") - return None \ No newline at end of file + return None + + + +def compute_merge_similarity_threshold() -> float: + """ + 根据当前记忆数量占比动态计算合并相似度阈值。 + + 规则:占比越高,阈值越低。 + - < 60%: 0.80(更严格,避免早期误合并) + - < 80%: 0.70 + - < 100%: 0.60 + - < 120%: 0.50 + - >= 120%: 0.45(最宽松,加速收敛) + """ + try: + current_count = MemoryChestModel.select().count() + max_count = max(1, int(global_config.memory.max_memory_number)) + percentage = current_count / max_count + + if percentage < 0.6: + return 0.70 + elif percentage < 0.8: + return 0.60 + elif percentage < 1.0: + return 0.50 + elif percentage < 1.5: + return 0.40 + elif percentage < 2: + return 0.30 + else: + return 0.25 + except Exception: + # 发生异常时使用保守阈值 + return 0.70 \ No newline at end of file diff --git a/src/memory_system/retrieval_tools/README.md b/src/memory_system/retrieval_tools/README.md new file mode 100644 index 00000000..427e4cc9 --- /dev/null +++ b/src/memory_system/retrieval_tools/README.md @@ -0,0 +1,155 @@ +# 记忆检索工具模块 + +这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。 + +## 目录结构 + +``` +retrieval_tools/ +├── __init__.py # 模块导出 +├── tool_registry.py # 工具注册系统 +├── tool_utils.py # 工具函数库(共用函数) +├── query_jargon.py # 查询jargon工具 +├── query_chat_history.py # 查询聊天历史工具 +└── README.md # 本文件 +``` + +## 模块说明 + +### `tool_registry.py` +包含工具注册系统的核心类: +- `MemoryRetrievalTool`: 工具基类 +- `MemoryRetrievalToolRegistry`: 工具注册器 +- `register_memory_retrieval_tool()`: 便捷注册函数 +- `get_tool_registry()`: 获取注册器实例 + +### `tool_utils.py` +包含所有工具共用的工具函数: +- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳 +- `parse_time_range()`: 解析时间范围字符串 + +### 工具文件 +每个工具都有独立的文件: +- `query_jargon.py`: 根据关键词在jargon库中查询 +- `query_chat_history.py`: 根据时间或关键词在chat_history中查询(支持查询时间点事件、时间范围事件、关键词搜索) + +## 如何添加新工具 + +1. 创建新的工具文件,例如 `query_new_tool.py`: + +```python +""" +新工具 - 工具实现 +""" + +from src.common.logger import get_logger +from .tool_registry import register_memory_retrieval_tool +from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数 + +logger = get_logger("memory_retrieval_tools") + + +async def query_new_tool(param1: str, param2: str, chat_id: str) -> str: + """新工具的实现 + + Args: + param1: 参数1 + param2: 参数2 + chat_id: 聊天ID + + Returns: + str: 查询结果 + """ + try: + # 实现逻辑 + return "结果" + except Exception as e: + logger.error(f"新工具执行失败: {e}") + return f"查询失败: {str(e)}" + + +def register_tool(): + """注册工具""" + register_memory_retrieval_tool( + name="query_new_tool", + description="新工具的描述", + parameters=[ + { + "name": "param1", + "type": "string", + "description": "参数1的描述", + "required": True + }, + { + "name": "param2", + "type": "string", + "description": "参数2的描述", + "required": True + } + ], + execute_func=query_new_tool + ) +``` + +2. 在 `__init__.py` 中导入并注册新工具: + +```python +from .query_new_tool import register_tool as register_query_new_tool + +def init_all_tools(): + """初始化并注册所有记忆检索工具""" + register_query_jargon() + register_query_chat_history() + register_query_new_tool() # 添加新工具 +``` + +3. 工具会自动: + - 出现在 ReAct Agent 的 prompt 中 + - 在动作类型列表中可用 + - 被 ReAct Agent 自动调用 + +## 使用示例 + +```python +from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry + +# 初始化所有工具 +init_all_tools() + +# 获取工具注册器 +registry = get_tool_registry() + +# 获取特定工具 +tool = registry.get_tool("query_chat_history") + +# 执行工具(查询时间点事件) +result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123") + +# 或者查询关键词 +result = await tool.execute(keyword="小丑AI", chat_id="chat123") + +# 或者查询时间范围 +result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123") +``` + +## 现有工具说明 + +### query_jargon +根据关键词在jargon库中查询黑话/俚语/缩写的含义 +- 参数:`keyword` (必填) - 关键词 + +### query_chat_history +根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息 +- 参数: + - `keyword` (可选) - 关键词,用于搜索消息内容 + - `time_point` (可选) - 时间点,格式:YYYY-MM-DD HH:MM:SS,用于查询某个时间点附近发生了什么(与time_range二选一) + - `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(与time_point二选一) + +## 注意事项 + +- 所有工具函数必须是异步函数(`async def`) +- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测) +- 工具参数定义中的 `required` 字段用于生成 prompt 描述 +- 工具执行失败时应返回错误信息字符串,而不是抛出异常 +- 共用函数放在 `tool_utils.py` 中,避免代码重复 + diff --git a/src/memory_system/retrieval_tools/__init__.py b/src/memory_system/retrieval_tools/__init__.py new file mode 100644 index 00000000..2bb5623c --- /dev/null +++ b/src/memory_system/retrieval_tools/__init__.py @@ -0,0 +1,30 @@ +""" +记忆检索工具模块 +提供统一的工具注册和管理系统 +""" + +from .tool_registry import ( + MemoryRetrievalTool, + MemoryRetrievalToolRegistry, + register_memory_retrieval_tool, + get_tool_registry, +) + +# 导入所有工具的注册函数 +from .query_jargon import register_tool as register_query_jargon +from .query_chat_history import register_tool as register_query_chat_history + + +def init_all_tools(): + """初始化并注册所有记忆检索工具""" + register_query_jargon() + register_query_chat_history() + + +__all__ = [ + "MemoryRetrievalTool", + "MemoryRetrievalToolRegistry", + "register_memory_retrieval_tool", + "get_tool_registry", + "init_all_tools", +] diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py new file mode 100644 index 00000000..da5371ff --- /dev/null +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -0,0 +1,221 @@ +""" +根据时间或关键词在chat_history中查询 - 工具实现 +从ChatHistory表的聊天记录概述库中查询 +""" + +import json +from typing import Optional +from src.common.logger import get_logger +from src.config.config import model_config +from src.common.database.database_model import ChatHistory +from src.llm_models.utils_model import LLMRequest +from .tool_registry import register_memory_retrieval_tool +from .tool_utils import parse_datetime_to_timestamp, parse_time_range + +logger = get_logger("memory_retrieval_tools") + + +async def query_chat_history( + chat_id: str, + keyword: Optional[str] = None, + time_point: Optional[str] = None, + time_range: Optional[str] = None +) -> str: + """根据时间或关键词在chat_history表中查询聊天记录概述 + + Args: + chat_id: 聊天ID + keyword: 关键词(可选) + time_point: 时间点,格式:YYYY-MM-DD HH:MM:SS(可选) + time_range: 时间范围,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选) + + Returns: + str: 查询结果 + """ + try: + # 检查参数 + if not keyword and not time_point and not time_range: + return "未指定查询参数(需要提供keyword、time_point或time_range之一)" + + # 构建查询条件 + query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) + + # 时间过滤条件 + time_conditions = [] + if time_point: + # 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time) + target_timestamp = parse_datetime_to_timestamp(time_point) + time_conditions.append( + (ChatHistory.start_time <= target_timestamp) & + (ChatHistory.end_time >= target_timestamp) + ) + elif time_range: + # 时间范围:查询与时间范围有交集的记录 + start_timestamp, end_timestamp = parse_time_range(time_range) + # 交集条件:start_time < end_timestamp AND end_time > start_timestamp + time_conditions.append( + (ChatHistory.start_time < end_timestamp) & + (ChatHistory.end_time > start_timestamp) + ) + + if time_conditions: + # 合并所有时间条件(OR关系) + time_filter = time_conditions[0] + for condition in time_conditions[1:]: + time_filter = time_filter | condition + query = query.where(time_filter) + + # 执行查询 + records = list(query.order_by(ChatHistory.start_time.desc()).limit(50)) + + if not records: + return "未找到相关聊天记录概述" + + # 如果有关键词,进一步过滤 + if keyword: + keyword_lower = keyword.lower() + filtered_records = [] + + for record in records: + # 在theme、keywords、summary、original_text中搜索 + theme = (record.theme or "").lower() + summary = (record.summary or "").lower() + original_text = (record.original_text or "").lower() + + # 解析keywords JSON + keywords_list = [] + if record.keywords: + try: + keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords + if isinstance(keywords_data, list): + keywords_list = [str(k).lower() for k in keywords_data] + except (json.JSONDecodeError, TypeError, ValueError): + pass + + # 检查是否包含关键词 + if (keyword_lower in theme or + keyword_lower in summary or + keyword_lower in original_text or + any(keyword_lower in k for k in keywords_list)): + filtered_records.append(record) + + if not filtered_records: + return f"未找到包含关键词'{keyword}'的聊天记录概述" + + records = filtered_records + + # 构建结果文本 + results = [] + for record in records[:10]: # 最多返回10条记录 + result_parts = [] + + # 添加主题 + if record.theme: + result_parts.append(f"主题:{record.theme}") + + # 添加时间范围 + from datetime import datetime + start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S") + end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S") + result_parts.append(f"时间:{start_str} - {end_str}") + + # 添加概括(优先使用summary,如果没有则使用original_text的前200字符) + if record.summary: + result_parts.append(f"概括:{record.summary}") + elif record.original_text: + text_preview = record.original_text[:200] + if len(record.original_text) > 200: + text_preview += "..." + result_parts.append(f"内容:{text_preview}") + + results.append("\n".join(result_parts)) + + if not results: + return "未找到相关聊天记录概述" + + # 如果只有一条记录,直接返回 + if len(results) == 1: + return results[0] + + # 多条记录,使用LLM总结 + try: + llm_request = LLMRequest( + model_set=model_config.model_task_config.utils_small, + request_type="chat_history_analysis" + ) + + query_desc = [] + if keyword: + query_desc.append(f"关键词:{keyword}") + if time_point: + query_desc.append(f"时间点:{time_point}") + if time_range: + query_desc.append(f"时间范围:{time_range}") + + query_info = ",".join(query_desc) if query_desc else "聊天记录概述" + + combined_results = "\n\n---\n\n".join(results) + + analysis_prompt = f"""请根据以下聊天记录概述,总结与查询条件相关的信息。请输出一段平文本,不要有特殊格式。 +查询条件:{query_info} + +聊天记录概述: +{combined_results} + +请仔细分析聊天记录概述,提取与查询条件相关的信息并给出总结。如果概述中没有相关信息,输出"无有效信息"即可,不要输出其他内容。 + +总结:""" + + response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async( + prompt=analysis_prompt, + temperature=0.3, + max_tokens=512 + ) + + logger.info(f"查询聊天历史概述提示词: {analysis_prompt}") + logger.info(f"查询聊天历史概述响应: {response}") + logger.info(f"查询聊天历史概述推理: {reasoning}") + logger.info(f"查询聊天历史概述模型: {model_name}") + + if "无有效信息" in response: + return "无有效信息" + + return response + + except Exception as llm_error: + logger.error(f"LLM分析聊天记录概述失败: {llm_error}") + # 如果LLM分析失败,返回前3条记录的摘要 + return "\n\n---\n\n".join(results[:3]) + + except Exception as e: + logger.error(f"查询聊天历史概述失败: {e}") + return f"查询失败: {str(e)}" + + +def register_tool(): + """注册工具""" + register_memory_retrieval_tool( + name="query_chat_history", + description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述", + parameters=[ + { + "name": "keyword", + "type": "string", + "description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)", + "required": False + }, + { + "name": "time_point", + "type": "string", + "description": "时间点,格式:YYYY-MM-DD HH:MM:SS(可选,与time_range二选一)。用于查询包含该时间点的聊天记录概述", + "required": False + }, + { + "name": "time_range", + "type": "string", + "description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(可选,与time_point二选一)。用于查询与时间范围有交集的聊天记录概述", + "required": False + } + ], + execute_func=query_chat_history + ) diff --git a/src/memory_system/retrieval_tools/query_jargon.py b/src/memory_system/retrieval_tools/query_jargon.py new file mode 100644 index 00000000..f8accf08 --- /dev/null +++ b/src/memory_system/retrieval_tools/query_jargon.py @@ -0,0 +1,92 @@ +""" +根据关键词在jargon库中查询 - 工具实现 +""" + +from src.common.logger import get_logger +from src.jargon.jargon_miner import search_jargon +from .tool_registry import register_memory_retrieval_tool + +logger = get_logger("memory_retrieval_tools") + + +async def query_jargon( + keyword: str, + chat_id: str, + fuzzy: bool = False, + search_all: bool = False +) -> str: + """根据关键词在jargon库中查询 + + Args: + keyword: 关键词(黑话/俚语/缩写) + chat_id: 聊天ID + fuzzy: 是否使用模糊搜索,默认False(精确匹配) + search_all: 是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global) + + Returns: + str: 查询结果 + """ + try: + content = str(keyword).strip() + if not content: + return "关键词为空" + + # 根据参数执行搜索 + search_chat_id = None if search_all else chat_id + results = search_jargon( + keyword=content, + chat_id=search_chat_id, + limit=1, + case_sensitive=False, + fuzzy=fuzzy + ) + + if results: + result = results[0] + translation = result.get("translation", "").strip() + meaning = result.get("meaning", "").strip() + search_type = "模糊搜索" if fuzzy else "精确匹配" + search_scope = "全库" if search_all else "当前会话或全局" + output = f"“{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}”" + logger.info(f"在jargon库中找到匹配({search_scope},{search_type}): {content}") + return output + + # 未命中 + search_type = "模糊搜索" if fuzzy else "精确匹配" + search_scope = "全库" if search_all else "当前会话或全局" + logger.info(f"在jargon库中未找到匹配({search_scope},{search_type}): {content}") + return f"未在jargon库中找到'{content}'的解释" + + except Exception as e: + logger.error(f"查询jargon失败: {e}") + return f"查询失败: {str(e)}" + + +def register_tool(): + """注册工具""" + register_memory_retrieval_tool( + name="query_jargon", + description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon,可以设置为搜索全库。", + parameters=[ + { + "name": "keyword", + "type": "string", + "description": "关键词(黑话/俚语/缩写),支持模糊搜索", + "required": True + }, + { + "name": "fuzzy", + "type": "boolean", + "description": "是否使用模糊搜索(部分匹配),默认False(精确匹配)。当精确匹配找不到时,可以尝试使用模糊搜索。", + "required": False + }, + { + "name": "search_all", + "type": "boolean", + "description": "是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global的jargon)。当在当前会话中找不到时,可以尝试搜索全库。", + "required": False + } + ], + execute_func=query_jargon + ) + diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py new file mode 100644 index 00000000..920a1bb6 --- /dev/null +++ b/src/memory_system/retrieval_tools/tool_registry.py @@ -0,0 +1,114 @@ +""" +工具注册系统 +提供统一的工具注册和管理接口 +""" + +from typing import List, Dict, Any, Optional, Callable, Awaitable +from src.common.logger import get_logger + +logger = get_logger("memory_retrieval_tools") + + +class MemoryRetrievalTool: + """记忆检索工具基类""" + + def __init__( + self, + name: str, + description: str, + parameters: List[Dict[str, Any]], + execute_func: Callable[..., Awaitable[str]] + ): + """ + 初始化工具 + + Args: + name: 工具名称 + description: 工具描述 + parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}] + execute_func: 执行函数,必须是异步函数 + """ + self.name = name + self.description = description + self.parameters = parameters + self.execute_func = execute_func + + def get_tool_description(self) -> str: + """获取工具的文本描述,用于prompt""" + param_descriptions = [] + for param in self.parameters: + param_name = param.get("name", "") + param_type = param.get("type", "string") + param_desc = param.get("description", "") + required = param.get("required", True) + required_str = "必填" if required else "可选" + param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}") + + params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" + return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" + + async def execute(self, **kwargs) -> str: + """执行工具""" + return await self.execute_func(**kwargs) + + +class MemoryRetrievalToolRegistry: + """工具注册器""" + + def __init__(self): + self.tools: Dict[str, MemoryRetrievalTool] = {} + + def register_tool(self, tool: MemoryRetrievalTool) -> None: + """注册工具""" + self.tools[tool.name] = tool + logger.info(f"注册记忆检索工具: {tool.name}") + + def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]: + """获取工具""" + return self.tools.get(name) + + def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]: + """获取所有工具""" + return self.tools.copy() + + def get_tools_description(self) -> str: + """获取所有工具的描述,用于prompt""" + descriptions = [] + for i, tool in enumerate(self.tools.values(), 1): + descriptions.append(f"{i}. {tool.get_tool_description()}") + return "\n".join(descriptions) + + def get_action_types_list(self) -> str: + """获取所有动作类型的列表,用于prompt""" + action_types = [tool.name for tool in self.tools.values()] + action_types.append("final_answer") + action_types.append("no_answer") + return " 或 ".join([f'"{at}"' for at in action_types]) + + +# 全局工具注册器实例 +_tool_registry = MemoryRetrievalToolRegistry() + + +def register_memory_retrieval_tool( + name: str, + description: str, + parameters: List[Dict[str, Any]], + execute_func: Callable[..., Awaitable[str]] +) -> None: + """注册记忆检索工具的便捷函数 + + Args: + name: 工具名称 + description: 工具描述 + parameters: 参数定义列表 + execute_func: 执行函数 + """ + tool = MemoryRetrievalTool(name, description, parameters, execute_func) + _tool_registry.register_tool(tool) + + +def get_tool_registry() -> MemoryRetrievalToolRegistry: + """获取工具注册器实例""" + return _tool_registry + diff --git a/src/memory_system/retrieval_tools/tool_utils.py b/src/memory_system/retrieval_tools/tool_utils.py new file mode 100644 index 00000000..d0ca334f --- /dev/null +++ b/src/memory_system/retrieval_tools/tool_utils.py @@ -0,0 +1,64 @@ +""" +工具函数库 +包含所有工具共用的工具函数 +""" + +from datetime import datetime +from typing import Tuple + + +def parse_datetime_to_timestamp(value: str) -> float: + """ + 接受多种常见格式并转换为时间戳(秒) + 支持示例: + - 2025-09-29 + - 2025-09-29 00:00:00 + - 2025/09/29 00:00 + - 2025-09-29T00:00:00 + """ + value = value.strip() + fmts = [ + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d", + "%Y/%m/%d", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M", + ] + last_err = None + for fmt in fmts: + try: + dt = datetime.strptime(value, fmt) + return dt.timestamp() + except Exception as e: + last_err = e + raise ValueError(f"无法解析时间: {value} ({last_err})") + + +def parse_time_range(time_range: str) -> Tuple[float, float]: + """ + 解析时间范围字符串,返回开始和结束时间戳 + + Args: + time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS" + + Returns: + Tuple[float, float]: (开始时间戳, 结束时间戳) + """ + if " - " not in time_range: + raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}") + + parts = time_range.split(" - ", 1) + if len(parts) != 2: + raise ValueError(f"时间范围格式错误: {time_range}") + + start_str = parts[0].strip() + end_str = parts[1].strip() + + start_timestamp = parse_datetime_to_timestamp(start_str) + end_timestamp = parse_datetime_to_timestamp(end_str) + + return start_timestamp, end_timestamp +