feat:记忆查询能力提升

pull/1394/head
SengokuCola 2025-11-28 13:33:56 +08:00
parent 76d8659910
commit 804be2fa96
3 changed files with 258 additions and 128 deletions

1
.gitignore vendored
View File

@ -25,6 +25,7 @@ run_na.bat
run_all_in_wt.bat
run.bat
log_debug/
NapCat.Shell.Windows.OneKey
run_amds.bat
run_none.bat
docs-mai/

View File

@ -76,6 +76,8 @@ def init_memory_retrieval_prompt():
- "xxxx和xxx的关系是什么"
- "xxx在某个时间点发生了什么"
问题要说明前因后果和上下文使其全面且精准
输出格式示例需要检索时
```json
{{

View File

@ -1,5 +1,5 @@
"""
根据时间或关键词在chat_history中查询 - 工具实现
根据关键词或参与人在chat_history中查询记忆 - 工具实现
从ChatHistory表的聊天记录概述库中查询
"""
@ -9,177 +9,175 @@ from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.chat.utils.utils import parse_keywords_string
from .tool_registry import register_memory_retrieval_tool
from ..memory_utils import parse_datetime_to_timestamp, parse_time_range
from datetime import datetime
logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
async def search_chat_history(
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None, fuzzy: bool = True
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args:
chat_id: 聊天ID
keyword: 关键词可选支持多个关键词可用空格逗号等分隔
time_range: 时间范围或时间点格式
- 时间范围"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点"YYYY-MM-DD HH:MM:SS"查询包含该时间点的记录
participant: 参与人昵称可选
fuzzy: 是否使用模糊匹配模式默认True
- True: 模糊匹配只要包含任意一个关键词即匹配OR关系
- False: 全匹配必须包含所有关键词才匹配AND关系
Returns:
str: 查询结果
str: 查询结果包含记忆idtheme和keywords
"""
try:
# 检查参数
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一)"
if not keyword and not participant:
return "未指定查询参数需要提供keyword或participant之一)"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
if time_range:
# 判断是时间点还是时间范围
if " - " in time_range:
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
query = query.where(time_filter)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
# 如果有关键词,进一步过滤
if keyword:
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
filtered_records = []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
for record in records:
participant_matched = True # 如果没有participant条件默认为True
keyword_matched = True # 如果没有keyword条件默认为True
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
# 检查参与人匹配
if participant:
participant_matched = False
participants_list = []
if record.participants:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
participants_data = (
json.loads(record.participants) if isinstance(record.participants, str) else record.participants
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
if isinstance(participants_data, list):
participants_list = [str(p).lower() for p in participants_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 根据匹配模式检查关键词
matched = False
if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower:
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True
break
else:
# 全匹配必须包含所有关键词才匹配AND关系
matched = True
for kw in keywords_lower:
kw_matched = (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if not kw_matched:
matched = False
break
participant_lower = participant.lower().strip()
if participant_lower and any(participant_lower in p for p in participants_list):
participant_matched = True
if matched:
filtered_records.append(record)
# 检查关键词匹配
if keyword:
keyword_matched = False
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
if not filtered_records:
keywords_str = "".join(keywords_list)
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if keywords_lower:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 根据匹配模式检查关键词
if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower:
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
keyword_matched = True
break
else:
# 全匹配必须包含所有关键词才匹配AND关系
keyword_matched = True
for kw in keywords_lower:
kw_matched = (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if not kw_matched:
keyword_matched = False
break
# 两者都匹配如果同时有participant和keyword需要两者都匹配如果只有一个条件只需要该条件匹配
matched = participant_matched and keyword_matched
if matched:
filtered_records.append(record)
if not filtered_records:
if keyword and participant:
keywords_str = "".join(parse_keywords_string(keyword) if keyword else [])
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
elif keyword:
keywords_str = "".join(parse_keywords_string(keyword))
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
if time_range:
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
else:
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
records = filtered_records
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
if not records:
if time_range:
return "未找到指定时间范围内的聊天记录概述"
return f"未找到{match_mode}'{keywords_str}'的聊天记录"
elif participant:
return f"未找到参与人包含'{participant}'的聊天记录"
else:
return "未找到相关聊天记录概述"
return "未找到相关聊天记录"
# 对即将返回的记录增加使用计数
records_to_use = records[:3]
for record in records_to_use:
try:
ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute()
record.count = (record.count or 0) + 1
except Exception as update_error:
logger.error(f"更新聊天记录概述计数失败: {update_error}")
# 构建结果文本
# 构建结果文本返回id、theme和keywords
results = []
for record in records_to_use: # 最多返回3条记录
for record in filtered_records[:20]: # 最多返回20条记录
result_parts = []
# 添加记忆ID
result_parts.append(f"记忆ID{record.id}")
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
else:
result_parts.append("主题:(无)")
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加概括优先使用summary如果没有则使用original_text的前200字符
if record.summary:
result_parts.append(f"概括:{record.summary}")
elif record.original_text:
text_preview = record.original_text[:200]
if len(record.original_text) > 200:
text_preview += "..."
result_parts.append(f"内容:{text_preview}")
# 添加关键词
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}")
else:
result_parts.append("关键词:(无)")
except (json.JSONDecodeError, TypeError, ValueError):
result_parts.append("关键词:(无)")
else:
result_parts.append("关键词:(无)")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录概述"
return "未找到相关聊天记录"
response_text = "\n\n---\n\n".join(results)
if len(records) > len(records_to_use):
omitted_count = len(records) - len(records_to_use)
response_text += f"\n\n(还有{omitted_count}历史记录已省略)"
if len(filtered_records) > 20:
omitted_count = len(filtered_records) - 20
response_text += f"\n\n(还有{omitted_count}记录已省略可使用记忆ID查询详细信息)"
return response_text
except Exception as e:
@ -187,11 +185,125 @@ async def query_chat_history(
return f"查询失败: {str(e)}"
async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
"""根据记忆ID展示某条或某几条记忆的具体内容
Args:
chat_id: 聊天ID
memory_ids: 记忆ID可以是单个ID"123"或多个ID用逗号分隔"1,2,3"
Returns:
str: 记忆的详细内容
"""
try:
# 解析memory_ids
id_list = []
# 尝试解析为逗号分隔的ID列表
try:
id_list = [int(id_str.strip()) for id_str in memory_ids.split(",") if id_str.strip()]
except ValueError:
return f"无效的记忆ID格式: {memory_ids}请使用数字ID多个ID用逗号分隔'123''123,456'"
if not id_list:
return "未提供有效的记忆ID"
# 查询记录
query = ChatHistory.select().where(
(ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list))
)
records = list(query.order_by(ChatHistory.start_time.desc()))
if not records:
return f"未找到ID为{id_list}的记忆记录可能ID不存在或不属于当前聊天"
# 对即将返回的记录增加使用计数
for record in records:
try:
ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute()
record.count = (record.count or 0) + 1
except Exception as update_error:
logger.error(f"更新聊天记录概述计数失败: {update_error}")
# 构建详细结果
results = []
for record in records:
result_parts = []
# 添加记忆ID
result_parts.append(f"记忆ID{record.id}")
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
# 添加时间范围
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加参与人
if record.participants:
try:
participants_data = (
json.loads(record.participants) if isinstance(record.participants, str) else record.participants
)
if isinstance(participants_data, list) and participants_data:
participants_str = "".join([str(p) for p in participants_data])
result_parts.append(f"参与人:{participants_str}")
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 添加关键词
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}")
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 添加概括
if record.summary:
result_parts.append(f"概括:{record.summary}")
# 添加关键信息点
if record.key_point:
try:
key_point_data = (
json.loads(record.key_point) if isinstance(record.key_point, str) else record.key_point
)
if isinstance(key_point_data, list) and key_point_data:
key_point_str = "\n".join([f" - {str(kp)}" for kp in key_point_data])
result_parts.append(f"关键信息点:\n{key_point_str}")
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 添加原文内容
if record.original_text:
result_parts.append(f"原文内容:\n{record.original_text}")
results.append("\n".join(result_parts))
if not results:
return "未找到相关记忆记录"
response_text = "\n\n" + "=" * 50 + "\n\n".join(results)
return response_text
except Exception as e:
logger.error(f"获取记忆详情失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
# 注册工具1搜索记忆
register_memory_retrieval_tool(
name="query_chat_history",
description="根据时间或关键词在聊天记录中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
name="search_chat_history",
description="根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。",
parameters=[
{
"name": "keyword",
@ -200,9 +312,9 @@ def register_tool():
"required": False,
},
{
"name": "time_range",
"name": "participant",
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"description": "参与人昵称(可选),用于查询包含该参与人的记忆",
"required": False,
},
{
@ -212,5 +324,20 @@ def register_tool():
"required": False,
},
],
execute_func=query_chat_history,
execute_func=search_chat_history,
)
# 注册工具2获取记忆详情
register_memory_retrieval_tool(
name="get_chat_history_detail",
description="根据记忆ID展示某条或某几条记忆的具体内容。包括主题、时间、参与人、关键词、概括、关键信息点和原文内容等详细信息。需要先使用search_chat_history工具获取记忆ID。",
parameters=[
{
"name": "memory_ids",
"type": "string",
"description": "记忆ID可以是单个ID'123'或多个ID用逗号分隔'123,456,789'",
"required": True,
},
],
execute_func=get_chat_history_detail,
)