mirror of https://github.com/Mai-with-u/MaiBot.git
feat:添加ReAct记忆提取系统
parent
d761d42dd7
commit
7a3f260cc3
|
|
@ -9,7 +9,7 @@ from src.plugin_system.base.base_tool import BaseTool, ToolParamType
|
||||||
# 导入依赖的系统组件
|
# 导入依赖的系统组件
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from src.plugins.built_in.relation.relation import BuildRelationAction
|
# from src.plugins.built_in.relation.relation import BuildRelationAction
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
|
|
||||||
logger = get_logger("relation_actions")
|
logger = get_logger("relation_actions")
|
||||||
|
|
|
||||||
|
|
@ -37,10 +37,12 @@ from src.plugin_system.apis import llm_api
|
||||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||||
|
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||||
|
|
||||||
init_lpmm_prompt()
|
init_lpmm_prompt()
|
||||||
init_replyer_prompt()
|
init_replyer_prompt()
|
||||||
init_rewrite_prompt()
|
init_rewrite_prompt()
|
||||||
|
init_memory_retrieval_prompt()
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
@ -289,15 +291,7 @@ class DefaultReplyer:
|
||||||
|
|
||||||
async def build_question_block(self) -> str:
|
async def build_question_block(self) -> str:
|
||||||
"""构建问题块"""
|
"""构建问题块"""
|
||||||
# if not global_config.question.enable_question:
|
# 问题跟踪功能已移除,返回空字符串
|
||||||
# return ""
|
|
||||||
questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id)
|
|
||||||
questions_str = ""
|
|
||||||
for question in questions:
|
|
||||||
questions_str += f"- {question.question}\n"
|
|
||||||
if questions_str:
|
|
||||||
return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}"
|
|
||||||
else:
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -807,7 +801,7 @@ class DefaultReplyer:
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 并行执行五个构建任务
|
# 并行执行九个构建任务
|
||||||
task_results = await asyncio.gather(
|
task_results = await asyncio.gather(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||||
|
|
@ -821,6 +815,12 @@ class DefaultReplyer:
|
||||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||||
self._time_and_run_task(self.build_question_block(), "question_block"),
|
self._time_and_run_task(self.build_question_block(), "question_block"),
|
||||||
|
self._time_and_run_task(
|
||||||
|
build_memory_retrieval_prompt(
|
||||||
|
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||||
|
),
|
||||||
|
"memory_retrieval",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 任务名称中英文映射
|
# 任务名称中英文映射
|
||||||
|
|
@ -835,6 +835,7 @@ class DefaultReplyer:
|
||||||
"personality_prompt": "人格信息",
|
"personality_prompt": "人格信息",
|
||||||
"mood_state_prompt": "情绪状态",
|
"mood_state_prompt": "情绪状态",
|
||||||
"question_block": "问题",
|
"question_block": "问题",
|
||||||
|
"memory_retrieval": "记忆检索",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 处理结果
|
# 处理结果
|
||||||
|
|
@ -865,6 +866,7 @@ class DefaultReplyer:
|
||||||
actions_info: str = results_dict["actions_info"]
|
actions_info: str = results_dict["actions_info"]
|
||||||
personality_prompt: str = results_dict["personality_prompt"]
|
personality_prompt: str = results_dict["personality_prompt"]
|
||||||
question_block: str = results_dict["question_block"]
|
question_block: str = results_dict["question_block"]
|
||||||
|
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||||
|
|
||||||
|
|
@ -922,6 +924,7 @@ class DefaultReplyer:
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
question_block=question_block,
|
question_block=question_block,
|
||||||
|
memory_retrieval=memory_retrieval,
|
||||||
chat_prompt=chat_prompt_block,
|
chat_prompt=chat_prompt_block,
|
||||||
), selected_expressions
|
), selected_expressions
|
||||||
|
|
||||||
|
|
@ -1150,7 +1153,6 @@ class DefaultReplyer:
|
||||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
"""
|
"""
|
||||||
加权且不放回地随机抽取k个元素。
|
加权且不放回地随机抽取k个元素。
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ def init_replyer_prompt():
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
{expression_habits_block}{memory_block}{question_block}
|
{expression_habits_block}{memory_block}{question_block}{memory_retrieval}
|
||||||
|
|
||||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片:
|
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片:
|
||||||
{time_block}
|
{time_block}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
from peewee import fn
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Jargon
|
from src.common.database.database_model import Jargon
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config, global_config
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_anonymous_messages,
|
build_anonymous_messages,
|
||||||
|
|
@ -21,28 +22,27 @@ logger = get_logger("jargon")
|
||||||
|
|
||||||
def _init_prompt() -> None:
|
def _init_prompt() -> None:
|
||||||
prompt_str = """
|
prompt_str = """
|
||||||
**聊天内容**
|
**聊天内容,其中的SELF是你自己的发言**
|
||||||
{chat_str}
|
{chat_str}
|
||||||
|
|
||||||
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||||
- 必须为对话中真实出现过的短词或短语
|
- 必须为对话中真实出现过的短词或短语
|
||||||
- 必须是你无法理解含义的词语,或者出现频率较高的词语
|
- 必须是你无法理解含义的词语,没有明确含义的词语
|
||||||
- 请不要选择有明确含义,或者含义清晰的词语
|
- 请不要选择有明确含义,或者含义清晰的词语
|
||||||
- 必须是这几种类别之一:英文或中文缩写、中文拼音短语、字母数字混合
|
- 必须是这几种类别之一:英文或中文缩写、中文拼音短语
|
||||||
- 排除:人名、@、明显的表情/图片占位、纯标点、常规功能词(如的、了、呢、啊等)
|
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||||
- 合并重复项,去重
|
- 合并重复项,去重
|
||||||
|
|
||||||
分类规则,type必须根据规则填写:
|
分类规则,type必须根据规则填写:
|
||||||
- p(拼音缩写):由字母或字母和汉字构成的,用汉语拼音简写词,或汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
- p(拼音缩写):由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||||
- c(中文缩写):中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
|
||||||
- e(英文缩写):英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
- e(英文缩写):英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||||
- x(谐音梗):谐音梗,用谐音词概括一个词汇或含义,例如:好似,难崩
|
- c(中文缩写):中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||||
|
|
||||||
以 JSON 数组输出,元素为对象(严格按以下结构):
|
以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||||
[
|
[
|
||||||
{{"content": "词条", "raw_content": "包含该词条的完整对话原文", "type": "p"}},
|
{{"content": "词条", "raw_content": "包含该词条的完整对话上下文原文", "type": "p"}},
|
||||||
{{"content": "词条2", "raw_content": "包含该词条的完整对话原文", "type": "c"}}
|
{{"content": "词条2", "raw_content": "包含该词条的完整对话上下文原文", "type": "c"}}
|
||||||
]
|
]
|
||||||
|
|
||||||
现在请输出:
|
现在请输出:
|
||||||
|
|
@ -57,7 +57,7 @@ def _init_inference_prompts() -> None:
|
||||||
**词条内容**
|
**词条内容**
|
||||||
{content}
|
{content}
|
||||||
|
|
||||||
**词条出现的上下文(raw_content)**
|
**词条出现的上下文(raw_content)其中的SELF是你自己的发言**
|
||||||
{raw_content_list}
|
{raw_content_list}
|
||||||
|
|
||||||
请根据以上词条内容和上下文,推断这个词条的含义。
|
请根据以上词条内容和上下文,推断这个词条的含义。
|
||||||
|
|
@ -66,8 +66,8 @@ def _init_inference_prompts() -> None:
|
||||||
|
|
||||||
以 JSON 格式输出:
|
以 JSON 格式输出:
|
||||||
{{
|
{{
|
||||||
"meaning": "含义说明",
|
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)",
|
||||||
"translation": "翻译或解释"
|
"translation": "原文(用一个词语写明这个词的实际含义)"
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
Prompt(prompt1_str, "jargon_inference_with_context_prompt")
|
Prompt(prompt1_str, "jargon_inference_with_context_prompt")
|
||||||
|
|
@ -83,8 +83,8 @@ def _init_inference_prompts() -> None:
|
||||||
|
|
||||||
以 JSON 格式输出:
|
以 JSON 格式输出:
|
||||||
{{
|
{{
|
||||||
"meaning": "含义说明",
|
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)",
|
||||||
"translation": "翻译或解释"
|
"translation": "原文(用一个词语写明这个词的实际含义)"
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
Prompt(prompt2_str, "jargon_inference_content_only_prompt")
|
Prompt(prompt2_str, "jargon_inference_content_only_prompt")
|
||||||
|
|
@ -117,7 +117,7 @@ _init_inference_prompts()
|
||||||
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||||
"""
|
"""
|
||||||
判断是否需要进行含义推断
|
判断是否需要进行含义推断
|
||||||
在 count 达到 5, 10, 20, 40, 60, 100 时进行推断
|
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
|
||||||
并且count必须大于last_inference_count,避免重启后重复判定
|
并且count必须大于last_inference_count,避免重启后重复判定
|
||||||
如果is_complete为True,不再进行推断
|
如果is_complete为True,不再进行推断
|
||||||
"""
|
"""
|
||||||
|
|
@ -128,8 +128,8 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||||
count = jargon_obj.count or 0
|
count = jargon_obj.count or 0
|
||||||
last_inference = jargon_obj.last_inference_count or 0
|
last_inference = jargon_obj.last_inference_count or 0
|
||||||
|
|
||||||
# 阈值列表:5, 10, 20, 40, 60, 100
|
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||||
thresholds = [5, 10, 20, 40, 60, 100]
|
thresholds = [3,6, 10, 20, 40, 60, 100]
|
||||||
|
|
||||||
if count < thresholds[0]:
|
if count < thresholds[0]:
|
||||||
return False
|
return False
|
||||||
|
|
@ -166,6 +166,11 @@ class JargonMiner:
|
||||||
request_type="jargon.extract",
|
request_type="jargon.extract",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 初始化stream_name作为类属性,避免重复提取
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||||
|
self.stream_name = stream_name if stream_name else self.chat_id
|
||||||
|
|
||||||
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
|
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
|
||||||
"""通过ID加载对象并推断"""
|
"""通过ID加载对象并推断"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -255,6 +260,8 @@ class JargonMiner:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if global_config.debug.show_jargon_prompt:
|
||||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||||
# logger.info(f"jargon {content} 推断2结果: {inference2}")
|
# logger.info(f"jargon {content} 推断2结果: {inference2}")
|
||||||
|
|
@ -269,6 +276,7 @@ class JargonMiner:
|
||||||
inference2=json.dumps(inference2, ensure_ascii=False),
|
inference2=json.dumps(inference2, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if global_config.debug.show_jargon_prompt:
|
||||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||||
|
|
||||||
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
||||||
|
|
@ -317,6 +325,20 @@ class JargonMiner:
|
||||||
jargon_obj.save()
|
jargon_obj.save()
|
||||||
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
||||||
|
|
||||||
|
# 固定输出推断结果,格式化为可读形式
|
||||||
|
if is_jargon:
|
||||||
|
# 是黑话,输出格式:[聊天名]xxx (translation)的含义是 xxxxxxxxxxx
|
||||||
|
translation = jargon_obj.translation or "未知"
|
||||||
|
meaning = jargon_obj.meaning or "无详细说明"
|
||||||
|
is_global = jargon_obj.is_global
|
||||||
|
if is_global:
|
||||||
|
logger.info(f"[通用黑话]{content} ({translation})的含义是 {meaning}")
|
||||||
|
else:
|
||||||
|
logger.info(f"[{self.stream_name}]{content} ({translation})的含义是 {meaning}")
|
||||||
|
else:
|
||||||
|
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||||||
|
logger.info(f"[{self.stream_name}]{content} 不是黑话")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"jargon推断失败: {e}")
|
logger.error(f"jargon推断失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -371,6 +393,7 @@ class JargonMiner:
|
||||||
if not response:
|
if not response:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if global_config.debug.show_jargon_prompt:
|
||||||
logger.info(f"jargon提取提示词: {prompt}")
|
logger.info(f"jargon提取提示词: {prompt}")
|
||||||
logger.info(f"jargon提取结果: {response}")
|
logger.info(f"jargon提取结果: {response}")
|
||||||
|
|
||||||
|
|
@ -404,6 +427,8 @@ class JargonMiner:
|
||||||
raw_content_list = []
|
raw_content_list = []
|
||||||
if isinstance(raw_content_value, list):
|
if isinstance(raw_content_value, list):
|
||||||
raw_content_list = [str(rc).strip() for rc in raw_content_value if str(rc).strip()]
|
raw_content_list = [str(rc).strip() for rc in raw_content_value if str(rc).strip()]
|
||||||
|
# 去重
|
||||||
|
raw_content_list = list(dict.fromkeys(raw_content_list))
|
||||||
elif isinstance(raw_content_value, str):
|
elif isinstance(raw_content_value, str):
|
||||||
raw_content_str = raw_content_value.strip()
|
raw_content_str = raw_content_value.strip()
|
||||||
if raw_content_str:
|
if raw_content_str:
|
||||||
|
|
@ -585,10 +610,20 @@ class JargonMiner:
|
||||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if saved or updated or merged:
|
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,合并为global {merged} 条,chat_id={self.chat_id}")
|
if uniq_entries:
|
||||||
|
# 收集所有提取的jargon内容
|
||||||
|
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||||
|
jargon_str = ",".join(jargon_list)
|
||||||
|
|
||||||
|
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||||
|
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||||
|
|
||||||
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
|
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
|
||||||
self.last_learning_time = extraction_end_time
|
self.last_learning_time = extraction_end_time
|
||||||
|
|
||||||
|
if saved or updated or merged:
|
||||||
|
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,合并为global {merged} 条,chat_id={self.chat_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"JargonMiner 运行失败: {e}")
|
logger.error(f"JargonMiner 运行失败: {e}")
|
||||||
|
|
||||||
|
|
@ -611,3 +646,88 @@ async def extract_and_store_jargon(chat_id: str) -> None:
|
||||||
await miner.run_once()
|
await miner.run_once()
|
||||||
|
|
||||||
|
|
||||||
|
def search_jargon(
|
||||||
|
keyword: str,
|
||||||
|
chat_id: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
fuzzy: bool = True
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
搜索jargon,支持大小写不敏感和模糊搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyword: 搜索关键词
|
||||||
|
chat_id: 可选的聊天ID,如果提供则优先搜索该聊天或global的jargon
|
||||||
|
limit: 返回结果数量限制,默认10
|
||||||
|
case_sensitive: 是否大小写敏感,默认False(不敏感)
|
||||||
|
fuzzy: 是否模糊搜索,默认True(使用LIKE匹配)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, str]]: 包含content, translation, meaning的字典列表
|
||||||
|
"""
|
||||||
|
if not keyword or not keyword.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
keyword = keyword.strip()
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
query = Jargon.select(
|
||||||
|
Jargon.content,
|
||||||
|
Jargon.translation,
|
||||||
|
Jargon.meaning
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建搜索条件
|
||||||
|
if case_sensitive:
|
||||||
|
# 大小写敏感
|
||||||
|
if fuzzy:
|
||||||
|
# 模糊搜索
|
||||||
|
search_condition = Jargon.content.contains(keyword)
|
||||||
|
else:
|
||||||
|
# 精确匹配
|
||||||
|
search_condition = (Jargon.content == keyword)
|
||||||
|
else:
|
||||||
|
# 大小写不敏感
|
||||||
|
if fuzzy:
|
||||||
|
# 模糊搜索(使用LOWER函数)
|
||||||
|
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||||
|
else:
|
||||||
|
# 精确匹配(使用LOWER函数)
|
||||||
|
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
||||||
|
|
||||||
|
query = query.where(search_condition)
|
||||||
|
|
||||||
|
# 如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||||
|
if chat_id:
|
||||||
|
query = query.where(
|
||||||
|
(Jargon.chat_id == chat_id) | Jargon.is_global
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只返回有translation或meaning的记录
|
||||||
|
query = query.where(
|
||||||
|
(
|
||||||
|
(Jargon.translation.is_null(False)) & (Jargon.translation != "")
|
||||||
|
) | (
|
||||||
|
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按count降序排序,优先返回出现频率高的
|
||||||
|
query = query.order_by(Jargon.count.desc())
|
||||||
|
|
||||||
|
# 限制结果数量
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
# 执行查询并返回结果
|
||||||
|
results = []
|
||||||
|
for jargon in query:
|
||||||
|
results.append({
|
||||||
|
"content": jargon.content or "",
|
||||||
|
"translation": jargon.translation or "",
|
||||||
|
"meaning": jargon.meaning or ""
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,585 @@
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
from src.plugin_system.apis import llm_api
|
||||||
|
from src.common.database.database_model import ThinkingBack
|
||||||
|
from json_repair import repair_json
|
||||||
|
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||||
|
|
||||||
|
logger = get_logger("memory_retrieval")
|
||||||
|
|
||||||
|
|
||||||
|
def init_memory_retrieval_prompt():
|
||||||
|
"""初始化记忆检索相关的 prompt 模板和工具"""
|
||||||
|
# 首先注册所有工具
|
||||||
|
init_all_tools()
|
||||||
|
|
||||||
|
# 第一步:问题生成prompt
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
你是一个专门检测是否需要回忆的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||||
|
群里正在进行的聊天内容:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
{recent_query_history}
|
||||||
|
|
||||||
|
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||||
|
请仔细分析聊天内容,考虑以下几点:
|
||||||
|
1. 对话中是否提到了过去发生的事情、人物、事件或信息
|
||||||
|
2. 是否有需要回忆的内容(比如"之前说过"、"上次"、"以前"等)
|
||||||
|
3. 是否有需要查找历史信息的问题
|
||||||
|
|
||||||
|
重要提示:
|
||||||
|
- 如果"最近已查询的问题和结果"中已经包含了类似的问题,请避免重复生成相同或相似的问题
|
||||||
|
- 如果之前已经查询过某个问题但未找到答案,可以尝试用不同的方式提问或更具体的问题
|
||||||
|
- 如果之前已经查询过某个问题并找到了答案,可以直接参考已有结果,不需要重复查询
|
||||||
|
|
||||||
|
如果你认为需要从记忆中检索信息来回答,请根据上下文提出一个或多个具体的问题。
|
||||||
|
问题格式示例:
|
||||||
|
- "xxx在前几天干了什么"
|
||||||
|
- "xxx是什么"
|
||||||
|
- "xxxx和xxx的关系是什么"
|
||||||
|
- "xxx在某个时间点发生了什么"
|
||||||
|
|
||||||
|
请输出JSON格式的问题数组。如果不需要检索记忆,则输出空数组[]。
|
||||||
|
|
||||||
|
输出格式示例:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
"张三在前几天干了什么",
|
||||||
|
"自然选择是什么",
|
||||||
|
"李四和王五的关系是什么"
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
请只输出JSON数组,不要输出其他内容:
|
||||||
|
""",
|
||||||
|
name="memory_retrieval_question_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 第二步:ReAct Agent prompt(工具描述会在运行时动态生成)
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
你是一个记忆检索助手,需要通过思考(Think)、行动(Action)、观察(Observation)的循环来回答问题。
|
||||||
|
|
||||||
|
当前问题:{question}
|
||||||
|
已收集的信息:
|
||||||
|
{collected_info}
|
||||||
|
|
||||||
|
你可以使用以下工具来查询信息:
|
||||||
|
{tools_description}
|
||||||
|
|
||||||
|
请按照以下格式输出你的思考过程:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thought": "你的思考过程,分析当前情况,决定下一步行动",
|
||||||
|
"action": "要执行的动作,格式为:工具名(参数)",
|
||||||
|
"action_type": {action_types_list},
|
||||||
|
"action_params": {{参数名: 参数值}} 或 null
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
你可以选择以下动作:
|
||||||
|
1. 如果已经收集到足够的信息可以回答问题,请设置action_type为"final_answer",并在thought中说明答案。
|
||||||
|
2. 如果经过多次查询后,确认无法找到相关信息或答案,请设置action_type为"no_answer",并在thought中说明原因。
|
||||||
|
|
||||||
|
请只输出JSON,不要输出其他内容:
|
||||||
|
""",
|
||||||
|
name="memory_retrieval_react_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""解析ReAct Agent的响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM返回的响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 解析后的动作信息,如果解析失败返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 尝试提取JSON(可能包含在```json代码块中)
|
||||||
|
json_pattern = r"```json\s*(.*?)\s*```"
|
||||||
|
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
json_str = matches[0]
|
||||||
|
else:
|
||||||
|
# 尝试直接解析整个响应
|
||||||
|
json_str = response.strip()
|
||||||
|
|
||||||
|
# 修复可能的JSON错误
|
||||||
|
repaired_json = repair_json(json_str)
|
||||||
|
|
||||||
|
# 解析JSON
|
||||||
|
action_info = json.loads(repaired_json)
|
||||||
|
|
||||||
|
if not isinstance(action_info, dict):
|
||||||
|
logger.warning(f"解析的JSON不是对象格式: {action_info}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return action_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _react_agent_solve_question(
|
||||||
|
question: str,
|
||||||
|
chat_id: str,
|
||||||
|
max_iterations: int = 5,
|
||||||
|
timeout: float = 30.0
|
||||||
|
) -> Tuple[bool, str, List[Dict[str, Any]]]:
|
||||||
|
"""使用ReAct架构的Agent来解决问题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: 要回答的问题
|
||||||
|
chat_id: 聊天ID
|
||||||
|
max_iterations: 最大迭代次数
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str, List[Dict[str, Any]]]: (是否找到答案, 答案内容, 思考步骤列表)
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
collected_info = ""
|
||||||
|
thinking_steps = []
|
||||||
|
|
||||||
|
for iteration in range(max_iterations):
|
||||||
|
# 检查超时
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
logger.warning(f"ReAct Agent超时,已迭代{iteration}次")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}")
|
||||||
|
logger.info(f"ReAct Agent 已收集信息: {collected_info if collected_info else '暂无信息'}")
|
||||||
|
|
||||||
|
# 获取工具注册器
|
||||||
|
tool_registry = get_tool_registry()
|
||||||
|
|
||||||
|
# 构建prompt(动态生成工具描述)
|
||||||
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
|
"memory_retrieval_react_prompt",
|
||||||
|
question=question,
|
||||||
|
collected_info=collected_info if collected_info else "暂无信息",
|
||||||
|
tools_description=tool_registry.get_tools_description(),
|
||||||
|
action_types_list=tool_registry.get_action_types_list(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 Prompt: {prompt}")
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||||
|
prompt,
|
||||||
|
model_config=model_config.model_task_config.tool_use,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM响应: {response}")
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM推理: {reasoning_content}")
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM模型: {model_name}")
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
action_info = _parse_react_response(response)
|
||||||
|
if not action_info:
|
||||||
|
logger.warning(f"无法解析ReAct响应,迭代{iteration + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
thought = action_info.get("thought", "")
|
||||||
|
action_type = action_info.get("action_type", "")
|
||||||
|
action_params = action_info.get("action_params", {})
|
||||||
|
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考: {thought}")
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作类型: {action_type}")
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作参数: {action_params}")
|
||||||
|
|
||||||
|
# 记录思考步骤
|
||||||
|
step = {
|
||||||
|
"iteration": iteration + 1,
|
||||||
|
"thought": thought,
|
||||||
|
"action_type": action_type,
|
||||||
|
"action_params": action_params,
|
||||||
|
"observation": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行动作
|
||||||
|
if action_type == "final_answer":
|
||||||
|
# Agent认为已经找到答案
|
||||||
|
answer = thought # 使用thought作为答案
|
||||||
|
step["observation"] = "找到答案"
|
||||||
|
thinking_steps.append(step)
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到最终答案: {answer}")
|
||||||
|
return True, answer, thinking_steps
|
||||||
|
|
||||||
|
elif action_type == "no_answer":
|
||||||
|
# Agent确认无法找到答案
|
||||||
|
answer = thought # 使用thought说明无法找到答案的原因
|
||||||
|
step["observation"] = "确认无法找到答案"
|
||||||
|
thinking_steps.append(step)
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}")
|
||||||
|
return False, answer, thinking_steps
|
||||||
|
|
||||||
|
# 使用工具注册器执行工具
|
||||||
|
tool_registry = get_tool_registry()
|
||||||
|
tool = tool_registry.get_tool(action_type)
|
||||||
|
|
||||||
|
if tool:
|
||||||
|
try:
|
||||||
|
# 准备工具参数(需要添加chat_id如果工具需要)
|
||||||
|
tool_params = action_params.copy()
|
||||||
|
|
||||||
|
# 如果工具函数签名需要chat_id,添加它
|
||||||
|
import inspect
|
||||||
|
sig = inspect.signature(tool.execute_func)
|
||||||
|
if "chat_id" in sig.parameters:
|
||||||
|
tool_params["chat_id"] = chat_id
|
||||||
|
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 执行工具: {action_type}({tool_params})")
|
||||||
|
|
||||||
|
# 执行工具
|
||||||
|
observation = await tool.execute(**tool_params)
|
||||||
|
step["observation"] = observation
|
||||||
|
|
||||||
|
# 构建收集信息的描述
|
||||||
|
param_str = ", ".join([f"{k}={v}" for k, v in action_params.items()])
|
||||||
|
collected_info += f"\n查询{action_type}({param_str})的结果:{observation}\n"
|
||||||
|
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具执行结果: {observation}")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"工具执行失败: {str(e)}"
|
||||||
|
step["observation"] = error_msg
|
||||||
|
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 {error_msg}")
|
||||||
|
else:
|
||||||
|
step["observation"] = f"未知的工具类型: {action_type}"
|
||||||
|
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 未知的工具类型: {action_type}")
|
||||||
|
|
||||||
|
thinking_steps.append(step)
|
||||||
|
|
||||||
|
# 如果观察结果为空或无效,继续下一轮
|
||||||
|
if step["observation"] and "无有效信息" not in step["observation"] and "未找到" not in step["observation"]:
|
||||||
|
# 有有效信息,继续思考
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 达到最大迭代次数或超时,但Agent没有明确返回final_answer
|
||||||
|
# 这种情况下,即使收集到了一些信息,也不认为找到了答案
|
||||||
|
# 只有Agent明确返回final_answer时,才认为找到了答案
|
||||||
|
if collected_info:
|
||||||
|
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}...")
|
||||||
|
return False, collected_info, thinking_steps
|
||||||
|
else:
|
||||||
|
return False, "未找到相关信息", thinking_steps
|
||||||
|
|
||||||
|
|
||||||
|
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0) -> str:
|
||||||
|
"""获取最近一段时间内的查询历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
time_window_seconds: 时间窗口(秒),默认1小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化的查询历史字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
start_time = current_time - time_window_seconds
|
||||||
|
|
||||||
|
# 查询最近时间窗口内的记录,按更新时间倒序
|
||||||
|
records = (
|
||||||
|
ThinkingBack.select()
|
||||||
|
.where(
|
||||||
|
(ThinkingBack.chat_id == chat_id) &
|
||||||
|
(ThinkingBack.update_time >= start_time)
|
||||||
|
)
|
||||||
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
|
.limit(8) # 最多返回10条最近的记录
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records.exists():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
history_lines = []
|
||||||
|
history_lines.append("最近已查询的问题和结果:")
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
|
||||||
|
answer_preview = ""
|
||||||
|
if record.answer:
|
||||||
|
# 截取答案前100字符
|
||||||
|
answer_preview = record.answer[:100]
|
||||||
|
if len(record.answer) > 100:
|
||||||
|
answer_preview += "..."
|
||||||
|
|
||||||
|
history_lines.append(f"- 问题:{record.question}")
|
||||||
|
history_lines.append(f" 状态:{status}")
|
||||||
|
if answer_preview:
|
||||||
|
history_lines.append(f" 答案:{answer_preview}")
|
||||||
|
history_lines.append("") # 空行分隔
|
||||||
|
|
||||||
|
return "\n".join(history_lines)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取查询历史失败: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]:
|
||||||
|
"""从thinking_back数据库中查询是否有现成的答案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
question: 问题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[bool, str]]: 如果找到答案,返回(True, answer),否则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查询相同chat_id和问题,且found_answer为True的记录
|
||||||
|
# 按更新时间倒序,获取最新的答案
|
||||||
|
records = (
|
||||||
|
ThinkingBack.select()
|
||||||
|
.where(
|
||||||
|
(ThinkingBack.chat_id == chat_id) &
|
||||||
|
(ThinkingBack.question == question) &
|
||||||
|
(ThinkingBack.found_answer == 1)
|
||||||
|
)
|
||||||
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if records.exists():
|
||||||
|
record = records.get()
|
||||||
|
logger.info(f"在thinking_back中找到现成答案,问题: {question[:50]}...")
|
||||||
|
return True, record.answer or ""
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询thinking_back失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _store_thinking_back(
|
||||||
|
chat_id: str,
|
||||||
|
question: str,
|
||||||
|
context: str,
|
||||||
|
found_answer: bool,
|
||||||
|
answer: str,
|
||||||
|
thinking_steps: List[Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
question: 问题
|
||||||
|
context: 上下文信息
|
||||||
|
found_answer: 是否找到答案
|
||||||
|
answer: 答案内容
|
||||||
|
thinking_steps: 思考步骤列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# 先查询是否已存在相同chat_id和问题的记录
|
||||||
|
existing = (
|
||||||
|
ThinkingBack.select()
|
||||||
|
.where(
|
||||||
|
(ThinkingBack.chat_id == chat_id) &
|
||||||
|
(ThinkingBack.question == question)
|
||||||
|
)
|
||||||
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing.exists():
|
||||||
|
# 更新现有记录
|
||||||
|
record = existing.get()
|
||||||
|
record.context = context
|
||||||
|
record.found_answer = found_answer
|
||||||
|
record.answer = answer
|
||||||
|
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
|
||||||
|
record.update_time = now
|
||||||
|
record.save()
|
||||||
|
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
|
||||||
|
else:
|
||||||
|
# 创建新记录
|
||||||
|
ThinkingBack.create(
|
||||||
|
chat_id=chat_id,
|
||||||
|
question=question,
|
||||||
|
context=context,
|
||||||
|
found_answer=found_answer,
|
||||||
|
answer=answer,
|
||||||
|
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||||
|
create_time=now,
|
||||||
|
update_time=now
|
||||||
|
)
|
||||||
|
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存储思考过程失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def build_memory_retrieval_prompt(
|
||||||
|
message: str,
|
||||||
|
sender: str,
|
||||||
|
target: str,
|
||||||
|
chat_stream,
|
||||||
|
tool_executor,
|
||||||
|
) -> str:
|
||||||
|
"""构建记忆检索提示
|
||||||
|
使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 聊天历史记录
|
||||||
|
sender: 发送者名称
|
||||||
|
target: 目标消息内容
|
||||||
|
chat_stream: 聊天流对象
|
||||||
|
tool_executor: 工具执行器(保留参数以兼容接口)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 记忆检索结果字符串
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
|
try:
|
||||||
|
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
bot_name = global_config.bot.nickname
|
||||||
|
chat_id = chat_stream.stream_id
|
||||||
|
|
||||||
|
# 获取最近查询历史(最近1小时内的查询)
|
||||||
|
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0)
|
||||||
|
if not recent_query_history:
|
||||||
|
recent_query_history = "最近没有查询记录。"
|
||||||
|
|
||||||
|
# 第一步:生成问题
|
||||||
|
question_prompt = await global_prompt_manager.format_prompt(
|
||||||
|
"memory_retrieval_question_prompt",
|
||||||
|
bot_name=bot_name,
|
||||||
|
time_now=time_now,
|
||||||
|
chat_history=message,
|
||||||
|
recent_query_history=recent_query_history,
|
||||||
|
sender=sender,
|
||||||
|
target_message=target,
|
||||||
|
)
|
||||||
|
|
||||||
|
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||||
|
question_prompt,
|
||||||
|
model_config=model_config.model_task_config.tool_use,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
||||||
|
logger.info(f"记忆检索问题生成响应: {response}")
|
||||||
|
logger.info(f"记忆检索问题生成推理: {reasoning_content}")
|
||||||
|
logger.info(f"记忆检索问题生成模型: {model_name}")
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error(f"LLM生成问题失败: {response}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 解析问题列表
|
||||||
|
questions = _parse_questions_json(response)
|
||||||
|
|
||||||
|
if not questions:
|
||||||
|
logger.debug("模型认为不需要检索记忆或解析失败")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||||
|
|
||||||
|
# 第二步:对每个问题查询答案
|
||||||
|
all_results = []
|
||||||
|
for question in questions:
|
||||||
|
logger.info(f"开始处理问题: {question}")
|
||||||
|
|
||||||
|
# 先检查thinking_back数据库中是否有现成答案
|
||||||
|
cached_result = _query_thinking_back(chat_id, question)
|
||||||
|
if cached_result:
|
||||||
|
found_answer, answer = cached_result
|
||||||
|
if found_answer and answer:
|
||||||
|
logger.info(f"从thinking_back缓存中获取答案,问题: {question[:50]}...")
|
||||||
|
all_results.append(f"问题:{question}\n答案:{answer}")
|
||||||
|
continue # 跳过ReAct Agent查询
|
||||||
|
|
||||||
|
# 如果没有缓存答案,使用ReAct Agent查询
|
||||||
|
logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...")
|
||||||
|
found_answer, answer, thinking_steps = await _react_agent_solve_question(
|
||||||
|
question=question,
|
||||||
|
chat_id=chat_id,
|
||||||
|
max_iterations=5,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 存储到数据库
|
||||||
|
_store_thinking_back(
|
||||||
|
chat_id=chat_id,
|
||||||
|
question=question,
|
||||||
|
context=message, # 只存储前500字符作为上下文
|
||||||
|
found_answer=found_answer,
|
||||||
|
answer=answer,
|
||||||
|
thinking_steps=thinking_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
if found_answer and answer:
|
||||||
|
all_results.append(f"问题:{question}\n答案:{answer}")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
if all_results:
|
||||||
|
retrieved_memory = "\n\n".join(all_results)
|
||||||
|
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n请在回复时参考这些回忆的信息。\n"
|
||||||
|
else:
|
||||||
|
logger.debug("所有问题均未找到答案")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_questions_json(response: str) -> List[str]:
|
||||||
|
"""解析问题JSON
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM返回的响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 问题列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 尝试提取JSON(可能包含在```json代码块中)
|
||||||
|
json_pattern = r"```json\s*(.*?)\s*```"
|
||||||
|
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
json_str = matches[0]
|
||||||
|
else:
|
||||||
|
# 尝试直接解析整个响应
|
||||||
|
json_str = response.strip()
|
||||||
|
|
||||||
|
# 修复可能的JSON错误
|
||||||
|
repaired_json = repair_json(json_str)
|
||||||
|
|
||||||
|
# 解析JSON
|
||||||
|
questions = json.loads(repaired_json)
|
||||||
|
|
||||||
|
if not isinstance(questions, list):
|
||||||
|
logger.warning(f"解析的JSON不是数组格式: {questions}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 确保所有元素都是字符串
|
||||||
|
questions = [q for q in questions if isinstance(q, str) and q.strip()]
|
||||||
|
|
||||||
|
return questions
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||||
|
return []
|
||||||
|
|
@ -11,6 +11,7 @@ from typing import List, Tuple, Optional
|
||||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory_utils")
|
logger = get_logger("memory_utils")
|
||||||
|
|
@ -355,3 +356,37 @@ def find_most_similar_memory_by_chat_id(target_title: str, target_chat_id: str,
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"查找最相似记忆时出错: {e}")
|
logger.error(f"查找最相似记忆时出错: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def compute_merge_similarity_threshold() -> float:
|
||||||
|
"""
|
||||||
|
根据当前记忆数量占比动态计算合并相似度阈值。
|
||||||
|
|
||||||
|
规则:占比越高,阈值越低。
|
||||||
|
- < 60%: 0.80(更严格,避免早期误合并)
|
||||||
|
- < 80%: 0.70
|
||||||
|
- < 100%: 0.60
|
||||||
|
- < 120%: 0.50
|
||||||
|
- >= 120%: 0.45(最宽松,加速收敛)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_count = MemoryChestModel.select().count()
|
||||||
|
max_count = max(1, int(global_config.memory.max_memory_number))
|
||||||
|
percentage = current_count / max_count
|
||||||
|
|
||||||
|
if percentage < 0.6:
|
||||||
|
return 0.70
|
||||||
|
elif percentage < 0.8:
|
||||||
|
return 0.60
|
||||||
|
elif percentage < 1.0:
|
||||||
|
return 0.50
|
||||||
|
elif percentage < 1.5:
|
||||||
|
return 0.40
|
||||||
|
elif percentage < 2:
|
||||||
|
return 0.30
|
||||||
|
else:
|
||||||
|
return 0.25
|
||||||
|
except Exception:
|
||||||
|
# 发生异常时使用保守阈值
|
||||||
|
return 0.70
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
# 记忆检索工具模块
|
||||||
|
|
||||||
|
这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
retrieval_tools/
|
||||||
|
├── __init__.py # 模块导出
|
||||||
|
├── tool_registry.py # 工具注册系统
|
||||||
|
├── tool_utils.py # 工具函数库(共用函数)
|
||||||
|
├── query_jargon.py # 查询jargon工具
|
||||||
|
├── query_chat_history.py # 查询聊天历史工具
|
||||||
|
└── README.md # 本文件
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模块说明
|
||||||
|
|
||||||
|
### `tool_registry.py`
|
||||||
|
包含工具注册系统的核心类:
|
||||||
|
- `MemoryRetrievalTool`: 工具基类
|
||||||
|
- `MemoryRetrievalToolRegistry`: 工具注册器
|
||||||
|
- `register_memory_retrieval_tool()`: 便捷注册函数
|
||||||
|
- `get_tool_registry()`: 获取注册器实例
|
||||||
|
|
||||||
|
### `tool_utils.py`
|
||||||
|
包含所有工具共用的工具函数:
|
||||||
|
- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳
|
||||||
|
- `parse_time_range()`: 解析时间范围字符串
|
||||||
|
|
||||||
|
### 工具文件
|
||||||
|
每个工具都有独立的文件:
|
||||||
|
- `query_jargon.py`: 根据关键词在jargon库中查询
|
||||||
|
- `query_chat_history.py`: 根据时间或关键词在chat_history中查询(支持查询时间点事件、时间范围事件、关键词搜索)
|
||||||
|
|
||||||
|
## 如何添加新工具
|
||||||
|
|
||||||
|
1. 创建新的工具文件,例如 `query_new_tool.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
新工具 - 工具实现
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from .tool_registry import register_memory_retrieval_tool
|
||||||
|
from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数
|
||||||
|
|
||||||
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
|
async def query_new_tool(param1: str, param2: str, chat_id: str) -> str:
|
||||||
|
"""新工具的实现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: 参数1
|
||||||
|
param2: 参数2
|
||||||
|
chat_id: 聊天ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 查询结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 实现逻辑
|
||||||
|
return "结果"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"新工具执行失败: {e}")
|
||||||
|
return f"查询失败: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool():
|
||||||
|
"""注册工具"""
|
||||||
|
register_memory_retrieval_tool(
|
||||||
|
name="query_new_tool",
|
||||||
|
description="新工具的描述",
|
||||||
|
parameters=[
|
||||||
|
{
|
||||||
|
"name": "param1",
|
||||||
|
"type": "string",
|
||||||
|
"description": "参数1的描述",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "param2",
|
||||||
|
"type": "string",
|
||||||
|
"description": "参数2的描述",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
],
|
||||||
|
execute_func=query_new_tool
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 在 `__init__.py` 中导入并注册新工具:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .query_new_tool import register_tool as register_query_new_tool
|
||||||
|
|
||||||
|
def init_all_tools():
|
||||||
|
"""初始化并注册所有记忆检索工具"""
|
||||||
|
register_query_jargon()
|
||||||
|
register_query_chat_history()
|
||||||
|
register_query_new_tool() # 添加新工具
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 工具会自动:
|
||||||
|
- 出现在 ReAct Agent 的 prompt 中
|
||||||
|
- 在动作类型列表中可用
|
||||||
|
- 被 ReAct Agent 自动调用
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry
|
||||||
|
|
||||||
|
# 初始化所有工具
|
||||||
|
init_all_tools()
|
||||||
|
|
||||||
|
# 获取工具注册器
|
||||||
|
registry = get_tool_registry()
|
||||||
|
|
||||||
|
# 获取特定工具
|
||||||
|
tool = registry.get_tool("query_chat_history")
|
||||||
|
|
||||||
|
# 执行工具(查询时间点事件)
|
||||||
|
result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123")
|
||||||
|
|
||||||
|
# 或者查询关键词
|
||||||
|
result = await tool.execute(keyword="小丑AI", chat_id="chat123")
|
||||||
|
|
||||||
|
# 或者查询时间范围
|
||||||
|
result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 现有工具说明
|
||||||
|
|
||||||
|
### query_jargon
|
||||||
|
根据关键词在jargon库中查询黑话/俚语/缩写的含义
|
||||||
|
- 参数:`keyword` (必填) - 关键词
|
||||||
|
|
||||||
|
### query_chat_history
|
||||||
|
根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息
|
||||||
|
- 参数:
|
||||||
|
- `keyword` (可选) - 关键词,用于搜索消息内容
|
||||||
|
- `time_point` (可选) - 时间点,格式:YYYY-MM-DD HH:MM:SS,用于查询某个时间点附近发生了什么(与time_range二选一)
|
||||||
|
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(与time_point二选一)
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
- 所有工具函数必须是异步函数(`async def`)
|
||||||
|
- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测)
|
||||||
|
- 工具参数定义中的 `required` 字段用于生成 prompt 描述
|
||||||
|
- 工具执行失败时应返回错误信息字符串,而不是抛出异常
|
||||||
|
- 共用函数放在 `tool_utils.py` 中,避免代码重复
|
||||||
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
"""
|
||||||
|
记忆检索工具模块
|
||||||
|
提供统一的工具注册和管理系统
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .tool_registry import (
|
||||||
|
MemoryRetrievalTool,
|
||||||
|
MemoryRetrievalToolRegistry,
|
||||||
|
register_memory_retrieval_tool,
|
||||||
|
get_tool_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入所有工具的注册函数
|
||||||
|
from .query_jargon import register_tool as register_query_jargon
|
||||||
|
from .query_chat_history import register_tool as register_query_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
def init_all_tools():
|
||||||
|
"""初始化并注册所有记忆检索工具"""
|
||||||
|
register_query_jargon()
|
||||||
|
register_query_chat_history()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MemoryRetrievalTool",
|
||||||
|
"MemoryRetrievalToolRegistry",
|
||||||
|
"register_memory_retrieval_tool",
|
||||||
|
"get_tool_registry",
|
||||||
|
"init_all_tools",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,221 @@
|
||||||
|
"""
|
||||||
|
根据时间或关键词在chat_history中查询 - 工具实现
|
||||||
|
从ChatHistory表的聊天记录概述库中查询
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import model_config
|
||||||
|
from src.common.database.database_model import ChatHistory
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from .tool_registry import register_memory_retrieval_tool
|
||||||
|
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
|
||||||
|
|
||||||
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
|
async def query_chat_history(
|
||||||
|
chat_id: str,
|
||||||
|
keyword: Optional[str] = None,
|
||||||
|
time_point: Optional[str] = None,
|
||||||
|
time_range: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
keyword: 关键词(可选)
|
||||||
|
time_point: 时间点,格式:YYYY-MM-DD HH:MM:SS(可选)
|
||||||
|
time_range: 时间范围,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 查询结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查参数
|
||||||
|
if not keyword and not time_point and not time_range:
|
||||||
|
return "未指定查询参数(需要提供keyword、time_point或time_range之一)"
|
||||||
|
|
||||||
|
# 构建查询条件
|
||||||
|
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||||
|
|
||||||
|
# 时间过滤条件
|
||||||
|
time_conditions = []
|
||||||
|
if time_point:
|
||||||
|
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||||
|
target_timestamp = parse_datetime_to_timestamp(time_point)
|
||||||
|
time_conditions.append(
|
||||||
|
(ChatHistory.start_time <= target_timestamp) &
|
||||||
|
(ChatHistory.end_time >= target_timestamp)
|
||||||
|
)
|
||||||
|
elif time_range:
|
||||||
|
# 时间范围:查询与时间范围有交集的记录
|
||||||
|
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||||
|
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||||
|
time_conditions.append(
|
||||||
|
(ChatHistory.start_time < end_timestamp) &
|
||||||
|
(ChatHistory.end_time > start_timestamp)
|
||||||
|
)
|
||||||
|
|
||||||
|
if time_conditions:
|
||||||
|
# 合并所有时间条件(OR关系)
|
||||||
|
time_filter = time_conditions[0]
|
||||||
|
for condition in time_conditions[1:]:
|
||||||
|
time_filter = time_filter | condition
|
||||||
|
query = query.where(time_filter)
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return "未找到相关聊天记录概述"
|
||||||
|
|
||||||
|
# 如果有关键词,进一步过滤
|
||||||
|
if keyword:
|
||||||
|
keyword_lower = keyword.lower()
|
||||||
|
filtered_records = []
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
# 在theme、keywords、summary、original_text中搜索
|
||||||
|
theme = (record.theme or "").lower()
|
||||||
|
summary = (record.summary or "").lower()
|
||||||
|
original_text = (record.original_text or "").lower()
|
||||||
|
|
||||||
|
# 解析keywords JSON
|
||||||
|
keywords_list = []
|
||||||
|
if record.keywords:
|
||||||
|
try:
|
||||||
|
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
|
if isinstance(keywords_data, list):
|
||||||
|
keywords_list = [str(k).lower() for k in keywords_data]
|
||||||
|
except (json.JSONDecodeError, TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 检查是否包含关键词
|
||||||
|
if (keyword_lower in theme or
|
||||||
|
keyword_lower in summary or
|
||||||
|
keyword_lower in original_text or
|
||||||
|
any(keyword_lower in k for k in keywords_list)):
|
||||||
|
filtered_records.append(record)
|
||||||
|
|
||||||
|
if not filtered_records:
|
||||||
|
return f"未找到包含关键词'{keyword}'的聊天记录概述"
|
||||||
|
|
||||||
|
records = filtered_records
|
||||||
|
|
||||||
|
# 构建结果文本
|
||||||
|
results = []
|
||||||
|
for record in records[:10]: # 最多返回10条记录
|
||||||
|
result_parts = []
|
||||||
|
|
||||||
|
# 添加主题
|
||||||
|
if record.theme:
|
||||||
|
result_parts.append(f"主题:{record.theme}")
|
||||||
|
|
||||||
|
# 添加时间范围
|
||||||
|
from datetime import datetime
|
||||||
|
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||||
|
|
||||||
|
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||||
|
if record.summary:
|
||||||
|
result_parts.append(f"概括:{record.summary}")
|
||||||
|
elif record.original_text:
|
||||||
|
text_preview = record.original_text[:200]
|
||||||
|
if len(record.original_text) > 200:
|
||||||
|
text_preview += "..."
|
||||||
|
result_parts.append(f"内容:{text_preview}")
|
||||||
|
|
||||||
|
results.append("\n".join(result_parts))
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return "未找到相关聊天记录概述"
|
||||||
|
|
||||||
|
# 如果只有一条记录,直接返回
|
||||||
|
if len(results) == 1:
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
# 多条记录,使用LLM总结
|
||||||
|
try:
|
||||||
|
llm_request = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils_small,
|
||||||
|
request_type="chat_history_analysis"
|
||||||
|
)
|
||||||
|
|
||||||
|
query_desc = []
|
||||||
|
if keyword:
|
||||||
|
query_desc.append(f"关键词:{keyword}")
|
||||||
|
if time_point:
|
||||||
|
query_desc.append(f"时间点:{time_point}")
|
||||||
|
if time_range:
|
||||||
|
query_desc.append(f"时间范围:{time_range}")
|
||||||
|
|
||||||
|
query_info = ",".join(query_desc) if query_desc else "聊天记录概述"
|
||||||
|
|
||||||
|
combined_results = "\n\n---\n\n".join(results)
|
||||||
|
|
||||||
|
analysis_prompt = f"""请根据以下聊天记录概述,总结与查询条件相关的信息。请输出一段平文本,不要有特殊格式。
|
||||||
|
查询条件:{query_info}
|
||||||
|
|
||||||
|
聊天记录概述:
|
||||||
|
{combined_results}
|
||||||
|
|
||||||
|
请仔细分析聊天记录概述,提取与查询条件相关的信息并给出总结。如果概述中没有相关信息,输出"无有效信息"即可,不要输出其他内容。
|
||||||
|
|
||||||
|
总结:"""
|
||||||
|
|
||||||
|
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
|
||||||
|
prompt=analysis_prompt,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=512
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"查询聊天历史概述提示词: {analysis_prompt}")
|
||||||
|
logger.info(f"查询聊天历史概述响应: {response}")
|
||||||
|
logger.info(f"查询聊天历史概述推理: {reasoning}")
|
||||||
|
logger.info(f"查询聊天历史概述模型: {model_name}")
|
||||||
|
|
||||||
|
if "无有效信息" in response:
|
||||||
|
return "无有效信息"
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as llm_error:
|
||||||
|
logger.error(f"LLM分析聊天记录概述失败: {llm_error}")
|
||||||
|
# 如果LLM分析失败,返回前3条记录的摘要
|
||||||
|
return "\n\n---\n\n".join(results[:3])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询聊天历史概述失败: {e}")
|
||||||
|
return f"查询失败: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool():
|
||||||
|
"""注册工具"""
|
||||||
|
register_memory_retrieval_tool(
|
||||||
|
name="query_chat_history",
|
||||||
|
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述",
|
||||||
|
parameters=[
|
||||||
|
{
|
||||||
|
"name": "keyword",
|
||||||
|
"type": "string",
|
||||||
|
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)",
|
||||||
|
"required": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "time_point",
|
||||||
|
"type": "string",
|
||||||
|
"description": "时间点,格式:YYYY-MM-DD HH:MM:SS(可选,与time_range二选一)。用于查询包含该时间点的聊天记录概述",
|
||||||
|
"required": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "time_range",
|
||||||
|
"type": "string",
|
||||||
|
"description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(可选,与time_point二选一)。用于查询与时间范围有交集的聊天记录概述",
|
||||||
|
"required": False
|
||||||
|
}
|
||||||
|
],
|
||||||
|
execute_func=query_chat_history
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""
|
||||||
|
根据关键词在jargon库中查询 - 工具实现
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.jargon.jargon_miner import search_jargon
|
||||||
|
from .tool_registry import register_memory_retrieval_tool
|
||||||
|
|
||||||
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
|
async def query_jargon(
|
||||||
|
keyword: str,
|
||||||
|
chat_id: str,
|
||||||
|
fuzzy: bool = False,
|
||||||
|
search_all: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""根据关键词在jargon库中查询
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyword: 关键词(黑话/俚语/缩写)
|
||||||
|
chat_id: 聊天ID
|
||||||
|
fuzzy: 是否使用模糊搜索,默认False(精确匹配)
|
||||||
|
search_all: 是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 查询结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = str(keyword).strip()
|
||||||
|
if not content:
|
||||||
|
return "关键词为空"
|
||||||
|
|
||||||
|
# 根据参数执行搜索
|
||||||
|
search_chat_id = None if search_all else chat_id
|
||||||
|
results = search_jargon(
|
||||||
|
keyword=content,
|
||||||
|
chat_id=search_chat_id,
|
||||||
|
limit=1,
|
||||||
|
case_sensitive=False,
|
||||||
|
fuzzy=fuzzy
|
||||||
|
)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
result = results[0]
|
||||||
|
translation = result.get("translation", "").strip()
|
||||||
|
meaning = result.get("meaning", "").strip()
|
||||||
|
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
||||||
|
search_scope = "全库" if search_all else "当前会话或全局"
|
||||||
|
output = f"“{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}”"
|
||||||
|
logger.info(f"在jargon库中找到匹配({search_scope},{search_type}): {content}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
# 未命中
|
||||||
|
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
||||||
|
search_scope = "全库" if search_all else "当前会话或全局"
|
||||||
|
logger.info(f"在jargon库中未找到匹配({search_scope},{search_type}): {content}")
|
||||||
|
return f"未在jargon库中找到'{content}'的解释"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询jargon失败: {e}")
|
||||||
|
return f"查询失败: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool():
|
||||||
|
"""注册工具"""
|
||||||
|
register_memory_retrieval_tool(
|
||||||
|
name="query_jargon",
|
||||||
|
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon,可以设置为搜索全库。",
|
||||||
|
parameters=[
|
||||||
|
{
|
||||||
|
"name": "keyword",
|
||||||
|
"type": "string",
|
||||||
|
"description": "关键词(黑话/俚语/缩写),支持模糊搜索",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fuzzy",
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否使用模糊搜索(部分匹配),默认False(精确匹配)。当精确匹配找不到时,可以尝试使用模糊搜索。",
|
||||||
|
"required": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "search_all",
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global的jargon)。当在当前会话中找不到时,可以尝试搜索全库。",
|
||||||
|
"required": False
|
||||||
|
}
|
||||||
|
],
|
||||||
|
execute_func=query_jargon
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""
|
||||||
|
工具注册系统
|
||||||
|
提供统一的工具注册和管理接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRetrievalTool:
|
||||||
|
"""记忆检索工具基类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: List[Dict[str, Any]],
|
||||||
|
execute_func: Callable[..., Awaitable[str]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
description: 工具描述
|
||||||
|
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
|
||||||
|
execute_func: 执行函数,必须是异步函数
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.parameters = parameters
|
||||||
|
self.execute_func = execute_func
|
||||||
|
|
||||||
|
def get_tool_description(self) -> str:
|
||||||
|
"""获取工具的文本描述,用于prompt"""
|
||||||
|
param_descriptions = []
|
||||||
|
for param in self.parameters:
|
||||||
|
param_name = param.get("name", "")
|
||||||
|
param_type = param.get("type", "string")
|
||||||
|
param_desc = param.get("description", "")
|
||||||
|
required = param.get("required", True)
|
||||||
|
required_str = "必填" if required else "可选"
|
||||||
|
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||||||
|
|
||||||
|
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||||
|
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> str:
|
||||||
|
"""执行工具"""
|
||||||
|
return await self.execute_func(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRetrievalToolRegistry:
|
||||||
|
"""工具注册器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||||
|
|
||||||
|
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||||
|
"""注册工具"""
|
||||||
|
self.tools[tool.name] = tool
|
||||||
|
logger.info(f"注册记忆检索工具: {tool.name}")
|
||||||
|
|
||||||
|
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||||||
|
"""获取工具"""
|
||||||
|
return self.tools.get(name)
|
||||||
|
|
||||||
|
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||||||
|
"""获取所有工具"""
|
||||||
|
return self.tools.copy()
|
||||||
|
|
||||||
|
def get_tools_description(self) -> str:
|
||||||
|
"""获取所有工具的描述,用于prompt"""
|
||||||
|
descriptions = []
|
||||||
|
for i, tool in enumerate(self.tools.values(), 1):
|
||||||
|
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||||||
|
return "\n".join(descriptions)
|
||||||
|
|
||||||
|
def get_action_types_list(self) -> str:
|
||||||
|
"""获取所有动作类型的列表,用于prompt"""
|
||||||
|
action_types = [tool.name for tool in self.tools.values()]
|
||||||
|
action_types.append("final_answer")
|
||||||
|
action_types.append("no_answer")
|
||||||
|
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||||
|
|
||||||
|
|
||||||
|
# 全局工具注册器实例
|
||||||
|
_tool_registry = MemoryRetrievalToolRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
def register_memory_retrieval_tool(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: List[Dict[str, Any]],
|
||||||
|
execute_func: Callable[..., Awaitable[str]]
|
||||||
|
) -> None:
|
||||||
|
"""注册记忆检索工具的便捷函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
description: 工具描述
|
||||||
|
parameters: 参数定义列表
|
||||||
|
execute_func: 执行函数
|
||||||
|
"""
|
||||||
|
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||||||
|
_tool_registry.register_tool(tool)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||||
|
"""获取工具注册器实例"""
|
||||||
|
return _tool_registry
|
||||||
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
"""
|
||||||
|
工具函数库
|
||||||
|
包含所有工具共用的工具函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def parse_datetime_to_timestamp(value: str) -> float:
|
||||||
|
"""
|
||||||
|
接受多种常见格式并转换为时间戳(秒)
|
||||||
|
支持示例:
|
||||||
|
- 2025-09-29
|
||||||
|
- 2025-09-29 00:00:00
|
||||||
|
- 2025/09/29 00:00
|
||||||
|
- 2025-09-29T00:00:00
|
||||||
|
"""
|
||||||
|
value = value.strip()
|
||||||
|
fmts = [
|
||||||
|
"%Y-%m-%d %H:%M:%S",
|
||||||
|
"%Y-%m-%d %H:%M",
|
||||||
|
"%Y/%m/%d %H:%M:%S",
|
||||||
|
"%Y/%m/%d %H:%M",
|
||||||
|
"%Y-%m-%d",
|
||||||
|
"%Y/%m/%d",
|
||||||
|
"%Y-%m-%dT%H:%M:%S",
|
||||||
|
"%Y-%m-%dT%H:%M",
|
||||||
|
]
|
||||||
|
last_err = None
|
||||||
|
for fmt in fmts:
|
||||||
|
try:
|
||||||
|
dt = datetime.strptime(value, fmt)
|
||||||
|
return dt.timestamp()
|
||||||
|
except Exception as e:
|
||||||
|
last_err = e
|
||||||
|
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
解析时间范围字符串,返回开始和结束时间戳
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||||
|
"""
|
||||||
|
if " - " not in time_range:
|
||||||
|
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||||
|
|
||||||
|
parts = time_range.split(" - ", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||||
|
|
||||||
|
start_str = parts[0].strip()
|
||||||
|
end_str = parts[1].strip()
|
||||||
|
|
||||||
|
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||||
|
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||||
|
|
||||||
|
return start_timestamp, end_timestamp
|
||||||
|
|
||||||
Loading…
Reference in New Issue