From 71a2a4282be026da410c4294bcc5c6bd67ad44e4 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Mon, 10 Nov 2025 12:27:54 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E4=BC=98=E5=8C=96=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=8F=90=E5=8F=96=E5=92=8C=E8=81=8A=E5=A4=A9=E5=8E=8B?= =?UTF-8?q?=E7=BC=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/chat_history_summarizer.py | 21 +- src/chat/utils/utils.py | 11 +- src/memory_system/memory_retrieval.py | 180 ++++++++++++------ .../retrieval_tools/query_chat_history.py | 114 ++++++----- .../retrieval_tools/query_jargon.py | 25 +-- 5 files changed, 212 insertions(+), 139 deletions(-) diff --git a/src/chat/utils/chat_history_summarizer.py b/src/chat/utils/chat_history_summarizer.py index 4c984892..07553460 100644 --- a/src/chat/utils/chat_history_summarizer.py +++ b/src/chat/utils/chat_history_summarizer.py @@ -269,7 +269,16 @@ class ChatHistorySummarizer: ) # 使用LLM压缩聊天内容 - theme, keywords, summary = await self._compress_with_llm(original_text) + success, theme, keywords, summary = await self._compress_with_llm(original_text) + + if not success: + logger.warning( + f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}" + ) + # 清空当前批次,避免重复处理 + self.current_batch = None + return + logger.info( f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)} 字" ) @@ -297,12 +306,12 @@ class ChatHistorySummarizer: # 出错时也清空批次,避免重复处理 self.current_batch = None - async def _compress_with_llm(self, original_text: str) -> tuple[str, List[str], str]: + async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]: """ 使用LLM压缩聊天内容 Returns: - tuple[str, List[str], str]: (主题, 关键词列表, 概括) + tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括) """ prompt = f"""请对以下聊天记录进行概括,提取以下信息: @@ -353,13 +362,13 @@ class ChatHistorySummarizer: if isinstance(keywords, str): keywords = [keywords] - return theme, keywords, summary + return True, theme, keywords, summary except Exception as e: logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - # 返回默认值 - return "未命名对话", [], "压缩失败,无法生成概括" + # 返回失败标志和默认值 + return False, "未命名对话", [], "压缩失败,无法生成概括" async def _store_to_database( self, diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index d85f8143..cb1559dc 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -221,14 +221,17 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: while i < len(text): char = text[i] if char in separators: - # 检查分割条件:如果空格左右都是英文字母,则不分割(仅对空格应用此规则) + # 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则) can_split = True if 0 < i < len(text) - 1: prev_char = text[i - 1] next_char = text[i + 1] - # 只对空格应用"不分割两个英文之间的空格"规则 - if char == ' ' and is_english_letter(prev_char) and is_english_letter(next_char): - can_split = False + # 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则 + if char == ' ': + prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char) + next_is_alnum = next_char.isdigit() or is_english_letter(next_char) + if prev_is_alnum and next_is_alnum: + can_split = False if can_split: # 只有当当前段不为空时才添加 diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index fd7a2daf..0b608d07 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -2,6 +2,7 @@ import time import json import re import random +import asyncio from typing import List, Dict, Any, Optional, Tuple from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -430,6 +431,100 @@ def _store_thinking_back( logger.error(f"存储思考过程失败: {e}") +def _get_max_iterations_by_question_count(question_count: int) -> int: + """根据问题数量获取最大迭代次数 + + Args: + question_count: 问题数量 + + Returns: + int: 最大迭代次数 + """ + if question_count == 1: + return 5 + elif question_count == 2: + return 3 + else: # 3个或以上 + return 1 + + +async def _process_single_question( + question: str, + chat_id: str, + context: str, + max_iterations: int +) -> Optional[str]: + """处理单个问题的查询(包含缓存检查逻辑) + + Args: + question: 要查询的问题 + chat_id: 聊天ID + context: 上下文信息 + max_iterations: 最大迭代次数 + + Returns: + Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None + """ + logger.info(f"开始处理问题: {question}") + + # 先检查thinking_back数据库中是否有现成答案 + cached_result = _query_thinking_back(chat_id, question) + should_requery = False + + if cached_result: + cached_found_answer, cached_answer = cached_result + + # 根据found_answer的值决定是否重新查询 + if cached_found_answer: # found_answer == 1 (True) + # found_answer == 1:20%概率重新查询 + if random.random() < 0.2: + should_requery = True + logger.info(f"found_answer=1,触发20%概率重新查询,问题: {question[:50]}...") + else: # found_answer == 0 (False) + # found_answer == 0:40%概率重新查询 + if random.random() < 0.4: + should_requery = True + logger.info(f"found_answer=0,触发40%概率重新查询,问题: {question[:50]}...") + + # 如果不需要重新查询,使用缓存答案 + if not should_requery: + if cached_answer: + logger.info(f"从thinking_back缓存中获取答案,问题: {question[:50]}...") + return f"问题:{question}\n答案:{cached_answer}" + else: + # 缓存中没有答案,需要查询 + should_requery = True + + # 如果没有缓存答案或需要重新查询,使用ReAct Agent查询 + if not cached_result or should_requery: + if should_requery: + logger.info(f"概率触发重新查询,使用ReAct Agent查询,问题: {question[:50]}...") + else: + logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...") + + found_answer, answer, thinking_steps = await _react_agent_solve_question( + question=question, + chat_id=chat_id, + max_iterations=max_iterations, + timeout=30.0 + ) + + # 存储到数据库 + _store_thinking_back( + chat_id=chat_id, + question=question, + context=context, + found_answer=found_answer, + answer=answer, + thinking_steps=thinking_steps + ) + + if found_answer and answer: + return f"问题:{question}\n答案:{answer}" + + return None + + async def build_memory_retrieval_prompt( message: str, sender: str, @@ -498,68 +593,31 @@ async def build_memory_retrieval_prompt( logger.info(f"解析到 {len(questions)} 个问题: {questions}") - # 第二步:对每个问题查询答案 + # 第二步:根据问题数量确定最大迭代次数 + max_iterations = _get_max_iterations_by_question_count(len(questions)) + logger.info(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations}") + + # 并行处理所有问题 + question_tasks = [ + _process_single_question( + question=question, + chat_id=chat_id, + context=message, + max_iterations=max_iterations + ) + for question in questions + ] + + # 并行执行所有查询任务 + results = await asyncio.gather(*question_tasks, return_exceptions=True) + + # 收集所有有效结果 all_results = [] - for question in questions: - logger.info(f"开始处理问题: {question}") - - # 先检查thinking_back数据库中是否有现成答案 - cached_result = _query_thinking_back(chat_id, question) - should_requery = False - - if cached_result: - cached_found_answer, cached_answer = cached_result - - # 根据found_answer的值决定是否重新查询 - if cached_found_answer: # found_answer == 1 (True) - # found_answer == 1:20%概率重新查询 - if random.random() < 0.2: - should_requery = True - logger.info(f"found_answer=1,触发20%概率重新查询,问题: {question[:50]}...") - else: - # 使用缓存答案 - if cached_answer: - logger.info(f"从thinking_back缓存中获取答案(found_answer=1),问题: {question[:50]}...") - all_results.append(f"问题:{question}\n答案:{cached_answer}") - continue # 跳过ReAct Agent查询 - else: # found_answer == 0 (False) - # found_answer == 0:40%概率重新查询 - if random.random() < 0.4: - should_requery = True - logger.info(f"found_answer=0,触发40%概率重新查询,问题: {question[:50]}...") - else: - # 使用缓存答案(即使found_answer=0,也可能有部分答案) - if cached_answer: - logger.info(f"从thinking_back缓存中获取答案(found_answer=0),问题: {question[:50]}...") - all_results.append(f"问题:{question}\n答案:{cached_answer}") - continue # 跳过ReAct Agent查询 - - # 如果没有缓存答案或需要重新查询,使用ReAct Agent查询 - if not cached_result or should_requery: - if should_requery: - logger.info(f"概率触发重新查询,使用ReAct Agent查询,问题: {question[:50]}...") - else: - logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...") - - found_answer, answer, thinking_steps = await _react_agent_solve_question( - question=question, - chat_id=chat_id, - max_iterations=5, - timeout=30.0 - ) - - # 存储到数据库 - _store_thinking_back( - chat_id=chat_id, - question=question, - context=message, # 只存储前500字符作为上下文 - found_answer=found_answer, - answer=answer, - thinking_steps=thinking_steps - ) - - if found_answer and answer: - all_results.append(f"问题:{question}\n答案:{answer}") + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"处理问题 '{questions[i]}' 时发生异常: {result}") + elif result is not None: + all_results.append(result) end_time = time.time() diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index da5371ff..4be8ccab 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -9,6 +9,7 @@ from src.common.logger import get_logger from src.config.config import model_config from src.common.database.database_model import ChatHistory from src.llm_models.utils_model import LLMRequest +from src.chat.utils.utils import parse_keywords_string from .tool_registry import register_memory_retrieval_tool from .tool_utils import parse_datetime_to_timestamp, parse_time_range @@ -18,51 +19,46 @@ logger = get_logger("memory_retrieval_tools") async def query_chat_history( chat_id: str, keyword: Optional[str] = None, - time_point: Optional[str] = None, time_range: Optional[str] = None ) -> str: """根据时间或关键词在chat_history表中查询聊天记录概述 Args: chat_id: 聊天ID - keyword: 关键词(可选) - time_point: 时间点,格式:YYYY-MM-DD HH:MM:SS(可选) - time_range: 时间范围,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选) + keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔) + time_range: 时间范围或时间点,格式: + - 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS" + - 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录) Returns: str: 查询结果 """ try: # 检查参数 - if not keyword and not time_point and not time_range: - return "未指定查询参数(需要提供keyword、time_point或time_range之一)" + if not keyword and not time_range: + return "未指定查询参数(需要提供keyword或time_range之一)" # 构建查询条件 query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) # 时间过滤条件 - time_conditions = [] - if time_point: - # 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time) - target_timestamp = parse_datetime_to_timestamp(time_point) - time_conditions.append( - (ChatHistory.start_time <= target_timestamp) & - (ChatHistory.end_time >= target_timestamp) - ) - elif time_range: - # 时间范围:查询与时间范围有交集的记录 - start_timestamp, end_timestamp = parse_time_range(time_range) - # 交集条件:start_time < end_timestamp AND end_time > start_timestamp - time_conditions.append( - (ChatHistory.start_time < end_timestamp) & - (ChatHistory.end_time > start_timestamp) - ) - - if time_conditions: - # 合并所有时间条件(OR关系) - time_filter = time_conditions[0] - for condition in time_conditions[1:]: - time_filter = time_filter | condition + if time_range: + # 判断是时间点还是时间范围 + if " - " in time_range: + # 时间范围:查询与时间范围有交集的记录 + start_timestamp, end_timestamp = parse_time_range(time_range) + # 交集条件:start_time < end_timestamp AND end_time > start_timestamp + time_filter = ( + (ChatHistory.start_time < end_timestamp) & + (ChatHistory.end_time > start_timestamp) + ) + else: + # 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time) + target_timestamp = parse_datetime_to_timestamp(time_range) + time_filter = ( + (ChatHistory.start_time <= target_timestamp) & + (ChatHistory.end_time >= target_timestamp) + ) query = query.where(time_filter) # 执行查询 @@ -73,7 +69,17 @@ async def query_chat_history( # 如果有关键词,进一步过滤 if keyword: - keyword_lower = keyword.lower() + # 解析多个关键词(支持空格、逗号等分隔符) + keywords_list = parse_keywords_string(keyword) + if not keywords_list: + keywords_list = [keyword.strip()] if keyword.strip() else [] + + # 转换为小写以便匹配 + keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()] + + if not keywords_lower: + return "关键词为空" + filtered_records = [] for record in records: @@ -82,25 +88,32 @@ async def query_chat_history( summary = (record.summary or "").lower() original_text = (record.original_text or "").lower() - # 解析keywords JSON - keywords_list = [] + # 解析record中的keywords JSON + record_keywords_list = [] if record.keywords: try: keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords if isinstance(keywords_data, list): - keywords_list = [str(k).lower() for k in keywords_data] + record_keywords_list = [str(k).lower() for k in keywords_data] except (json.JSONDecodeError, TypeError, ValueError): pass - # 检查是否包含关键词 - if (keyword_lower in theme or - keyword_lower in summary or - keyword_lower in original_text or - any(keyword_lower in k for k in keywords_list)): + # 检查是否包含任意一个关键词(OR关系) + matched = False + for kw in keywords_lower: + if (kw in theme or + kw in summary or + kw in original_text or + any(kw in k for k in record_keywords_list)): + matched = True + break + + if matched: filtered_records.append(record) if not filtered_records: - return f"未找到包含关键词'{keyword}'的聊天记录概述" + keywords_str = "、".join(keywords_list) + return f"未找到包含关键词'{keywords_str}'的聊天记录概述" records = filtered_records @@ -146,11 +159,18 @@ async def query_chat_history( query_desc = [] if keyword: - query_desc.append(f"关键词:{keyword}") - if time_point: - query_desc.append(f"时间点:{time_point}") + # 解析关键词列表用于显示 + keywords_list = parse_keywords_string(keyword) + if keywords_list: + keywords_str = "、".join(keywords_list) + query_desc.append(f"关键词:{keywords_str}") + else: + query_desc.append(f"关键词:{keyword}") if time_range: - query_desc.append(f"时间范围:{time_range}") + if " - " in time_range: + query_desc.append(f"时间范围:{time_range}") + else: + query_desc.append(f"时间点:{time_range}") query_info = ",".join(query_desc) if query_desc else "聊天记录概述" @@ -201,19 +221,13 @@ def register_tool(): { "name": "keyword", "type": "string", - "description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)", - "required": False - }, - { - "name": "time_point", - "type": "string", - "description": "时间点,格式:YYYY-MM-DD HH:MM:SS(可选,与time_range二选一)。用于查询包含该时间点的聊天记录概述", + "description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)", "required": False }, { "name": "time_range", "type": "string", - "description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(可选,与time_point二选一)。用于查询与时间范围有交集的聊天记录概述", + "description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)", "required": False } ], diff --git a/src/memory_system/retrieval_tools/query_jargon.py b/src/memory_system/retrieval_tools/query_jargon.py index f8accf08..998973be 100644 --- a/src/memory_system/retrieval_tools/query_jargon.py +++ b/src/memory_system/retrieval_tools/query_jargon.py @@ -12,8 +12,7 @@ logger = get_logger("memory_retrieval_tools") async def query_jargon( keyword: str, chat_id: str, - fuzzy: bool = False, - search_all: bool = False + fuzzy: bool = False ) -> str: """根据关键词在jargon库中查询 @@ -21,7 +20,6 @@ async def query_jargon( keyword: 关键词(黑话/俚语/缩写) chat_id: 聊天ID fuzzy: 是否使用模糊搜索,默认False(精确匹配) - search_all: 是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global) Returns: str: 查询结果 @@ -31,11 +29,10 @@ async def query_jargon( if not content: return "关键词为空" - # 根据参数执行搜索 - search_chat_id = None if search_all else chat_id + # 执行搜索(仅搜索当前会话或全局) results = search_jargon( keyword=content, - chat_id=search_chat_id, + chat_id=chat_id, limit=1, case_sensitive=False, fuzzy=fuzzy @@ -46,15 +43,13 @@ async def query_jargon( translation = result.get("translation", "").strip() meaning = result.get("meaning", "").strip() search_type = "模糊搜索" if fuzzy else "精确匹配" - search_scope = "全库" if search_all else "当前会话或全局" - output = f"“{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}”" - logger.info(f"在jargon库中找到匹配({search_scope},{search_type}): {content}") + output = f'"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"' + logger.info(f"在jargon库中找到匹配(当前会话或全局,{search_type}): {content}") return output # 未命中 search_type = "模糊搜索" if fuzzy else "精确匹配" - search_scope = "全库" if search_all else "当前会话或全局" - logger.info(f"在jargon库中未找到匹配({search_scope},{search_type}): {content}") + logger.info(f"在jargon库中未找到匹配(当前会话或全局,{search_type}): {content}") return f"未在jargon库中找到'{content}'的解释" except Exception as e: @@ -66,7 +61,7 @@ def register_tool(): """注册工具""" register_memory_retrieval_tool( name="query_jargon", - description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon,可以设置为搜索全库。", + description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。仅搜索当前会话或全局jargon。", parameters=[ { "name": "keyword", @@ -79,12 +74,6 @@ def register_tool(): "type": "boolean", "description": "是否使用模糊搜索(部分匹配),默认False(精确匹配)。当精确匹配找不到时,可以尝试使用模糊搜索。", "required": False - }, - { - "name": "search_all", - "type": "boolean", - "description": "是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global的jargon)。当在当前会话中找不到时,可以尝试搜索全库。", - "required": False } ], execute_func=query_jargon