fix:优化记忆提取和聊天压缩

pull/1356/head
SengokuCola 2025-11-10 12:27:54 +08:00
parent 10cd2474af
commit 71a2a4282b
5 changed files with 212 additions and 139 deletions

View File

@ -269,7 +269,16 @@ class ChatHistorySummarizer:
) )
# 使用LLM压缩聊天内容 # 使用LLM压缩聊天内容
theme, keywords, summary = await self._compress_with_llm(original_text) success, theme, keywords, summary = await self._compress_with_llm(original_text)
if not success:
logger.warning(
f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}"
)
# 清空当前批次,避免重复处理
self.current_batch = None
return
logger.info( logger.info(
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}" f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}"
) )
@ -297,12 +306,12 @@ class ChatHistorySummarizer:
# 出错时也清空批次,避免重复处理 # 出错时也清空批次,避免重复处理
self.current_batch = None self.current_batch = None
async def _compress_with_llm(self, original_text: str) -> tuple[str, List[str], str]: async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
""" """
使用LLM压缩聊天内容 使用LLM压缩聊天内容
Returns: Returns:
tuple[str, List[str], str]: (主题, 关键词列表, 概括) tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
""" """
prompt = f"""请对以下聊天记录进行概括,提取以下信息: prompt = f"""请对以下聊天记录进行概括,提取以下信息:
@ -353,13 +362,13 @@ class ChatHistorySummarizer:
if isinstance(keywords, str): if isinstance(keywords, str):
keywords = [keywords] keywords = [keywords]
return theme, keywords, summary return True, theme, keywords, summary
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回默认值 # 返回失败标志和默认值
return "未命名对话", [], "压缩失败,无法生成概括" return False, "未命名对话", [], "压缩失败,无法生成概括"
async def _store_to_database( async def _store_to_database(
self, self,

View File

@ -221,13 +221,16 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
while i < len(text): while i < len(text):
char = text[i] char = text[i]
if char in separators: if char in separators:
# 检查分割条件:如果空格左右都是英文字母,则不分割(仅对空格应用此规则) # 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则)
can_split = True can_split = True
if 0 < i < len(text) - 1: if 0 < i < len(text) - 1:
prev_char = text[i - 1] prev_char = text[i - 1]
next_char = text[i + 1] next_char = text[i + 1]
# 只对空格应用"不分割两个英文之间的空格"规则 # 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == ' ' and is_english_letter(prev_char) and is_english_letter(next_char): if char == ' ':
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum:
can_split = False can_split = False
if can_split: if can_split:

View File

@ -2,6 +2,7 @@ import time
import json import json
import re import re
import random import random
import asyncio
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@ -430,6 +431,100 @@ def _store_thinking_back(
logger.error(f"存储思考过程失败: {e}") logger.error(f"存储思考过程失败: {e}")
def _get_max_iterations_by_question_count(question_count: int) -> int:
"""根据问题数量获取最大迭代次数
Args:
question_count: 问题数量
Returns:
int: 最大迭代次数
"""
if question_count == 1:
return 5
elif question_count == 2:
return 3
else: # 3个或以上
return 1
async def _process_single_question(
question: str,
chat_id: str,
context: str,
max_iterations: int
) -> Optional[str]:
"""处理单个问题的查询(包含缓存检查逻辑)
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
Optional[str]: 如果找到答案返回格式化的结果字符串否则返回None
"""
logger.info(f"开始处理问题: {question}")
# 先检查thinking_back数据库中是否有现成答案
cached_result = _query_thinking_back(chat_id, question)
should_requery = False
if cached_result:
cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询
if random.random() < 0.2:
should_requery = True
logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...")
else: # found_answer == 0 (False)
# found_answer == 040%概率重新查询
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
# 如果不需要重新查询,使用缓存答案
if not should_requery:
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案问题: {question[:50]}...")
return f"问题:{question}\n答案:{cached_answer}"
else:
# 缓存中没有答案,需要查询
should_requery = True
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery:
if should_requery:
logger.info(f"概率触发重新查询使用ReAct Agent查询问题: {question[:50]}...")
else:
logger.info(f"未找到缓存答案使用ReAct Agent查询问题: {question[:50]}...")
found_answer, answer, thinking_steps = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=30.0
)
# 存储到数据库
_store_thinking_back(
chat_id=chat_id,
question=question,
context=context,
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps
)
if found_answer and answer:
return f"问题:{question}\n答案:{answer}"
return None
async def build_memory_retrieval_prompt( async def build_memory_retrieval_prompt(
message: str, message: str,
sender: str, sender: str,
@ -498,68 +593,31 @@ async def build_memory_retrieval_prompt(
logger.info(f"解析到 {len(questions)} 个问题: {questions}") logger.info(f"解析到 {len(questions)} 个问题: {questions}")
# 第二步:对每个问题查询答案 # 第二步:根据问题数量确定最大迭代次数
max_iterations = _get_max_iterations_by_question_count(len(questions))
logger.info(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations}")
# 并行处理所有问题
question_tasks = [
_process_single_question(
question=question,
chat_id=chat_id,
context=message,
max_iterations=max_iterations
)
for question in questions
]
# 并行执行所有查询任务
results = await asyncio.gather(*question_tasks, return_exceptions=True)
# 收集所有有效结果
all_results = [] all_results = []
for question in questions: for i, result in enumerate(results):
logger.info(f"开始处理问题: {question}") if isinstance(result, Exception):
logger.error(f"处理问题 '{questions[i]}' 时发生异常: {result}")
# 先检查thinking_back数据库中是否有现成答案 elif result is not None:
cached_result = _query_thinking_back(chat_id, question) all_results.append(result)
should_requery = False
if cached_result:
cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询
if random.random() < 0.2:
should_requery = True
logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...")
else:
# 使用缓存答案
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案found_answer=1问题: {question[:50]}...")
all_results.append(f"问题:{question}\n答案:{cached_answer}")
continue # 跳过ReAct Agent查询
else: # found_answer == 0 (False)
# found_answer == 040%概率重新查询
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
else:
# 使用缓存答案即使found_answer=0也可能有部分答案
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案found_answer=0问题: {question[:50]}...")
all_results.append(f"问题:{question}\n答案:{cached_answer}")
continue # 跳过ReAct Agent查询
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery:
if should_requery:
logger.info(f"概率触发重新查询使用ReAct Agent查询问题: {question[:50]}...")
else:
logger.info(f"未找到缓存答案使用ReAct Agent查询问题: {question[:50]}...")
found_answer, answer, thinking_steps = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=5,
timeout=30.0
)
# 存储到数据库
_store_thinking_back(
chat_id=chat_id,
question=question,
context=message, # 只存储前500字符作为上下文
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps
)
if found_answer and answer:
all_results.append(f"问题:{question}\n答案:{answer}")
end_time = time.time() end_time = time.time()

View File

@ -9,6 +9,7 @@ from src.common.logger import get_logger
from src.config.config import model_config from src.config.config import model_config
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import parse_keywords_string
from .tool_registry import register_memory_retrieval_tool from .tool_registry import register_memory_retrieval_tool
from .tool_utils import parse_datetime_to_timestamp, parse_time_range from .tool_utils import parse_datetime_to_timestamp, parse_time_range
@ -18,51 +19,46 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history( async def query_chat_history(
chat_id: str, chat_id: str,
keyword: Optional[str] = None, keyword: Optional[str] = None,
time_point: Optional[str] = None,
time_range: Optional[str] = None time_range: Optional[str] = None
) -> str: ) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述 """根据时间或关键词在chat_history表中查询聊天记录概述
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
keyword: 关键词可选 keyword: 关键词可选支持多个关键词可用空格逗号等分隔
time_point: 时间点格式YYYY-MM-DD HH:MM:SS可选 time_range: 时间范围或时间点格式
time_range: 时间范围格式"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"可选 - 时间范围"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点"YYYY-MM-DD HH:MM:SS"查询包含该时间点的记录
Returns: Returns:
str: 查询结果 str: 查询结果
""" """
try: try:
# 检查参数 # 检查参数
if not keyword and not time_point and not time_range: if not keyword and not time_range:
return "未指定查询参数需要提供keyword、time_point或time_range之一" return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件 # 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件 # 时间过滤条件
time_conditions = [] if time_range:
if time_point: # 判断是时间点还是时间范围
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time if " - " in time_range:
target_timestamp = parse_datetime_to_timestamp(time_point)
time_conditions.append(
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
elif time_range:
# 时间范围:查询与时间范围有交集的记录 # 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range) start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp # 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_conditions.append( time_filter = (
(ChatHistory.start_time < end_timestamp) & (ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp) (ChatHistory.end_time > start_timestamp)
) )
else:
if time_conditions: # 时间点查询包含该时间点的记录start_time <= time_point <= end_time
# 合并所有时间条件OR关系 target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = time_conditions[0] time_filter = (
for condition in time_conditions[1:]: (ChatHistory.start_time <= target_timestamp) &
time_filter = time_filter | condition (ChatHistory.end_time >= target_timestamp)
)
query = query.where(time_filter) query = query.where(time_filter)
# 执行查询 # 执行查询
@ -73,7 +69,17 @@ async def query_chat_history(
# 如果有关键词,进一步过滤 # 如果有关键词,进一步过滤
if keyword: if keyword:
keyword_lower = keyword.lower() # 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = [] filtered_records = []
for record in records: for record in records:
@ -82,25 +88,32 @@ async def query_chat_history(
summary = (record.summary or "").lower() summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower() original_text = (record.original_text or "").lower()
# 解析keywords JSON # 解析record中的keywords JSON
keywords_list = [] record_keywords_list = []
if record.keywords: if record.keywords:
try: try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
keywords_list = [str(k).lower() for k in keywords_data] record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError): except (json.JSONDecodeError, TypeError, ValueError):
pass pass
# 检查是否包含关键词 # 检查是否包含任意一个关键词OR关系
if (keyword_lower in theme or matched = False
keyword_lower in summary or for kw in keywords_lower:
keyword_lower in original_text or if (kw in theme or
any(keyword_lower in k for k in keywords_list)): kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
matched = True
break
if matched:
filtered_records.append(record) filtered_records.append(record)
if not filtered_records: if not filtered_records:
return f"未找到包含关键词'{keyword}'的聊天记录概述" keywords_str = "".join(keywords_list)
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
records = filtered_records records = filtered_records
@ -146,11 +159,18 @@ async def query_chat_history(
query_desc = [] query_desc = []
if keyword: if keyword:
# 解析关键词列表用于显示
keywords_list = parse_keywords_string(keyword)
if keywords_list:
keywords_str = "".join(keywords_list)
query_desc.append(f"关键词:{keywords_str}")
else:
query_desc.append(f"关键词:{keyword}") query_desc.append(f"关键词:{keyword}")
if time_point:
query_desc.append(f"时间点:{time_point}")
if time_range: if time_range:
if " - " in time_range:
query_desc.append(f"时间范围:{time_range}") query_desc.append(f"时间范围:{time_range}")
else:
query_desc.append(f"时间点:{time_range}")
query_info = "".join(query_desc) if query_desc else "聊天记录概述" query_info = "".join(query_desc) if query_desc else "聊天记录概述"
@ -201,19 +221,13 @@ def register_tool():
{ {
"name": "keyword", "name": "keyword",
"type": "string", "type": "string",
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)", "description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)",
"required": False
},
{
"name": "time_point",
"type": "string",
"description": "时间点格式YYYY-MM-DD HH:MM:SS可选与time_range二选一。用于查询包含该时间点的聊天记录概述",
"required": False "required": False
}, },
{ {
"name": "time_range", "name": "time_range",
"type": "string", "type": "string",
"description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'可选与time_point二选一。用于查询与时间范围有交集的聊天记录概述", "description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False "required": False
} }
], ],

View File

@ -12,8 +12,7 @@ logger = get_logger("memory_retrieval_tools")
async def query_jargon( async def query_jargon(
keyword: str, keyword: str,
chat_id: str, chat_id: str,
fuzzy: bool = False, fuzzy: bool = False
search_all: bool = False
) -> str: ) -> str:
"""根据关键词在jargon库中查询 """根据关键词在jargon库中查询
@ -21,7 +20,6 @@ async def query_jargon(
keyword: 关键词黑话/俚语/缩写 keyword: 关键词黑话/俚语/缩写
chat_id: 聊天ID chat_id: 聊天ID
fuzzy: 是否使用模糊搜索默认False精确匹配 fuzzy: 是否使用模糊搜索默认False精确匹配
search_all: 是否搜索全库不限chat_id默认False仅搜索当前会话或global
Returns: Returns:
str: 查询结果 str: 查询结果
@ -31,11 +29,10 @@ async def query_jargon(
if not content: if not content:
return "关键词为空" return "关键词为空"
# 根据参数执行搜索 # 执行搜索(仅搜索当前会话或全局)
search_chat_id = None if search_all else chat_id
results = search_jargon( results = search_jargon(
keyword=content, keyword=content,
chat_id=search_chat_id, chat_id=chat_id,
limit=1, limit=1,
case_sensitive=False, case_sensitive=False,
fuzzy=fuzzy fuzzy=fuzzy
@ -46,15 +43,13 @@ async def query_jargon(
translation = result.get("translation", "").strip() translation = result.get("translation", "").strip()
meaning = result.get("meaning", "").strip() meaning = result.get("meaning", "").strip()
search_type = "模糊搜索" if fuzzy else "精确匹配" search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局" output = f'"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"'
output = f"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}" logger.info(f"在jargon库中找到匹配当前会话或全局{search_type}: {content}")
logger.info(f"在jargon库中找到匹配{search_scope}{search_type}: {content}")
return output return output
# 未命中 # 未命中
search_type = "模糊搜索" if fuzzy else "精确匹配" search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局" logger.info(f"在jargon库中未找到匹配当前会话或全局{search_type}: {content}")
logger.info(f"在jargon库中未找到匹配{search_scope}{search_type}: {content}")
return f"未在jargon库中找到'{content}'的解释" return f"未在jargon库中找到'{content}'的解释"
except Exception as e: except Exception as e:
@ -66,7 +61,7 @@ def register_tool():
"""注册工具""" """注册工具"""
register_memory_retrieval_tool( register_memory_retrieval_tool(
name="query_jargon", name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon可以设置为搜索全库", description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。仅搜索当前会话或全局jargon",
parameters=[ parameters=[
{ {
"name": "keyword", "name": "keyword",
@ -79,12 +74,6 @@ def register_tool():
"type": "boolean", "type": "boolean",
"description": "是否使用模糊搜索部分匹配默认False精确匹配。当精确匹配找不到时可以尝试使用模糊搜索。", "description": "是否使用模糊搜索部分匹配默认False精确匹配。当精确匹配找不到时可以尝试使用模糊搜索。",
"required": False "required": False
},
{
"name": "search_all",
"type": "boolean",
"description": "是否搜索全库不限chat_id默认False仅搜索当前会话或global的jargon。当在当前会话中找不到时可以尝试搜索全库。",
"required": False
} }
], ],
execute_func=query_jargon execute_func=query_jargon