fix:时记忆提取更精确

pull/1182/head
SengokuCola 2025-08-14 00:02:50 +08:00
parent ef7a3aee23
commit 3bf476c610
4 changed files with 177 additions and 31 deletions

View File

@ -1,15 +1,17 @@
import json
from json_repair import repair_json
from typing import List, Dict
from typing import List, Tuple
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.utils.utils import parse_keywords_string
from src.chat.utils.chat_message_builder import build_readable_messages
import random
logger = get_logger("memory_activator")
@ -40,20 +42,20 @@ def get_keywords_from_json(json_str) -> List:
def init_prompt():
# --- Group Chat Prompt ---
memory_activator_prompt = """
是一个记忆分析器你需要根据以下信息来进行回忆
以下是一段聊天记录请根据这些信息总结出几个关键词作为记忆回忆的触发词
需要根据以下信息来挑选合适的记忆编号
以下是一段聊天记录请根据这些信息和下方的记忆挑选和群聊内容有关的记忆编号
聊天记录:
{obs_info_text}
你想要回复的消息:
{target_message}
历史关键词请避免重复提取这些关键词
{cached_keywords}
记忆
{memory_info}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
}}
不要输出其他多余内容只输出json格式就好
"""
@ -67,9 +69,14 @@ class MemoryActivator:
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
# 用于记忆选择的 LLM 模型
self.memory_selection_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.selection",
)
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
"""
激活记忆
"""
@ -83,24 +90,172 @@ class MemoryActivator:
keywords = parse_keywords_string(msg.get("key_words", ""))
if keywords:
if len(keywords_list) < 30:
# 最多容纳30个关键词
# 最多容纳30个关键词
keywords_list.update(keywords)
print(keywords_list)
logger.debug(f"提取关键词: {keywords_list}")
else:
break
if not keywords_list:
logger.debug("没有提取到关键词,返回空记忆列表")
return []
# 从海马体获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
)
logger.info(f"当前记忆关键词: {keywords_list} ")
logger.info(f"当前记忆关键词: {keywords_list}")
logger.info(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
return []
used_ids = set()
candidate_memories = []
# 为每个记忆分配随机ID并过滤相关记忆
for memory in related_memory:
keyword, content = memory
found = False
for kw in keywords_list:
if kw in content:
found = True
break
if found:
# 随机分配一个不重复的2位数id
while True:
random_id = "{:02d}".format(random.randint(0, 99))
if random_id not in used_ids:
used_ids.add(random_id)
break
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
if not candidate_memories:
logger.info("没有找到相关的候选记忆")
return []
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.info(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
# 使用 LLM 选择合适的记忆
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
return selected_memories
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
"""
使用 LLM 选择合适的记忆
Args:
target_message: 目标消息
chat_history_prompt: 聊天历史
candidate_memories: 候选记忆列表每个记忆包含 memory_idkeywordcontent
Returns:
List[Tuple[str, str]]: 选择的记忆列表格式为 (keyword, content)
"""
try:
# 构建聊天历史字符串
obs_info_text = build_readable_messages(
chat_history_prompt,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 构建记忆信息字符串
memory_lines = []
for memory in candidate_memories:
memory_id = memory["memory_id"]
keyword = memory["keyword"]
content = memory["content"]
# 将 content 列表转换为字符串
if isinstance(content, list):
content_str = " | ".join(str(item) for item in content)
else:
content_str = str(content)
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
memory_info = "\n".join(memory_lines)
# 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text,
target_message=target_message,
memory_info=memory_info
)
# 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt,
temperature=0.3,
max_tokens=150
)
if global_config.debug.show_prompt:
logger.info(f"记忆选择 prompt: {formatted_prompt}")
logger.info(f"LLM 记忆选择响应: {response}")
else:
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
logger.debug(f"LLM 记忆选择响应: {response}")
# 解析响应获取选择的记忆编号
try:
fixed_json = repair_json(response)
# 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段
memory_ids_str = result.get("memory_ids", "")
# 解析逗号分隔的编号
if memory_ids_str:
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
selected_memory_ids = valid_memory_ids
else:
selected_memory_ids = []
except Exception as e:
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
selected_memory_ids = []
# 根据编号筛选记忆
selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
for memory_id in selected_memory_ids:
if memory_id in memory_id_to_memory:
selected_memories.append(memory_id_to_memory[memory_id])
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
except Exception as e:
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
# 出错时返回前3个候选记忆作为备选转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
return related_memory
init_prompt()

View File

@ -356,15 +356,6 @@ class DefaultReplyer:
Returns:
str: 记忆信息字符串
"""
chat_talking_prompt_short = build_readable_messages(
chat_history,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
if not global_config.memory.enable_memory:
return ""

View File

@ -399,7 +399,7 @@ class MessageReceiveConfig(ConfigBase):
class ExpressionConfig(ConfigBase):
"""表达配置类"""
expression_learning: list[list] = field(default_factory=lambda: [])
learning_list: list[list] = field(default_factory=lambda: [])
"""
表达学习配置列表支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
@ -469,7 +469,7 @@ class ExpressionConfig(ConfigBase):
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)
"""
if not self.expression_learning:
if not self.learning_list:
# 如果没有配置使用默认值启用表达启用学习300秒间隔
return True, True, 300
@ -497,7 +497,7 @@ class ExpressionConfig(ConfigBase):
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)如果没有配置则返回 None
"""
for config_item in self.expression_learning:
for config_item in self.learning_list:
if not config_item or len(config_item) < 4:
continue
@ -534,7 +534,7 @@ class ExpressionConfig(ConfigBase):
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)如果没有配置则返回 None
"""
for config_item in self.expression_learning:
for config_item in self.learning_list:
if not config_item or len(config_item) < 4:
continue

View File

@ -1,5 +1,5 @@
[inner]
version = "6.4.0"
version = "6.4.2"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@ -34,10 +34,10 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
[expression]
# 表达学习配置
expression_learning = [ # 表达学习配置列表,支持按聊天流配置
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:group", "enable", "enable", 1.5], # 特定群聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5
learning_list = [ # 表达学习配置列表,支持按聊天流配置
["", "enable", "enable", "1.0"], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:group", "enable", "enable", "1.5"], # 特定群聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", "0.5"], # 特定私聊配置使用表达禁用学习学习强度0.5
# 格式说明:
# 第一位: chat_stream_id空字符串表示全局配置
# 第二位: 是否使用学到的表达 ("enable"/"disable")