feat:优化黑话附加

pull/1397/head
SengokuCola 2025-12-02 15:43:46 +08:00
parent 8237f1a4c1
commit 138bd8ec70
3 changed files with 262 additions and 362 deletions

View File

@ -249,3 +249,112 @@ async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_cont
""" """
explainer = JargonExplainer(chat_id) explainer = JargonExplainer(chat_id)
return await explainer.explain_jargon(messages, chat_context) return await explainer.explain_jargon(messages, chat_context)
def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
"""直接在聊天文本中匹配已知的jargon返回出现过的黑话列表
Args:
chat_text: 要匹配的聊天文本
chat_id: 聊天ID
Returns:
List[str]: 匹配到的黑话列表
"""
if not chat_text or not chat_text.strip():
return []
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
if global_config.jargon.all_global:
query = query.where(Jargon.is_global)
query = query.order_by(Jargon.count.desc())
matched: Dict[str, None] = {}
for jargon in query:
content = (jargon.content or "").strip()
if not content:
continue
if not global_config.jargon.all_global and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id):
continue
pattern = re.escape(content)
if re.search(r"[\u4e00-\u9fff]", content):
search_pattern = pattern
else:
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, chat_text, re.IGNORECASE):
matched[content] = None
logger.info(f"匹配到 {len(matched)} 个黑话")
return list(matched.keys())
async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
"""对概念列表进行jargon检索
Args:
concepts: 概念列表
chat_id: 聊天ID
Returns:
str: 检索结果字符串
"""
if not concepts:
return ""
results = []
exact_matches = [] # 收集所有精确匹配的概念
for concept in concepts:
concept = concept.strip()
if not concept:
continue
# 先尝试精确匹配
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not jargon_results:
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if jargon_results:
# 找到结果
if is_fuzzy_match:
# 模糊匹配
output_parts = [f"未精确匹配到'{concept}'"]
for result in jargon_results:
found_content = result.get("content", "").strip()
meaning = result.get("meaning", "").strip()
if found_content and meaning:
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
results.append("".join(output_parts))
logger.info(f"在jargon库中找到匹配模糊搜索: {concept},找到{len(jargon_results)}条结果")
else:
# 精确匹配
output_parts = []
for result in jargon_results:
meaning = result.get("meaning", "").strip()
if meaning:
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
results.append("".join(output_parts) if len(output_parts) > 1 else output_parts[0])
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
else:
# 未找到,不返回占位信息,只记录日志
logger.info(f"在jargon库中未找到匹配: {concept}")
# 合并所有精确匹配的日志
if exact_matches:
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n"
return ""

View File

@ -1,17 +1,16 @@
import time import time
import json import json
import re
import asyncio import asyncio
from typing import List, Dict, Any, Optional, Tuple, Set from typing import List, Dict, Any, Optional, Tuple, Set
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.apis import llm_api from src.plugin_system.apis import llm_api
from src.common.database.database_model import ThinkingBack, Jargon from src.common.database.database_model import ThinkingBack
from json_repair import repair_json
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.memory_utils import parse_questions_json
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.jargon.jargon_utils import parse_chat_id_list, chat_id_list_contains from src.jargon.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval") logger = get_logger("memory_retrieval")
@ -101,11 +100,6 @@ def init_memory_retrieval_prompt():
Prompt( Prompt(
"""你的名字是{bot_name}。现在是{time_now} """你的名字是{bot_name}。现在是{time_now}
你正在参与聊天你需要搜集信息来回答问题帮助你参与聊天 你正在参与聊天你需要搜集信息来回答问题帮助你参与聊天
**重要限制**
- 思考要简短直接切入要点
- 最大查询轮数{max_iterations}当前第{current_iteration}剩余{remaining_iterations}
当前需要解答的问题{question} 当前需要解答的问题{question}
已收集的信息 已收集的信息
{collected_info} {collected_info}
@ -118,7 +112,7 @@ def init_memory_retrieval_prompt():
- **如果当前已收集的信息足够回答问题且能找到明确答案调用found_answer工具标记已找到答案** - **如果当前已收集的信息足够回答问题且能找到明确答案调用found_answer工具标记已找到答案**
**思考** **思考**
- 你可以对查询思路给出简短的思考 - 你可以对查询思路给出简短的思考思考要简短直接切入要点
- 如果信息不足你必须给出使用什么工具进行查询 - 如果信息不足你必须给出使用什么工具进行查询
- 如果信息足够你必须调用found_answer工具 - 如果信息足够你必须调用found_answer工具
""", """,
@ -152,169 +146,6 @@ def init_memory_retrieval_prompt():
) )
def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
"""解析ReAct Agent的响应
Args:
response: LLM返回的响应
Returns:
Dict[str, Any]: 解析后的动作信息如果解析失败返回None
格式: {"thought": str, "actions": List[Dict[str, Any]]}
每个action格式: {"action_type": str, "action_params": dict}
"""
try:
# 尝试提取JSON可能包含在```json代码块中
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 修复可能的JSON错误
repaired_json = repair_json(json_str)
# 解析JSON
action_info = json.loads(repaired_json)
if not isinstance(action_info, dict):
logger.warning(f"解析的JSON不是对象格式: {action_info}")
return None
# 确保actions字段存在且为列表
if "actions" not in action_info:
logger.warning(f"响应中缺少actions字段: {action_info}")
return None
if not isinstance(action_info["actions"], list):
logger.warning(f"actions字段不是数组格式: {action_info['actions']}")
return None
# 确保actions不为空
if len(action_info["actions"]) == 0:
logger.warning("actions数组为空")
return None
return action_info
except Exception as e:
logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...")
return None
async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
"""对概念列表进行jargon检索
Args:
concepts: 概念列表
chat_id: 聊天ID
Returns:
str: 检索结果字符串
"""
if not concepts:
return ""
from src.jargon.jargon_miner import search_jargon
results = []
exact_matches = [] # 收集所有精确匹配的概念
for concept in concepts:
concept = concept.strip()
if not concept:
continue
# 先尝试精确匹配
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not jargon_results:
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if jargon_results:
# 找到结果
if is_fuzzy_match:
# 模糊匹配
output_parts = [f"未精确匹配到'{concept}'"]
for result in jargon_results:
found_content = result.get("content", "").strip()
meaning = result.get("meaning", "").strip()
if found_content and meaning:
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
results.append("".join(output_parts))
logger.info(f"在jargon库中找到匹配模糊搜索: {concept},找到{len(jargon_results)}条结果")
else:
# 精确匹配
output_parts = []
for result in jargon_results:
meaning = result.get("meaning", "").strip()
if meaning:
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
results.append("".join(output_parts) if len(output_parts) > 1 else output_parts[0])
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
else:
# 未找到,不返回占位信息,只记录日志
logger.info(f"在jargon库中未找到匹配: {concept}")
# 合并所有精确匹配的日志
if exact_matches:
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n"
return ""
def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
"""直接在聊天文本中匹配已知的jargon返回出现过的黑话列表"""
# print(chat_text)
if not chat_text or not chat_text.strip():
return []
start_time = time.time()
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
if global_config.jargon.all_global:
query = query.where(Jargon.is_global)
query = query.order_by(Jargon.count.desc())
query_time = time.time()
matched: Dict[str, None] = {}
for jargon in query:
content = (jargon.content or "").strip()
if not content:
continue
if not global_config.jargon.all_global and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id):
continue
pattern = re.escape(content)
if re.search(r"[\u4e00-\u9fff]", content):
search_pattern = pattern
else:
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, chat_text, re.IGNORECASE):
matched[content] = None
# end_time = time.time()
logger.info(
# f"记忆检索黑话匹配: 查询耗时 {(query_time - start_time):.3f}s, "
# f"匹配耗时 {(end_time - query_time):.3f}s, 总耗时 {(end_time - start_time):.3f}s, "
f"匹配到 {len(matched)} 个黑话"
)
return list(matched.keys())
def _log_conversation_messages( def _log_conversation_messages(
@ -336,7 +167,7 @@ def _log_conversation_messages(
# 如果有head_prompt先添加为第一条消息 # 如果有head_prompt先添加为第一条消息
if head_prompt: if head_prompt:
msg_info = "========================================\n[消息 1] 角色: System 内容类型: 文本\n-----------------------" msg_info = "========================================\n[消息 1] 角色: System 内容类型: 文本\n-----------------------------"
msg_info += f"\n{head_prompt}" msg_info += f"\n{head_prompt}"
log_lines.append(msg_info) log_lines.append(msg_info)
start_idx = 2 start_idx = 2
@ -363,7 +194,7 @@ def _log_conversation_messages(
content_type = "未知" content_type = "未知"
# 构建单条消息的日志信息 # 构建单条消息的日志信息
msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================" msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
if full_content: if full_content:
msg_info += f"\n{full_content}" msg_info += f"\n{full_content}"
@ -513,7 +344,7 @@ async def _react_agent_solve_question(
) )
if global_config.debug.show_memory_prompt: if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent {iteration + 1} 次迭代 最终评估Prompt: {evaluation_prompt}") logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools( eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
evaluation_prompt, evaluation_prompt,
@ -656,7 +487,7 @@ async def _react_agent_solve_question(
request_type="memory.react", request_type="memory.react",
) )
logger.info( logger.debug(
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
) )
@ -706,25 +537,19 @@ async def _react_agent_solve_question(
continue continue
# 处理工具调用 # 处理工具调用
tool_tasks = [] # 首先检查是否有found_answer工具调用如果有则立即返回不再处理其他工具
found_answer_from_tool = None # 检测是否有found_answer工具调用 found_answer_from_tool = None
for tool_call in tool_calls:
for i, tool_call in enumerate(tool_calls):
tool_name = tool_call.func_name tool_name = tool_call.func_name
tool_args = tool_call.args or {} tool_args = tool_call.args or {}
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 检查是否是found_answer工具调用
if tool_name == "found_answer": if tool_name == "found_answer":
found_answer_from_tool = tool_args.get("answer", "") found_answer_from_tool = tool_args.get("answer", "")
if found_answer_from_tool: if found_answer_from_tool:
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_from_tool}}) step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_from_tool}})
step["observations"] = ["检测到found_answer工具调用"] step["observations"] = ["检测到found_answer工具调用"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_from_tool}") logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_from_tool}")
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
@ -733,7 +558,20 @@ async def _react_agent_solve_question(
) )
return True, found_answer_from_tool, thinking_steps, False return True, found_answer_from_tool, thinking_steps, False
continue # found_answer工具不需要执行直接跳过
# 如果没有found_answer工具调用或者found_answer工具调用没有答案继续处理其他工具
tool_tasks = []
for i, tool_call in enumerate(tool_calls):
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
logger.debug(
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 跳过found_answer工具调用已经在上面处理过了
if tool_name == "found_answer":
continue
# 普通工具调用 # 普通工具调用
tool = tool_registry.get_tool(tool_name) tool = tool_registry.get_tool(tool_name)
@ -781,7 +619,7 @@ async def _react_agent_solve_question(
if stripped_observation: if stripped_observation:
# 检查工具输出中是否有新的jargon如果有则追加到工具结果中 # 检查工具输出中是否有新的jargon如果有则追加到工具结果中
if enable_jargon_detection: if enable_jargon_detection:
jargon_concepts = _match_jargon_from_text(stripped_observation, chat_id) jargon_concepts = match_jargon_from_text(stripped_observation, chat_id)
if jargon_concepts: if jargon_concepts:
new_concepts = [] new_concepts = []
for concept in jargon_concepts: for concept in jargon_concepts:
@ -790,7 +628,7 @@ async def _react_agent_solve_question(
new_concepts.append(normalized_concept) new_concepts.append(normalized_concept)
seen_jargon_concepts.add(normalized_concept) seen_jargon_concepts.add(normalized_concept)
if new_concepts: if new_concepts:
jargon_info = await _retrieve_concepts_with_jargon(new_concepts, chat_id) jargon_info = await retrieve_concepts_with_jargon(new_concepts, chat_id)
if jargon_info: if jargon_info:
# 将jargon查询结果追加到工具结果中 # 将jargon查询结果追加到工具结果中
observation_text += f"\n\n{jargon_info}" observation_text += f"\n\n{jargon_info}"
@ -828,8 +666,8 @@ async def _react_agent_solve_question(
return False, "", thinking_steps, is_timeout return False, "", thinking_steps, is_timeout
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str: def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str:
"""获取最近一段时间内的查询历史 """获取最近一段时间内的查询历史(用于避免重复查询)
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
@ -879,6 +717,49 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
return "" return ""
def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]:
"""获取最近一段时间内已找到答案的查询记录(用于返回给 replyer
Args:
chat_id: 聊天ID
time_window_seconds: 时间窗口默认10分钟
Returns:
List[str]: 格式化的答案列表每个元素格式为 "问题xxx\n答案xxx"
"""
try:
current_time = time.time()
start_time = current_time - time_window_seconds
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id)
& (ThinkingBack.update_time >= start_time)
& (ThinkingBack.found_answer == 1)
& (ThinkingBack.answer.is_null(False))
& (ThinkingBack.answer != "")
)
.order_by(ThinkingBack.update_time.desc())
.limit(3) # 最多返回5条最近的记录
)
if not records.exists():
return []
found_answers = []
for record in records:
if record.answer:
found_answers.append(f"问题:{record.question}\n答案:{record.answer}")
return found_answers
except Exception as e:
logger.error(f"获取最近已找到答案的记录失败: {e}")
return []
def _store_thinking_back( def _store_thinking_back(
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]] chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
) -> None: ) -> None:
@ -1016,8 +897,8 @@ async def build_memory_retrieval_prompt(
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
# 获取最近查询历史最近1小时内的查询) # 获取最近查询历史最近10分钟内的查询用于避免重复查询)
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0) recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0)
if not recent_query_history: if not recent_query_history:
recent_query_history = "最近没有查询记录。" recent_query_history = "最近没有查询记录。"
@ -1047,7 +928,7 @@ async def build_memory_retrieval_prompt(
return "" return ""
# 解析概念列表和问题列表 # 解析概念列表和问题列表
_, questions = _parse_questions_json(response) _, questions = parse_questions_json(response)
if questions: if questions:
logger.info(f"解析到 {len(questions)} 个问题: {questions}") logger.info(f"解析到 {len(questions)} 个问题: {questions}")
@ -1056,7 +937,7 @@ async def build_memory_retrieval_prompt(
if enable_jargon_detection: if enable_jargon_detection:
# 使用匹配逻辑自动识别聊天中的黑话概念 # 使用匹配逻辑自动识别聊天中的黑话概念
concepts = _match_jargon_from_text(message, chat_id) concepts = match_jargon_from_text(message, chat_id)
if concepts: if concepts:
logger.info(f"黑话匹配命中 {len(concepts)} 个概念: {concepts}") logger.info(f"黑话匹配命中 {len(concepts)} 个概念: {concepts}")
else: else:
@ -1067,7 +948,7 @@ async def build_memory_retrieval_prompt(
# 对匹配到的概念进行jargon检索作为初始信息 # 对匹配到的概念进行jargon检索作为初始信息
initial_info = "" initial_info = ""
if enable_jargon_detection and concepts: if enable_jargon_detection and concepts:
concept_info = await _retrieve_concepts_with_jargon(concepts, chat_id) concept_info = await retrieve_concepts_with_jargon(concepts, chat_id)
if concept_info: if concept_info:
initial_info += concept_info initial_info += concept_info
logger.debug(f"概念检索完成,结果: {concept_info}") logger.debug(f"概念检索完成,结果: {concept_info}")
@ -1107,67 +988,47 @@ async def build_memory_retrieval_prompt(
elif result is not None: elif result is not None:
question_results.append(result) question_results.append(result)
# 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
all_results = []
# 先添加当前查询的结果
current_questions = set()
for result in question_results:
# 提取问题(格式为 "问题xxx\n答案xxx"
if result.startswith("问题:"):
question_end = result.find("\n答案:")
if question_end != -1:
current_questions.add(result[4:question_end])
all_results.append(result)
# 添加缓存答案(排除当前查询中已存在的问题)
for cached_answer in cached_answers:
if cached_answer.startswith("问题:"):
question_end = cached_answer.find("\n答案:")
if question_end != -1:
cached_question = cached_answer[4:question_end]
if cached_question not in current_questions:
all_results.append(cached_answer)
end_time = time.time() end_time = time.time()
if question_results: if all_results:
retrieved_memory = "\n\n".join(question_results) retrieved_memory = "\n\n".join(all_results)
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆") current_count = len(question_results)
cached_count = len(all_results) - current_count
logger.info(
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,"
f"当前查询 {current_count} 条记忆,缓存 {cached_count} 条记忆,共 {len(all_results)} 条记忆"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else: else:
logger.debug("所有问题均未找到答案") logger.debug("所有问题均未找到答案,且无缓存答案")
return "" return ""
except Exception as e: except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}") logger.error(f"记忆检索时发生异常: {str(e)}")
return "" return ""
def _parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
Args:
response: LLM返回的响应
Returns:
Tuple[List[str], List[str]]: (概念列表, 问题列表)
"""
try:
# 尝试提取JSON可能包含在```json代码块中
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 修复可能的JSON错误
repaired_json = repair_json(json_str)
# 解析JSON
parsed = json.loads(repaired_json)
# 只支持新格式包含concepts和questions的对象
if not isinstance(parsed, dict):
logger.warning(f"解析的JSON不是对象格式: {parsed}")
return [], []
concepts_raw = parsed.get("concepts", [])
questions_raw = parsed.get("questions", [])
# 确保是列表
if not isinstance(concepts_raw, list):
concepts_raw = []
if not isinstance(questions_raw, list):
questions_raw = []
# 确保所有元素都是字符串
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
return concepts, questions
except Exception as e:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []

View File

@ -8,7 +8,8 @@ import json
import re import re
from datetime import datetime from datetime import datetime
from typing import Tuple from typing import Tuple
from difflib import SequenceMatcher from typing import List
from json_repair import repair_json
from src.common.logger import get_logger from src.common.logger import get_logger
@ -16,101 +17,56 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils") logger = get_logger("memory_utils")
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
reasoning_content = ""
# 使用正则表达式查找```json包裹的JSON内容 def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
json_pattern = r"```json\s*(.*?)\s*```" """解析问题JSON返回概念列表和问题列表
matches = re.findall(json_pattern, json_text, re.DOTALL)
# 提取JSON之前的内容作为推理文本
if matches:
# 找到第一个```json的位置
first_json_pos = json_text.find("```json")
if first_json_pos > 0:
reasoning_content = json_text[:first_json_pos].strip()
# 清理推理内容中的注释标记
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip()
for match in matches:
try:
# 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip():
json_obj = json.loads(json_str)
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except Exception as e:
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
continue
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args: Args:
text1: 第一个文本 response: LLM返回的响应
text2: 第二个文本
Returns: Returns:
float: 相似度分数 (0-1) Tuple[List[str], List[str]]: (概念列表, 问题列表)
""" """
try: try:
# 预处理文本 # 尝试提取JSON可能包含在```json代码块中
text1 = preprocess_text(text1) json_pattern = r"```json\s*(.*?)\s*```"
text2 = preprocess_text(text2) matches = re.findall(json_pattern, response, re.DOTALL)
# 使用SequenceMatcher计算相似度 if matches:
similarity = SequenceMatcher(None, text1, text2).ratio() json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 如果其中一个文本包含另一个,提高相似度 # 修复可能的JSON错误
if text1 in text2 or text2 in text1: repaired_json = repair_json(json_str)
similarity = max(similarity, 0.8)
return similarity # 解析JSON
parsed = json.loads(repaired_json)
# 只支持新格式包含concepts和questions的对象
if not isinstance(parsed, dict):
logger.warning(f"解析的JSON不是对象格式: {parsed}")
return [], []
concepts_raw = parsed.get("concepts", [])
questions_raw = parsed.get("questions", [])
# 确保是列表
if not isinstance(concepts_raw, list):
concepts_raw = []
if not isinstance(questions_raw, list):
questions_raw = []
# 确保所有元素都是字符串
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
return concepts, questions
except Exception as e: except Exception as e:
logger.error(f"计算相似度时出错: {e}") logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return 0.0 return [], []
def preprocess_text(text: str) -> str:
"""
预处理文本提高匹配准确性
Args:
text: 原始文本
Returns:
str: 预处理后的文本
"""
try:
# 转换为小写
text = text.lower()
# 移除标点符号和特殊字符
text = re.sub(r"[^\w\s]", "", text)
# 移除多余空格
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
logger.error(f"预处理文本时出错: {e}")
return text
def parse_datetime_to_timestamp(value: str) -> float: def parse_datetime_to_timestamp(value: str) -> float:
""" """
@ -140,29 +96,3 @@ def parse_datetime_to_timestamp(value: str) -> float:
except Exception as e: except Exception as e:
last_err = e last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})") raise ValueError(f"无法解析时间: {value} ({last_err})")
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串返回开始和结束时间戳
Args:
time_range: 时间范围字符串格式"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp