fix:优化记忆提取和聊天压缩

pull/1356/head
SengokuCola 2025-11-10 12:27:54 +08:00
parent 10cd2474af
commit 71a2a4282b
5 changed files with 212 additions and 139 deletions

View File

@ -269,7 +269,16 @@ class ChatHistorySummarizer:
)
# 使用LLM压缩聊天内容
theme, keywords, summary = await self._compress_with_llm(original_text)
success, theme, keywords, summary = await self._compress_with_llm(original_text)
if not success:
logger.warning(
f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}"
)
# 清空当前批次,避免重复处理
self.current_batch = None
return
logger.info(
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}"
)
@ -297,12 +306,12 @@ class ChatHistorySummarizer:
# 出错时也清空批次,避免重复处理
self.current_batch = None
async def _compress_with_llm(self, original_text: str) -> tuple[str, List[str], str]:
async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
"""
使用LLM压缩聊天内容
Returns:
tuple[str, List[str], str]: (主题, 关键词列表, 概括)
tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
"""
prompt = f"""请对以下聊天记录进行概括,提取以下信息:
@ -353,13 +362,13 @@ class ChatHistorySummarizer:
if isinstance(keywords, str):
keywords = [keywords]
return theme, keywords, summary
return True, theme, keywords, summary
except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回默认值
return "未命名对话", [], "压缩失败,无法生成概括"
# 返回失败标志和默认值
return False, "未命名对话", [], "压缩失败,无法生成概括"
async def _store_to_database(
self,

View File

@ -221,14 +221,17 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
while i < len(text):
char = text[i]
if char in separators:
# 检查分割条件:如果空格左右都是英文字母,则不分割(仅对空格应用此规则)
# 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则)
can_split = True
if 0 < i < len(text) - 1:
prev_char = text[i - 1]
next_char = text[i + 1]
# 只对空格应用"不分割两个英文之间的空格"规则
if char == ' ' and is_english_letter(prev_char) and is_english_letter(next_char):
can_split = False
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == ' ':
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum:
can_split = False
if can_split:
# 只有当当前段不为空时才添加

View File

@ -2,6 +2,7 @@ import time
import json
import re
import random
import asyncio
from typing import List, Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@ -430,6 +431,100 @@ def _store_thinking_back(
logger.error(f"存储思考过程失败: {e}")
def _get_max_iterations_by_question_count(question_count: int) -> int:
"""根据问题数量获取最大迭代次数
Args:
question_count: 问题数量
Returns:
int: 最大迭代次数
"""
if question_count == 1:
return 5
elif question_count == 2:
return 3
else: # 3个或以上
return 1
async def _process_single_question(
question: str,
chat_id: str,
context: str,
max_iterations: int
) -> Optional[str]:
"""处理单个问题的查询(包含缓存检查逻辑)
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
Optional[str]: 如果找到答案返回格式化的结果字符串否则返回None
"""
logger.info(f"开始处理问题: {question}")
# 先检查thinking_back数据库中是否有现成答案
cached_result = _query_thinking_back(chat_id, question)
should_requery = False
if cached_result:
cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询
if random.random() < 0.2:
should_requery = True
logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...")
else: # found_answer == 0 (False)
# found_answer == 040%概率重新查询
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
# 如果不需要重新查询,使用缓存答案
if not should_requery:
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案问题: {question[:50]}...")
return f"问题:{question}\n答案:{cached_answer}"
else:
# 缓存中没有答案,需要查询
should_requery = True
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery:
if should_requery:
logger.info(f"概率触发重新查询使用ReAct Agent查询问题: {question[:50]}...")
else:
logger.info(f"未找到缓存答案使用ReAct Agent查询问题: {question[:50]}...")
found_answer, answer, thinking_steps = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=30.0
)
# 存储到数据库
_store_thinking_back(
chat_id=chat_id,
question=question,
context=context,
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps
)
if found_answer and answer:
return f"问题:{question}\n答案:{answer}"
return None
async def build_memory_retrieval_prompt(
message: str,
sender: str,
@ -498,68 +593,31 @@ async def build_memory_retrieval_prompt(
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
# 第二步:对每个问题查询答案
# 第二步:根据问题数量确定最大迭代次数
max_iterations = _get_max_iterations_by_question_count(len(questions))
logger.info(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations}")
# 并行处理所有问题
question_tasks = [
_process_single_question(
question=question,
chat_id=chat_id,
context=message,
max_iterations=max_iterations
)
for question in questions
]
# 并行执行所有查询任务
results = await asyncio.gather(*question_tasks, return_exceptions=True)
# 收集所有有效结果
all_results = []
for question in questions:
logger.info(f"开始处理问题: {question}")
# 先检查thinking_back数据库中是否有现成答案
cached_result = _query_thinking_back(chat_id, question)
should_requery = False
if cached_result:
cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询
if random.random() < 0.2:
should_requery = True
logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...")
else:
# 使用缓存答案
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案found_answer=1问题: {question[:50]}...")
all_results.append(f"问题:{question}\n答案:{cached_answer}")
continue # 跳过ReAct Agent查询
else: # found_answer == 0 (False)
# found_answer == 040%概率重新查询
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
else:
# 使用缓存答案即使found_answer=0也可能有部分答案
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案found_answer=0问题: {question[:50]}...")
all_results.append(f"问题:{question}\n答案:{cached_answer}")
continue # 跳过ReAct Agent查询
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery:
if should_requery:
logger.info(f"概率触发重新查询使用ReAct Agent查询问题: {question[:50]}...")
else:
logger.info(f"未找到缓存答案使用ReAct Agent查询问题: {question[:50]}...")
found_answer, answer, thinking_steps = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=5,
timeout=30.0
)
# 存储到数据库
_store_thinking_back(
chat_id=chat_id,
question=question,
context=message, # 只存储前500字符作为上下文
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps
)
if found_answer and answer:
all_results.append(f"问题:{question}\n答案:{answer}")
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"处理问题 '{questions[i]}' 时发生异常: {result}")
elif result is not None:
all_results.append(result)
end_time = time.time()

View File

@ -9,6 +9,7 @@ from src.common.logger import get_logger
from src.config.config import model_config
from src.common.database.database_model import ChatHistory
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import parse_keywords_string
from .tool_registry import register_memory_retrieval_tool
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
@ -18,51 +19,46 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_point: Optional[str] = None,
time_range: Optional[str] = None
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词可选
time_point: 时间点格式YYYY-MM-DD HH:MM:SS可选
time_range: 时间范围格式"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"可选
keyword: 关键词可选支持多个关键词可用空格逗号等分隔
time_range: 时间范围或时间点格式
- 时间范围"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点"YYYY-MM-DD HH:MM:SS"查询包含该时间点的记录
Returns:
str: 查询结果
"""
try:
# 检查参数
if not keyword and not time_point and not time_range:
return "未指定查询参数需要提供keyword、time_point或time_range之一"
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
time_conditions = []
if time_point:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_point)
time_conditions.append(
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
elif time_range:
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_conditions.append(
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
if time_conditions:
# 合并所有时间条件OR关系
time_filter = time_conditions[0]
for condition in time_conditions[1:]:
time_filter = time_filter | condition
if time_range:
# 判断是时间点还是时间范围
if " - " in time_range:
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
query = query.where(time_filter)
# 执行查询
@ -73,7 +69,17 @@ async def query_chat_history(
# 如果有关键词,进一步过滤
if keyword:
keyword_lower = keyword.lower()
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
@ -82,25 +88,32 @@ async def query_chat_history(
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析keywords JSON
keywords_list = []
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(keywords_data, list):
keywords_list = [str(k).lower() for k in keywords_data]
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 检查是否包含关键词
if (keyword_lower in theme or
keyword_lower in summary or
keyword_lower in original_text or
any(keyword_lower in k for k in keywords_list)):
# 检查是否包含任意一个关键词OR关系
matched = False
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
matched = True
break
if matched:
filtered_records.append(record)
if not filtered_records:
return f"未找到包含关键词'{keyword}'的聊天记录概述"
keywords_str = "".join(keywords_list)
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
records = filtered_records
@ -146,11 +159,18 @@ async def query_chat_history(
query_desc = []
if keyword:
query_desc.append(f"关键词:{keyword}")
if time_point:
query_desc.append(f"时间点:{time_point}")
# 解析关键词列表用于显示
keywords_list = parse_keywords_string(keyword)
if keywords_list:
keywords_str = "".join(keywords_list)
query_desc.append(f"关键词:{keywords_str}")
else:
query_desc.append(f"关键词:{keyword}")
if time_range:
query_desc.append(f"时间范围:{time_range}")
if " - " in time_range:
query_desc.append(f"时间范围:{time_range}")
else:
query_desc.append(f"时间点:{time_range}")
query_info = "".join(query_desc) if query_desc else "聊天记录概述"
@ -201,19 +221,13 @@ def register_tool():
{
"name": "keyword",
"type": "string",
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)",
"required": False
},
{
"name": "time_point",
"type": "string",
"description": "时间点格式YYYY-MM-DD HH:MM:SS可选与time_range二选一。用于查询包含该时间点的聊天记录概述",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)",
"required": False
},
{
"name": "time_range",
"type": "string",
"description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'可选与time_point二选一。用于查询与时间范围有交集的聊天记录概述",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
}
],

View File

@ -12,8 +12,7 @@ logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str,
fuzzy: bool = False,
search_all: bool = False
fuzzy: bool = False
) -> str:
"""根据关键词在jargon库中查询
@ -21,7 +20,6 @@ async def query_jargon(
keyword: 关键词黑话/俚语/缩写
chat_id: 聊天ID
fuzzy: 是否使用模糊搜索默认False精确匹配
search_all: 是否搜索全库不限chat_id默认False仅搜索当前会话或global
Returns:
str: 查询结果
@ -31,11 +29,10 @@ async def query_jargon(
if not content:
return "关键词为空"
# 根据参数执行搜索
search_chat_id = None if search_all else chat_id
# 执行搜索(仅搜索当前会话或全局)
results = search_jargon(
keyword=content,
chat_id=search_chat_id,
chat_id=chat_id,
limit=1,
case_sensitive=False,
fuzzy=fuzzy
@ -46,15 +43,13 @@ async def query_jargon(
translation = result.get("translation", "").strip()
meaning = result.get("meaning", "").strip()
search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局"
output = f"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"
logger.info(f"在jargon库中找到匹配{search_scope}{search_type}: {content}")
output = f'"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"'
logger.info(f"在jargon库中找到匹配当前会话或全局{search_type}: {content}")
return output
# 未命中
search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局"
logger.info(f"在jargon库中未找到匹配{search_scope}{search_type}: {content}")
logger.info(f"在jargon库中未找到匹配当前会话或全局{search_type}: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
@ -66,7 +61,7 @@ def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon可以设置为搜索全库",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。仅搜索当前会话或全局jargon",
parameters=[
{
"name": "keyword",
@ -79,12 +74,6 @@ def register_tool():
"type": "boolean",
"description": "是否使用模糊搜索部分匹配默认False精确匹配。当精确匹配找不到时可以尝试使用模糊搜索。",
"required": False
},
{
"name": "search_all",
"type": "boolean",
"description": "是否搜索全库不限chat_id默认False仅搜索当前会话或global的jargon。当在当前会话中找不到时可以尝试搜索全库。",
"required": False
}
],
execute_func=query_jargon