mirror of https://github.com/Mai-with-u/MaiBot.git
feat:增加记忆提取能力
parent
157b1ed540
commit
0baa73aaf5
|
|
@ -26,15 +26,15 @@ class ExpressionReflector:
|
|||
bool: 是否执行了提问
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
||||
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
||||
|
||||
if not global_config.expression.reflect:
|
||||
logger.info(f"[Expression Reflection] 表达反思功能未启用,跳过")
|
||||
logger.debug(f"[Expression Reflection] 表达反思功能未启用,跳过")
|
||||
return False
|
||||
|
||||
operator_config = global_config.expression.reflect_operator_id
|
||||
if not operator_config:
|
||||
logger.info(f"[Expression Reflection] Operator ID 未配置,跳过")
|
||||
logger.debug(f"[Expression Reflection] Operator ID 未配置,跳过")
|
||||
return False
|
||||
|
||||
# 检查是否在允许列表中
|
||||
|
|
|
|||
|
|
@ -16,20 +16,17 @@ logger = get_logger("jargon")
|
|||
def _init_explainer_prompts() -> None:
|
||||
"""初始化黑话解释器相关的prompt"""
|
||||
# Prompt:概括黑话解释结果
|
||||
summarize_prompt_str = """
|
||||
**上下文聊天内容**
|
||||
summarize_prompt_str = """上下文聊天内容:
|
||||
{chat_context}
|
||||
|
||||
**提取到的黑话及其含义**
|
||||
在上下文中提取到的黑话及其含义:
|
||||
{jargon_explanations}
|
||||
|
||||
请根据上述信息,对黑话解释进行概括和整理。
|
||||
- 如果上下文中有黑话出现,请简要说明这些黑话在上下文中的使用情况
|
||||
- 将黑话解释整理成简洁、易读的格式
|
||||
- 如果某个黑话在上下文中没有出现,可以省略
|
||||
- 将所有黑话解释整理成简洁、易读的一段话
|
||||
- 输出格式要自然,适合作为回复参考信息
|
||||
|
||||
请输出概括后的黑话解释(直接输出文本,不要使用JSON格式):
|
||||
请输出概括后的黑话解释(直接输出一段平文本,不要标题,无特殊格式或markdown格式,不要使用JSON格式):
|
||||
"""
|
||||
Prompt(summarize_prompt_str, "jargon_explainer_summarize_prompt")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import time
|
||||
import json
|
||||
import re
|
||||
import random
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
|
@ -11,9 +10,8 @@ from src.plugin_system.apis import llm_api
|
|||
from src.common.database.database_model import ThinkingBack, Jargon
|
||||
from json_repair import repair_json
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.jargon.jargon_utils import parse_chat_id_list, chat_id_list_contains, contains_bot_self_name
|
||||
from src.jargon.jargon_utils import parse_chat_id_list, chat_id_list_contains
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
|
|
@ -104,11 +102,10 @@ def init_memory_retrieval_prompt():
|
|||
|
||||
**重要限制:**
|
||||
- 最大查询轮数:{max_iterations}轮(当前第{current_iteration}轮,剩余{remaining_iterations}轮)
|
||||
- 必须尽快得出答案,避免不必要的查询
|
||||
- 思考要简短,直接切入要点
|
||||
- 必须严格使用检索到的信息回答问题,不要编造信息
|
||||
|
||||
当前问题:{question}
|
||||
当前需要解答的问题:{question}
|
||||
已收集的信息:
|
||||
{collected_info}
|
||||
|
||||
|
|
@ -118,18 +115,17 @@ def init_memory_retrieval_prompt():
|
|||
- 当前信息是否足够回答问题?
|
||||
- **如果信息足够且能找到明确答案**,在思考中直接给出答案,格式为:found_answer(answer="你的答案内容")
|
||||
- **如果需要尝试搜集更多信息,进一步调用工具,进入第二步行动环节
|
||||
- **如果已有信息不足或无法找到答案**,在思考中给出:not_enough_info(reason="信息不足或无法找到答案的原因")
|
||||
- **如果已有信息不足或无法找到答案,决定结束查询**,在思考中给出:not_enough_info(reason="结束查询的原因")
|
||||
|
||||
**第二步:行动(Action)**
|
||||
- 如果涉及过往事件,可以使用聊天记录查询工具查询过往事件
|
||||
- 如果涉及概念,可以用jargon查询,或根据关键词检索聊天记录
|
||||
- 如果涉及过往事件,或者查询某个过去可能提到过的概念,或者某段时间发生的事件。可以使用聊天记录查询工具查询过往事件
|
||||
- 如果涉及人物,可以使用人物信息查询工具查询人物信息
|
||||
- 如果不确定查询类别,也可以使用lpmm知识库查询
|
||||
- 如果信息不足且需要继续查询,说明最需要查询什么,并输出为纯文本说明,然后调用相应工具查询(可并行调用多个工具)
|
||||
- 如果没有可靠信息,且查询时间充足,或者不确定查询类别,也可以使用lpmm知识库查询,作为辅助信息
|
||||
- 如果信息不足需要使用tool,说明需要查询什么,并输出为纯文本说明,然后调用相应工具查询(可并行调用多个工具)
|
||||
|
||||
**重要规则:**
|
||||
- **只有在检索到明确、有关的信息并得出答案时,才使用found_answer**
|
||||
- **如果信息不足、无法确定、找不到相关信息,必须使用not_enough_info,不要使用found_answer**
|
||||
- **如果信息不足、无法确定、找不到相关信息导致的无法回答问题,决定结束查询,必须使用not_enough_info,不要使用found_answer**
|
||||
- 答案必须在思考中给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="...")
|
||||
""",
|
||||
name="memory_retrieval_react_prompt_head",
|
||||
|
|
@ -278,6 +274,7 @@ async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> s
|
|||
|
||||
def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
"""直接在聊天文本中匹配已知的jargon,返回出现过的黑话列表"""
|
||||
print(chat_text)
|
||||
if not chat_text or not chat_text.strip():
|
||||
return []
|
||||
|
||||
|
|
@ -297,8 +294,6 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
|||
if not content:
|
||||
continue
|
||||
|
||||
if contains_bot_self_name(content):
|
||||
continue
|
||||
|
||||
if not global_config.jargon.all_global and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
|
|
@ -325,7 +320,12 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
|||
|
||||
|
||||
async def _react_agent_solve_question(
|
||||
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = ""
|
||||
question: str,
|
||||
chat_id: str,
|
||||
max_iterations: int = 5,
|
||||
timeout: float = 30.0,
|
||||
initial_info: str = "",
|
||||
initial_jargon_concepts: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||
"""使用ReAct架构的Agent来解决问题
|
||||
|
||||
|
|
@ -335,12 +335,19 @@ async def _react_agent_solve_question(
|
|||
max_iterations: 最大迭代次数
|
||||
timeout: 超时时间(秒)
|
||||
initial_info: 初始信息(如概念检索结果),将作为collected_info的初始值
|
||||
initial_jargon_concepts: 预先已解析过的黑话列表,避免重复解释
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
|
||||
"""
|
||||
start_time = time.time()
|
||||
collected_info = initial_info if initial_info else ""
|
||||
seen_jargon_concepts: Set[str] = set()
|
||||
if initial_jargon_concepts:
|
||||
for concept in initial_jargon_concepts:
|
||||
concept = (concept or "").strip()
|
||||
if concept:
|
||||
seen_jargon_concepts.add(concept)
|
||||
thinking_steps = []
|
||||
is_timeout = False
|
||||
conversation_messages: List[Message] = []
|
||||
|
|
@ -577,7 +584,7 @@ async def _react_agent_solve_question(
|
|||
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..."
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}"
|
||||
)
|
||||
return True, found_answer_content, thinking_steps, False
|
||||
|
||||
|
|
@ -588,7 +595,7 @@ async def _react_agent_solve_question(
|
|||
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..."
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 无法找到关于问题{question}的答案,原因: {not_enough_info_reason}"
|
||||
)
|
||||
return False, not_enough_info_reason, thinking_steps, False
|
||||
|
||||
|
|
@ -617,7 +624,7 @@ async def _react_agent_solve_question(
|
|||
if response and response.strip():
|
||||
# 如果响应不为空,记录思考过程,继续下一轮迭代
|
||||
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response[:100]}...")
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
|
||||
# 继续下一轮迭代,让LLM有机会在思考中给出found_answer或继续查询
|
||||
collected_info += f"思考: {response}"
|
||||
thinking_steps.append(step)
|
||||
|
|
@ -681,14 +688,29 @@ async def _react_agent_solve_question(
|
|||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
|
||||
|
||||
observation_text = observation if isinstance(observation, str) else str(observation)
|
||||
stripped_observation = observation_text.strip()
|
||||
step["observations"].append(observation_text)
|
||||
collected_info += f"\n{observation_text}\n"
|
||||
if observation_text.strip():
|
||||
if stripped_observation:
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(observation_text)
|
||||
tool_builder.add_tool_call(tool_call_item.call_id)
|
||||
conversation_messages.append(tool_builder.build())
|
||||
jargon_concepts = _match_jargon_from_text(stripped_observation, chat_id)
|
||||
if jargon_concepts:
|
||||
jargon_info = ""
|
||||
new_concepts = []
|
||||
for concept in jargon_concepts:
|
||||
normalized_concept = concept.strip()
|
||||
if normalized_concept and normalized_concept not in seen_jargon_concepts:
|
||||
new_concepts.append(normalized_concept)
|
||||
seen_jargon_concepts.add(normalized_concept)
|
||||
if new_concepts:
|
||||
jargon_info = await _retrieve_concepts_with_jargon(new_concepts, chat_id)
|
||||
if jargon_info:
|
||||
collected_info += f"\n{jargon_info}\n"
|
||||
logger.info(f"工具输出触发黑话解析: {new_concepts}")
|
||||
# logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行结果: {observation_text}")
|
||||
|
||||
thinking_steps.append(step)
|
||||
|
|
@ -758,83 +780,6 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
|||
return ""
|
||||
|
||||
|
||||
def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> List[str]:
|
||||
"""获取最近一段时间内缓存的记忆(只返回找到答案的记录)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认300秒(5分钟)
|
||||
|
||||
Returns:
|
||||
List[str]: 格式化的记忆列表,每个元素格式为 "问题:xxx\n答案:xxx"
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
start_time = current_time - time_window_seconds
|
||||
|
||||
# 查询最近时间窗口内找到答案的记录,按更新时间倒序
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id)
|
||||
& (ThinkingBack.update_time >= start_time)
|
||||
& (ThinkingBack.found_answer == 1)
|
||||
)
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(5) # 最多返回5条最近的记录
|
||||
)
|
||||
|
||||
if not records.exists():
|
||||
return []
|
||||
|
||||
cached_memories = []
|
||||
for record in records:
|
||||
if record.answer:
|
||||
cached_memories.append(f"问题:{record.question}\n答案:{record.answer}")
|
||||
|
||||
return cached_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取缓存记忆失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]:
|
||||
"""从thinking_back数据库中查询是否有现成的答案
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
question: 问题
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[bool, str]]: 如果找到记录,返回(found_answer, answer),否则返回None
|
||||
found_answer: 是否找到答案(True表示found_answer=1,False表示found_answer=0)
|
||||
answer: 答案内容
|
||||
"""
|
||||
try:
|
||||
# 查询相同chat_id和问题的所有记录(包括found_answer为0和1的)
|
||||
# 按更新时间倒序,获取最新的记录
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if records.exists():
|
||||
record = records.get()
|
||||
found_answer = bool(record.found_answer)
|
||||
answer = record.answer or ""
|
||||
logger.info(f"在thinking_back中找到记录,问题: {question[:50]}...,found_answer: {found_answer}")
|
||||
return found_answer, answer
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询thinking_back失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _store_thinking_back(
|
||||
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
|
|
@ -886,14 +831,21 @@ def _store_thinking_back(
|
|||
logger.error(f"存储思考过程失败: {e}")
|
||||
|
||||
|
||||
async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]:
|
||||
"""处理单个问题的查询(包含缓存检查逻辑)
|
||||
async def _process_single_question(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
context: str,
|
||||
initial_info: str = "",
|
||||
initial_jargon_concepts: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""处理单个问题的查询
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
initial_info: 初始信息(如概念检索结果),将传递给ReAct Agent
|
||||
initial_jargon_concepts: 已经处理过的黑话概念列表,用于ReAct阶段的去重
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None
|
||||
|
|
@ -904,74 +856,33 @@ async def _process_single_question(question: str, chat_id: str, context: str, in
|
|||
|
||||
question_initial_info = initial_info or ""
|
||||
|
||||
# 预先进行一次LPMM知识库查询,作为后续ReAct Agent的辅助信息
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
try:
|
||||
lpmm_result = await query_lpmm_knowledge(question, limit=2)
|
||||
if lpmm_result and lpmm_result.startswith("你从LPMM知识库中找到"):
|
||||
if question_initial_info:
|
||||
question_initial_info += "\n"
|
||||
question_initial_info += f"【LPMM知识库预查询】\n{lpmm_result}"
|
||||
logger.info(f"LPMM预查询命中,问题: {question[:50]}...")
|
||||
else:
|
||||
logger.info(f"LPMM预查询未命中或未找到信息,问题: {question[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"LPMM预查询失败,问题: {question[:50]}... 错误: {e}")
|
||||
# 直接使用ReAct Agent查询(不再从thinking_back获取缓存)
|
||||
logger.info(f"使用ReAct Agent查询,问题: {question[:50]}...")
|
||||
|
||||
# 先检查thinking_back数据库中是否有现成答案
|
||||
cached_result = _query_thinking_back(chat_id, question)
|
||||
should_requery = False
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
timeout=120.0,
|
||||
initial_info=question_initial_info,
|
||||
initial_jargon_concepts=initial_jargon_concepts,
|
||||
)
|
||||
|
||||
if cached_result:
|
||||
cached_found_answer, cached_answer = cached_result
|
||||
|
||||
if cached_found_answer: # found_answer == 1 (True)
|
||||
# found_answer == 1:20%概率重新查询
|
||||
if random.random() < 0.5:
|
||||
should_requery = True
|
||||
logger.info(f"found_answer=1,触发20%概率重新查询,问题: {question[:50]}...")
|
||||
|
||||
if not should_requery and cached_answer:
|
||||
logger.info(f"从thinking_back缓存中获取答案,问题: {question[:50]}...")
|
||||
return f"问题:{question}\n答案:{cached_answer}"
|
||||
elif not cached_answer:
|
||||
should_requery = True
|
||||
logger.info(f"found_answer=1 但缓存答案为空,重新查询,问题: {question[:50]}...")
|
||||
else:
|
||||
# found_answer == 0:不使用缓存,直接重新查询
|
||||
should_requery = True
|
||||
logger.info(f"thinking_back存在但未找到答案,忽略缓存重新查询,问题: {question[:50]}...")
|
||||
|
||||
# 如果没有缓存答案或需要重新查询,使用ReAct Agent查询
|
||||
if not cached_result or should_requery:
|
||||
if should_requery:
|
||||
logger.info(f"概率触发重新查询,使用ReAct Agent查询,问题: {question[:50]}...")
|
||||
else:
|
||||
logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...")
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
# 存储查询历史到数据库(超时时不存储)
|
||||
if not is_timeout:
|
||||
_store_thinking_back(
|
||||
chat_id=chat_id,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
timeout=120.0,
|
||||
initial_info=question_initial_info,
|
||||
question=question,
|
||||
context=context,
|
||||
found_answer=found_answer,
|
||||
answer=answer,
|
||||
thinking_steps=thinking_steps,
|
||||
)
|
||||
else:
|
||||
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||
|
||||
# 存储到数据库(超时时不存储)
|
||||
if not is_timeout:
|
||||
_store_thinking_back(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
context=context,
|
||||
found_answer=found_answer,
|
||||
answer=answer,
|
||||
thinking_steps=thinking_steps,
|
||||
)
|
||||
else:
|
||||
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||
|
||||
if found_answer and answer:
|
||||
return f"问题:{question}\n答案:{answer}"
|
||||
if found_answer and answer:
|
||||
return f"问题:{question}\n答案:{answer}"
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -1036,19 +947,20 @@ async def build_memory_retrieval_prompt(
|
|||
|
||||
# 解析概念列表和问题列表
|
||||
_, questions = _parse_questions_json(response)
|
||||
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||
if questions:
|
||||
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||
|
||||
# 使用匹配逻辑自动识别聊天中的黑话概念
|
||||
concepts = _match_jargon_from_text(message, chat_id)
|
||||
if concepts:
|
||||
logger.info(f"黑话匹配命中 {len(concepts)} 个概念: {concepts}")
|
||||
else:
|
||||
logger.info("黑话匹配未命中任何概念")
|
||||
logger.debug("黑话匹配未命中任何概念")
|
||||
|
||||
# 对匹配到的概念进行jargon检索,作为初始信息
|
||||
initial_info = ""
|
||||
if concepts:
|
||||
logger.info(f"开始对 {len(concepts)} 个概念进行jargon检索")
|
||||
# logger.info(f"开始对 {len(concepts)} 个概念进行jargon检索")
|
||||
concept_info = await _retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
if concept_info:
|
||||
initial_info += concept_info
|
||||
|
|
@ -1056,25 +968,11 @@ async def build_memory_retrieval_prompt(
|
|||
else:
|
||||
logger.info("概念检索未找到任何结果")
|
||||
|
||||
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
||||
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
||||
|
||||
if not questions:
|
||||
logger.debug("模型认为不需要检索记忆或解析失败")
|
||||
# 即使没有当次查询,也返回缓存的记忆和概念检索结果
|
||||
all_results = []
|
||||
if initial_info:
|
||||
all_results.append(initial_info.strip())
|
||||
if cached_memories:
|
||||
all_results.extend(cached_memories)
|
||||
|
||||
if all_results:
|
||||
retrieved_memory = "\n\n".join(all_results)
|
||||
end_time = time.time()
|
||||
logger.info(f"无当次查询,返回缓存记忆和概念检索结果,耗时: {(end_time - start_time):.3f}秒")
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
return ""
|
||||
logger.debug("模型认为不需要检索记忆或解析失败,不返回任何查询结果")
|
||||
end_time = time.time()
|
||||
logger.info(f"无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}秒")
|
||||
return ""
|
||||
|
||||
# 第二步:并行处理所有问题(使用配置的最大迭代次数/120秒超时)
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
|
@ -1082,7 +980,13 @@ async def build_memory_retrieval_prompt(
|
|||
|
||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||
question_tasks = [
|
||||
_process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info)
|
||||
_process_single_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
context=message,
|
||||
initial_info=initial_info,
|
||||
initial_jargon_concepts=concepts,
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
|
|
@ -1090,37 +994,23 @@ async def build_memory_retrieval_prompt(
|
|||
results = await asyncio.gather(*question_tasks, return_exceptions=True)
|
||||
|
||||
# 收集所有有效结果
|
||||
all_results = []
|
||||
current_questions = set() # 用于去重,避免缓存和当次查询重复
|
||||
question_results: List[str] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"处理问题 '{questions[i]}' 时发生异常: {result}")
|
||||
elif result is not None:
|
||||
all_results.append(result)
|
||||
# 提取问题用于去重
|
||||
if result.startswith("问题:"):
|
||||
question = result.split("\n")[0].replace("问题:", "").strip()
|
||||
current_questions.add(question)
|
||||
|
||||
# 将缓存的记忆添加到结果中(排除当次查询已包含的问题,避免重复)
|
||||
for cached_memory in cached_memories:
|
||||
if cached_memory.startswith("问题:"):
|
||||
question = cached_memory.split("\n")[0].replace("问题:", "").strip()
|
||||
# 只有当次查询中没有相同问题时,才添加缓存记忆
|
||||
if question not in current_questions:
|
||||
all_results.append(cached_memory)
|
||||
logger.debug(f"添加缓存记忆: {question[:50]}...")
|
||||
question_results.append(result)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if all_results:
|
||||
retrieved_memory = "\n\n".join(all_results)
|
||||
if question_results:
|
||||
retrieved_memory = "\n\n".join(question_results)
|
||||
logger.info(
|
||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
|
||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆"
|
||||
)
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
||||
logger.debug("所有问题均未找到答案")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,161 +0,0 @@
|
|||
# 记忆检索工具模块
|
||||
|
||||
这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
retrieval_tools/
|
||||
├── __init__.py # 模块导出
|
||||
├── tool_registry.py # 工具注册系统
|
||||
├── tool_utils.py # 工具函数库(共用函数)
|
||||
├── query_jargon.py # 查询jargon工具
|
||||
├── query_chat_history.py # 查询聊天历史工具
|
||||
├── query_lpmm_knowledge.py # 查询LPMM知识库工具
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 模块说明
|
||||
|
||||
### `tool_registry.py`
|
||||
包含工具注册系统的核心类:
|
||||
- `MemoryRetrievalTool`: 工具基类
|
||||
- `MemoryRetrievalToolRegistry`: 工具注册器
|
||||
- `register_memory_retrieval_tool()`: 便捷注册函数
|
||||
- `get_tool_registry()`: 获取注册器实例
|
||||
|
||||
### `tool_utils.py`
|
||||
包含所有工具共用的工具函数:
|
||||
- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳
|
||||
- `parse_time_range()`: 解析时间范围字符串
|
||||
|
||||
### 工具文件
|
||||
每个工具都有独立的文件:
|
||||
- `query_jargon.py`: 根据关键词在jargon库中查询
|
||||
- `query_chat_history.py`: 根据时间或关键词在chat_history中查询(支持查询时间点事件、时间范围事件、关键词搜索)
|
||||
|
||||
## 如何添加新工具
|
||||
|
||||
1. 创建新的工具文件,例如 `query_new_tool.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
新工具 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_new_tool(param1: str, param2: str, chat_id: str) -> str:
|
||||
"""新工具的实现
|
||||
|
||||
Args:
|
||||
param1: 参数1
|
||||
param2: 参数2
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 实现逻辑
|
||||
return "结果"
|
||||
except Exception as e:
|
||||
logger.error(f"新工具执行失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_new_tool",
|
||||
description="新工具的描述",
|
||||
parameters=[
|
||||
{
|
||||
"name": "param1",
|
||||
"type": "string",
|
||||
"description": "参数1的描述",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "param2",
|
||||
"type": "string",
|
||||
"description": "参数2的描述",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
execute_func=query_new_tool
|
||||
)
|
||||
```
|
||||
|
||||
2. 在 `__init__.py` 中导入并注册新工具:
|
||||
|
||||
```python
|
||||
from .query_new_tool import register_tool as register_query_new_tool
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_new_tool() # 添加新工具
|
||||
```
|
||||
|
||||
3. 工具会自动:
|
||||
- 出现在 ReAct Agent 的 prompt 中
|
||||
- 在动作类型列表中可用
|
||||
- 被 ReAct Agent 自动调用
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry
|
||||
|
||||
# 初始化所有工具
|
||||
init_all_tools()
|
||||
|
||||
# 获取工具注册器
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 获取特定工具
|
||||
tool = registry.get_tool("query_chat_history")
|
||||
|
||||
# 执行工具(查询时间点事件)
|
||||
result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123")
|
||||
|
||||
# 或者查询关键词
|
||||
result = await tool.execute(keyword="小丑AI", chat_id="chat123")
|
||||
|
||||
# 或者查询时间范围
|
||||
result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123")
|
||||
```
|
||||
|
||||
## 现有工具说明
|
||||
|
||||
### query_jargon
|
||||
根据关键词在jargon库中查询黑话/俚语/缩写的含义
|
||||
- 参数:`keyword` (必填) - 关键词
|
||||
|
||||
### query_chat_history
|
||||
根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息
|
||||
- 参数:
|
||||
- `keyword` (可选) - 关键词,用于搜索消息内容
|
||||
- `time_point` (可选) - 时间点,格式:YYYY-MM-DD HH:MM:SS,用于查询某个时间点附近发生了什么(与time_range二选一)
|
||||
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(与time_point二选一)
|
||||
|
||||
### query_lpmm_knowledge
|
||||
从LPMM知识库中检索与关键词相关的知识内容
|
||||
- 参数:
|
||||
- `query` (必填) - 查询的关键词或问题描述
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 所有工具函数必须是异步函数(`async def`)
|
||||
- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测)
|
||||
- 工具参数定义中的 `required` 字段用于生成 prompt 描述
|
||||
- 工具执行失败时应返回错误信息字符串,而不是抛出异常
|
||||
- 共用函数放在 `tool_utils.py` 中,避免代码重复
|
||||
|
||||
|
|
@ -11,7 +11,6 @@ from .tool_registry import (
|
|||
)
|
||||
|
||||
# 导入所有工具的注册函数
|
||||
from .query_jargon import register_tool as register_query_jargon
|
||||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
|
|
@ -20,7 +19,6 @@ from src.config.config import global_config
|
|||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
"""
|
||||
根据关键词在jargon库中查询 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_jargon(keyword: str, chat_id: str) -> str:
|
||||
"""根据关键词在jargon库中查询
|
||||
|
||||
Args:
|
||||
keyword: 关键词(黑话/俚语/缩写)
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(keyword).strip()
|
||||
if not content:
|
||||
return "关键词为空"
|
||||
|
||||
# 先尝试精确匹配
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not results:
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if results:
|
||||
# 如果是模糊匹配,显示找到的实际jargon内容
|
||||
if is_fuzzy_match:
|
||||
# 处理多个结果
|
||||
output_parts = [f"未精确匹配到'{content}'"]
|
||||
for result in results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
output = ",".join(output_parts)
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,模糊搜索): {content},找到{len(results)}条结果")
|
||||
else:
|
||||
# 精确匹配,可能有多条(相同content但不同chat_id的情况)
|
||||
output_parts = []
|
||||
for result in results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{content}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
output = ";".join(output_parts) if len(output_parts) > 1 else output_parts[0]
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,精确匹配): {content},找到{len(results)}条结果")
|
||||
return output
|
||||
|
||||
# 未命中
|
||||
logger.info(f"在jargon库中未找到匹配(当前会话或全局,精确匹配和模糊搜索都未找到): {content}")
|
||||
return f"未在jargon库中找到'{content}'的解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询jargon失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_jargon",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
||||
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
|
||||
execute_func=query_jargon,
|
||||
)
|
||||
|
|
@ -12,6 +12,25 @@ from .tool_registry import register_memory_retrieval_tool
|
|||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
def _calculate_similarity(query: str, target: str) -> float:
|
||||
"""计算查询词在目标字符串中的相似度比例
|
||||
|
||||
Args:
|
||||
query: 查询词
|
||||
target: 目标字符串
|
||||
|
||||
Returns:
|
||||
float: 相似度比例(0.0-1.0),查询词长度 / 目标字符串长度
|
||||
"""
|
||||
if not query or not target:
|
||||
return 0.0
|
||||
query_len = len(query)
|
||||
target_len = len(target)
|
||||
if target_len == 0:
|
||||
return 0.0
|
||||
return query_len / target_len
|
||||
|
||||
|
||||
def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
"""格式化群昵称信息
|
||||
|
||||
|
|
@ -81,11 +100,29 @@ async def query_person_info(person_name: str) -> str:
|
|||
if not records:
|
||||
return f"未找到模糊匹配'{person_name}'的用户信息"
|
||||
|
||||
# 根据相似度过滤结果:查询词在目标字符串中至少占50%
|
||||
SIMILARITY_THRESHOLD = 0.5
|
||||
filtered_records = []
|
||||
for record in records:
|
||||
if not record.person_name:
|
||||
continue
|
||||
# 精确匹配总是保留(相似度100%)
|
||||
if record.person_name.strip() == person_name:
|
||||
filtered_records.append(record)
|
||||
else:
|
||||
# 模糊匹配需要检查相似度
|
||||
similarity = _calculate_similarity(person_name, record.person_name.strip())
|
||||
if similarity >= SIMILARITY_THRESHOLD:
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
return f"未找到相似度≥50%的匹配'{person_name}'的用户信息"
|
||||
|
||||
# 区分精确匹配和模糊匹配的结果
|
||||
exact_matches = []
|
||||
fuzzy_matches = []
|
||||
|
||||
for record in records:
|
||||
for record in filtered_records:
|
||||
# 检查是否是精确匹配
|
||||
if record.person_name and record.person_name.strip() == person_name:
|
||||
exact_matches.append(record)
|
||||
|
|
@ -248,7 +285,7 @@ async def query_person_info(person_name: str) -> str:
|
|||
response_text = "\n\n---\n\n".join(results)
|
||||
|
||||
# 添加统计信息
|
||||
total_count = len(records)
|
||||
total_count = len(filtered_records)
|
||||
exact_count = len(exact_matches)
|
||||
fuzzy_count = len(fuzzy_matches)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue