mirror of https://github.com/Mai-with-u/MaiBot.git
feat:优化黑话附加
parent
8237f1a4c1
commit
138bd8ec70
|
|
@ -249,3 +249,112 @@ async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_cont
|
|||
"""
|
||||
explainer = JargonExplainer(chat_id)
|
||||
return await explainer.explain_jargon(messages, chat_context)
|
||||
|
||||
|
||||
def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
"""直接在聊天文本中匹配已知的jargon,返回出现过的黑话列表
|
||||
|
||||
Args:
|
||||
chat_text: 要匹配的聊天文本
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
List[str]: 匹配到的黑话列表
|
||||
"""
|
||||
if not chat_text or not chat_text.strip():
|
||||
return []
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.jargon.all_global:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
matched: Dict[str, None] = {}
|
||||
|
||||
for jargon in query:
|
||||
content = (jargon.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.jargon.all_global and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
pattern = re.escape(content)
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
search_pattern = pattern
|
||||
else:
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, chat_text, re.IGNORECASE):
|
||||
matched[content] = None
|
||||
|
||||
logger.info(f"匹配到 {len(matched)} 个黑话")
|
||||
|
||||
return list(matched.keys())
|
||||
|
||||
|
||||
async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||
"""对概念列表进行jargon检索
|
||||
|
||||
Args:
|
||||
concepts: 概念列表
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 检索结果字符串
|
||||
"""
|
||||
if not concepts:
|
||||
return ""
|
||||
|
||||
results = []
|
||||
exact_matches = [] # 收集所有精确匹配的概念
|
||||
for concept in concepts:
|
||||
concept = concept.strip()
|
||||
if not concept:
|
||||
continue
|
||||
|
||||
# 先尝试精确匹配
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not jargon_results:
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if jargon_results:
|
||||
# 找到结果
|
||||
if is_fuzzy_match:
|
||||
# 模糊匹配
|
||||
output_parts = [f"未精确匹配到'{concept}'"]
|
||||
for result in jargon_results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
results.append(",".join(output_parts))
|
||||
logger.info(f"在jargon库中找到匹配(模糊搜索): {concept},找到{len(jargon_results)}条结果")
|
||||
else:
|
||||
# 精确匹配
|
||||
output_parts = []
|
||||
for result in jargon_results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
results.append(";".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
|
||||
else:
|
||||
# 未找到,不返回占位信息,只记录日志
|
||||
logger.info(f"在jargon库中未找到匹配: {concept}")
|
||||
|
||||
# 合并所有精确匹配的日志
|
||||
if exact_matches:
|
||||
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
|
|
@ -1,17 +1,16 @@
|
|||
import time
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.common.database.database_model import ThinkingBack, Jargon
|
||||
from json_repair import repair_json
|
||||
from src.common.database.database_model import ThinkingBack
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.memory_system.memory_utils import parse_questions_json
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.jargon.jargon_utils import parse_chat_id_list, chat_id_list_contains
|
||||
from src.jargon.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
|
|
@ -101,11 +100,6 @@ def init_memory_retrieval_prompt():
|
|||
Prompt(
|
||||
"""你的名字是{bot_name}。现在是{time_now}。
|
||||
你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。
|
||||
|
||||
**重要限制:**
|
||||
- 思考要简短,直接切入要点
|
||||
- 最大查询轮数:{max_iterations}轮(当前第{current_iteration}轮,剩余{remaining_iterations}轮)
|
||||
|
||||
当前需要解答的问题:{question}
|
||||
已收集的信息:
|
||||
{collected_info}
|
||||
|
|
@ -118,7 +112,7 @@ def init_memory_retrieval_prompt():
|
|||
- **如果当前已收集的信息足够回答问题,且能找到明确答案,调用found_answer工具标记已找到答案**
|
||||
|
||||
**思考**
|
||||
- 你可以对查询思路给出简短的思考
|
||||
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
|
||||
- 如果信息不足,你必须给出使用什么工具进行查询
|
||||
- 如果信息足够,你必须调用found_answer工具
|
||||
""",
|
||||
|
|
@ -152,169 +146,6 @@ def init_memory_retrieval_prompt():
|
|||
)
|
||||
|
||||
|
||||
def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析ReAct Agent的响应
|
||||
|
||||
Args:
|
||||
response: LLM返回的响应
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 解析后的动作信息,如果解析失败返回None
|
||||
格式: {"thought": str, "actions": List[Dict[str, Any]]}
|
||||
每个action格式: {"action_type": str, "action_params": dict}
|
||||
"""
|
||||
try:
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
# 解析JSON
|
||||
action_info = json.loads(repaired_json)
|
||||
|
||||
if not isinstance(action_info, dict):
|
||||
logger.warning(f"解析的JSON不是对象格式: {action_info}")
|
||||
return None
|
||||
|
||||
# 确保actions字段存在且为列表
|
||||
if "actions" not in action_info:
|
||||
logger.warning(f"响应中缺少actions字段: {action_info}")
|
||||
return None
|
||||
|
||||
if not isinstance(action_info["actions"], list):
|
||||
logger.warning(f"actions字段不是数组格式: {action_info['actions']}")
|
||||
return None
|
||||
|
||||
# 确保actions不为空
|
||||
if len(action_info["actions"]) == 0:
|
||||
logger.warning("actions数组为空")
|
||||
return None
|
||||
|
||||
return action_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...")
|
||||
return None
|
||||
|
||||
|
||||
async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||
"""对概念列表进行jargon检索
|
||||
|
||||
Args:
|
||||
concepts: 概念列表
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 检索结果字符串
|
||||
"""
|
||||
if not concepts:
|
||||
return ""
|
||||
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
|
||||
results = []
|
||||
exact_matches = [] # 收集所有精确匹配的概念
|
||||
for concept in concepts:
|
||||
concept = concept.strip()
|
||||
if not concept:
|
||||
continue
|
||||
|
||||
# 先尝试精确匹配
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not jargon_results:
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if jargon_results:
|
||||
# 找到结果
|
||||
if is_fuzzy_match:
|
||||
# 模糊匹配
|
||||
output_parts = [f"未精确匹配到'{concept}'"]
|
||||
for result in jargon_results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
results.append(",".join(output_parts))
|
||||
logger.info(f"在jargon库中找到匹配(模糊搜索): {concept},找到{len(jargon_results)}条结果")
|
||||
else:
|
||||
# 精确匹配
|
||||
output_parts = []
|
||||
for result in jargon_results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
results.append(";".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
|
||||
else:
|
||||
# 未找到,不返回占位信息,只记录日志
|
||||
logger.info(f"在jargon库中未找到匹配: {concept}")
|
||||
|
||||
# 合并所有精确匹配的日志
|
||||
if exact_matches:
|
||||
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
|
||||
|
||||
def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
"""直接在聊天文本中匹配已知的jargon,返回出现过的黑话列表"""
|
||||
# print(chat_text)
|
||||
if not chat_text or not chat_text.strip():
|
||||
return []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.jargon.all_global:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
query_time = time.time()
|
||||
matched: Dict[str, None] = {}
|
||||
|
||||
for jargon in query:
|
||||
content = (jargon.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.jargon.all_global and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
pattern = re.escape(content)
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
search_pattern = pattern
|
||||
else:
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, chat_text, re.IGNORECASE):
|
||||
matched[content] = None
|
||||
|
||||
# end_time = time.time()
|
||||
logger.info(
|
||||
# f"记忆检索黑话匹配: 查询耗时 {(query_time - start_time):.3f}s, "
|
||||
# f"匹配耗时 {(end_time - query_time):.3f}s, 总耗时 {(end_time - start_time):.3f}s, "
|
||||
f"匹配到 {len(matched)} 个黑话"
|
||||
)
|
||||
|
||||
return list(matched.keys())
|
||||
|
||||
|
||||
def _log_conversation_messages(
|
||||
|
|
@ -336,7 +167,7 @@ def _log_conversation_messages(
|
|||
|
||||
# 如果有head_prompt,先添加为第一条消息
|
||||
if head_prompt:
|
||||
msg_info = "========================================\n[消息 1] 角色: System 内容类型: 文本\n-----------------------"
|
||||
msg_info = "========================================\n[消息 1] 角色: System 内容类型: 文本\n-----------------------------"
|
||||
msg_info += f"\n{head_prompt}"
|
||||
log_lines.append(msg_info)
|
||||
start_idx = 2
|
||||
|
|
@ -363,7 +194,7 @@ def _log_conversation_messages(
|
|||
content_type = "未知"
|
||||
|
||||
# 构建单条消息的日志信息
|
||||
msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================"
|
||||
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||
|
||||
if full_content:
|
||||
msg_info += f"\n{full_content}"
|
||||
|
|
@ -513,7 +344,7 @@ async def _react_agent_solve_question(
|
|||
)
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估Prompt: {evaluation_prompt}")
|
||||
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
|
||||
|
||||
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
evaluation_prompt,
|
||||
|
|
@ -656,7 +487,7 @@ async def _react_agent_solve_question(
|
|||
request_type="memory.react",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
)
|
||||
|
||||
|
|
@ -706,25 +537,19 @@ async def _react_agent_solve_question(
|
|||
continue
|
||||
|
||||
# 处理工具调用
|
||||
tool_tasks = []
|
||||
found_answer_from_tool = None # 检测是否有found_answer工具调用
|
||||
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
# 首先检查是否有found_answer工具调用,如果有则立即返回,不再处理其他工具
|
||||
found_answer_from_tool = None
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 检查是否是found_answer工具调用
|
||||
|
||||
if tool_name == "found_answer":
|
||||
found_answer_from_tool = tool_args.get("answer", "")
|
||||
if found_answer_from_tool:
|
||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_from_tool}})
|
||||
step["observations"] = ["检测到found_answer工具调用"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_from_tool}")
|
||||
logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_from_tool}")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
|
|
@ -733,7 +558,20 @@ async def _react_agent_solve_question(
|
|||
)
|
||||
|
||||
return True, found_answer_from_tool, thinking_steps, False
|
||||
continue # found_answer工具不需要执行,直接跳过
|
||||
|
||||
# 如果没有found_answer工具调用,或者found_answer工具调用没有答案,继续处理其他工具
|
||||
tool_tasks = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
logger.debug(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 跳过found_answer工具调用(已经在上面处理过了)
|
||||
if tool_name == "found_answer":
|
||||
continue
|
||||
|
||||
# 普通工具调用
|
||||
tool = tool_registry.get_tool(tool_name)
|
||||
|
|
@ -781,7 +619,7 @@ async def _react_agent_solve_question(
|
|||
if stripped_observation:
|
||||
# 检查工具输出中是否有新的jargon,如果有则追加到工具结果中
|
||||
if enable_jargon_detection:
|
||||
jargon_concepts = _match_jargon_from_text(stripped_observation, chat_id)
|
||||
jargon_concepts = match_jargon_from_text(stripped_observation, chat_id)
|
||||
if jargon_concepts:
|
||||
new_concepts = []
|
||||
for concept in jargon_concepts:
|
||||
|
|
@ -790,7 +628,7 @@ async def _react_agent_solve_question(
|
|||
new_concepts.append(normalized_concept)
|
||||
seen_jargon_concepts.add(normalized_concept)
|
||||
if new_concepts:
|
||||
jargon_info = await _retrieve_concepts_with_jargon(new_concepts, chat_id)
|
||||
jargon_info = await retrieve_concepts_with_jargon(new_concepts, chat_id)
|
||||
if jargon_info:
|
||||
# 将jargon查询结果追加到工具结果中
|
||||
observation_text += f"\n\n{jargon_info}"
|
||||
|
|
@ -828,8 +666,8 @@ async def _react_agent_solve_question(
|
|||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
|
||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str:
|
||||
"""获取最近一段时间内的查询历史
|
||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str:
|
||||
"""获取最近一段时间内的查询历史(用于避免重复查询)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
|
@ -879,6 +717,49 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
|||
return ""
|
||||
|
||||
|
||||
def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]:
|
||||
"""获取最近一段时间内已找到答案的查询记录(用于返回给 replyer)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
||||
|
||||
Returns:
|
||||
List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx"
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
start_time = current_time - time_window_seconds
|
||||
|
||||
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id)
|
||||
& (ThinkingBack.update_time >= start_time)
|
||||
& (ThinkingBack.found_answer == 1)
|
||||
& (ThinkingBack.answer.is_null(False))
|
||||
& (ThinkingBack.answer != "")
|
||||
)
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(3) # 最多返回5条最近的记录
|
||||
)
|
||||
|
||||
if not records.exists():
|
||||
return []
|
||||
|
||||
found_answers = []
|
||||
for record in records:
|
||||
if record.answer:
|
||||
found_answers.append(f"问题:{record.question}\n答案:{record.answer}")
|
||||
|
||||
return found_answers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最近已找到答案的记录失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _store_thinking_back(
|
||||
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
|
|
@ -1016,8 +897,8 @@ async def build_memory_retrieval_prompt(
|
|||
bot_name = global_config.bot.nickname
|
||||
chat_id = chat_stream.stream_id
|
||||
|
||||
# 获取最近查询历史(最近1小时内的查询)
|
||||
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0)
|
||||
# 获取最近查询历史(最近10分钟内的查询,用于避免重复查询)
|
||||
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0)
|
||||
if not recent_query_history:
|
||||
recent_query_history = "最近没有查询记录。"
|
||||
|
||||
|
|
@ -1047,7 +928,7 @@ async def build_memory_retrieval_prompt(
|
|||
return ""
|
||||
|
||||
# 解析概念列表和问题列表
|
||||
_, questions = _parse_questions_json(response)
|
||||
_, questions = parse_questions_json(response)
|
||||
if questions:
|
||||
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||
|
||||
|
|
@ -1056,7 +937,7 @@ async def build_memory_retrieval_prompt(
|
|||
|
||||
if enable_jargon_detection:
|
||||
# 使用匹配逻辑自动识别聊天中的黑话概念
|
||||
concepts = _match_jargon_from_text(message, chat_id)
|
||||
concepts = match_jargon_from_text(message, chat_id)
|
||||
if concepts:
|
||||
logger.info(f"黑话匹配命中 {len(concepts)} 个概念: {concepts}")
|
||||
else:
|
||||
|
|
@ -1067,7 +948,7 @@ async def build_memory_retrieval_prompt(
|
|||
# 对匹配到的概念进行jargon检索,作为初始信息
|
||||
initial_info = ""
|
||||
if enable_jargon_detection and concepts:
|
||||
concept_info = await _retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
concept_info = await retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
if concept_info:
|
||||
initial_info += concept_info
|
||||
logger.debug(f"概念检索完成,结果: {concept_info}")
|
||||
|
|
@ -1107,67 +988,47 @@ async def build_memory_retrieval_prompt(
|
|||
elif result is not None:
|
||||
question_results.append(result)
|
||||
|
||||
# 获取最近10分钟内已找到答案的缓存记录
|
||||
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
|
||||
|
||||
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
|
||||
all_results = []
|
||||
|
||||
# 先添加当前查询的结果
|
||||
current_questions = set()
|
||||
for result in question_results:
|
||||
# 提取问题(格式为 "问题:xxx\n答案:xxx")
|
||||
if result.startswith("问题:"):
|
||||
question_end = result.find("\n答案:")
|
||||
if question_end != -1:
|
||||
current_questions.add(result[4:question_end])
|
||||
all_results.append(result)
|
||||
|
||||
# 添加缓存答案(排除当前查询中已存在的问题)
|
||||
for cached_answer in cached_answers:
|
||||
if cached_answer.startswith("问题:"):
|
||||
question_end = cached_answer.find("\n答案:")
|
||||
if question_end != -1:
|
||||
cached_question = cached_answer[4:question_end]
|
||||
if cached_question not in current_questions:
|
||||
all_results.append(cached_answer)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if question_results:
|
||||
retrieved_memory = "\n\n".join(question_results)
|
||||
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆")
|
||||
if all_results:
|
||||
retrieved_memory = "\n\n".join(all_results)
|
||||
current_count = len(question_results)
|
||||
cached_count = len(all_results) - current_count
|
||||
logger.info(
|
||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,"
|
||||
f"当前查询 {current_count} 条记忆,缓存 {cached_count} 条记忆,共 {len(all_results)} 条记忆"
|
||||
)
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
logger.debug("所有问题均未找到答案")
|
||||
logger.debug("所有问题均未找到答案,且无缓存答案")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
"""解析问题JSON,返回概念列表和问题列表
|
||||
|
||||
Args:
|
||||
response: LLM返回的响应
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: (概念列表, 问题列表)
|
||||
"""
|
||||
try:
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
# 解析JSON
|
||||
parsed = json.loads(repaired_json)
|
||||
|
||||
# 只支持新格式:包含concepts和questions的对象
|
||||
if not isinstance(parsed, dict):
|
||||
logger.warning(f"解析的JSON不是对象格式: {parsed}")
|
||||
return [], []
|
||||
|
||||
concepts_raw = parsed.get("concepts", [])
|
||||
questions_raw = parsed.get("questions", [])
|
||||
|
||||
# 确保是列表
|
||||
if not isinstance(concepts_raw, list):
|
||||
concepts_raw = []
|
||||
if not isinstance(questions_raw, list):
|
||||
questions_raw = []
|
||||
|
||||
# 确保所有元素都是字符串
|
||||
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
|
||||
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
|
||||
|
||||
return concepts, questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return [], []
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import json
|
|||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
|
@ -16,101 +17,56 @@ from src.common.logger import get_logger
|
|||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, json_text, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = json_text.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = json_text[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(json_str)
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
"""解析问题JSON,返回概念列表和问题列表
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
response: LLM返回的响应
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
Tuple[List[str], List[str]]: (概念列表, 问题列表)
|
||||
"""
|
||||
try:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
return similarity
|
||||
# 解析JSON
|
||||
parsed = json.loads(repaired_json)
|
||||
|
||||
# 只支持新格式:包含concepts和questions的对象
|
||||
if not isinstance(parsed, dict):
|
||||
logger.warning(f"解析的JSON不是对象格式: {parsed}")
|
||||
return [], []
|
||||
|
||||
concepts_raw = parsed.get("concepts", [])
|
||||
questions_raw = parsed.get("questions", [])
|
||||
|
||||
# 确保是列表
|
||||
if not isinstance(concepts_raw, list):
|
||||
concepts_raw = []
|
||||
if not isinstance(questions_raw, list):
|
||||
questions_raw = []
|
||||
|
||||
# 确保所有元素都是字符串
|
||||
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
|
||||
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
|
||||
|
||||
return concepts, questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return [], []
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
|
|
@ -140,29 +96,3 @@ def parse_datetime_to_timestamp(value: str) -> float:
|
|||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
Args:
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
|
|
|||
Loading…
Reference in New Issue