mirror of https://github.com/Mai-with-u/MaiBot.git
fix:时记忆提取更精确
parent
ef7a3aee23
commit
3bf476c610
|
|
@ -1,15 +1,17 @@
|
|||
import json
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Dict
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
import random
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
|
@ -40,20 +42,20 @@ def get_keywords_from_json(json_str) -> List:
|
|||
def init_prompt():
|
||||
# --- Group Chat Prompt ---
|
||||
memory_activator_prompt = """
|
||||
你是一个记忆分析器,你需要根据以下信息来进行回忆
|
||||
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
|
||||
你需要根据以下信息来挑选合适的记忆编号
|
||||
以下是一段聊天记录,请根据这些信息,和下方的记忆,挑选和群聊内容有关的记忆编号
|
||||
|
||||
聊天记录:
|
||||
{obs_info_text}
|
||||
你想要回复的消息:
|
||||
{target_message}
|
||||
|
||||
历史关键词(请避免重复提取这些关键词):
|
||||
{cached_keywords}
|
||||
记忆:
|
||||
{memory_info}
|
||||
|
||||
请输出一个json格式,包含以下字段:
|
||||
{{
|
||||
"keywords": ["关键词1", "关键词2", "关键词3",......]
|
||||
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
|
||||
}}
|
||||
不要输出其他多余内容,只输出json格式就好
|
||||
"""
|
||||
|
|
@ -67,9 +69,14 @@ class MemoryActivator:
|
|||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.activator",
|
||||
)
|
||||
# 用于记忆选择的 LLM 模型
|
||||
self.memory_selection_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.selection",
|
||||
)
|
||||
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
|
|
@ -83,24 +90,172 @@ class MemoryActivator:
|
|||
keywords = parse_keywords_string(msg.get("key_words", ""))
|
||||
if keywords:
|
||||
if len(keywords_list) < 30:
|
||||
# 最多容纳30个关键词
|
||||
# 最多容纳30个关键词
|
||||
keywords_list.update(keywords)
|
||||
print(keywords_list)
|
||||
logger.debug(f"提取关键词: {keywords_list}")
|
||||
else:
|
||||
break
|
||||
|
||||
if not keywords_list:
|
||||
logger.debug("没有提取到关键词,返回空记忆列表")
|
||||
return []
|
||||
|
||||
# 从海马体获取相关记忆
|
||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"当前记忆关键词: {keywords_list} ")
|
||||
logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.info(f"获取到的记忆: {related_memory}")
|
||||
|
||||
if not related_memory:
|
||||
logger.debug("海马体没有返回相关记忆")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
used_ids = set()
|
||||
candidate_memories = []
|
||||
|
||||
# 为每个记忆分配随机ID并过滤相关记忆
|
||||
for memory in related_memory:
|
||||
keyword, content = memory
|
||||
found = False
|
||||
for kw in keywords_list:
|
||||
if kw in content:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# 随机分配一个不重复的2位数id
|
||||
while True:
|
||||
random_id = "{:02d}".format(random.randint(0, 99))
|
||||
if random_id not in used_ids:
|
||||
used_ids.add(random_id)
|
||||
break
|
||||
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
|
||||
|
||||
if not candidate_memories:
|
||||
logger.info("没有找到相关的候选记忆")
|
||||
return []
|
||||
|
||||
# 如果只有少量记忆,直接返回
|
||||
if len(candidate_memories) <= 2:
|
||||
logger.info(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||
|
||||
# 使用 LLM 选择合适的记忆
|
||||
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
|
||||
|
||||
return selected_memories
|
||||
|
||||
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
使用 LLM 选择合适的记忆
|
||||
|
||||
Args:
|
||||
target_message: 目标消息
|
||||
chat_history_prompt: 聊天历史
|
||||
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
|
||||
"""
|
||||
try:
|
||||
# 构建聊天历史字符串
|
||||
obs_info_text = build_readable_messages(
|
||||
chat_history_prompt,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
# 构建记忆信息字符串
|
||||
memory_lines = []
|
||||
for memory in candidate_memories:
|
||||
memory_id = memory["memory_id"]
|
||||
keyword = memory["keyword"]
|
||||
content = memory["content"]
|
||||
|
||||
# 将 content 列表转换为字符串
|
||||
if isinstance(content, list):
|
||||
content_str = " | ".join(str(item) for item in content)
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
|
||||
|
||||
memory_info = "\n".join(memory_lines)
|
||||
|
||||
# 获取并格式化 prompt
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
||||
formatted_prompt = prompt_template.format(
|
||||
obs_info_text=obs_info_text,
|
||||
target_message=target_message,
|
||||
memory_info=memory_info
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 调用 LLM
|
||||
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
||||
formatted_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=150
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.info(f"LLM 记忆选择响应: {response}")
|
||||
else:
|
||||
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.debug(f"LLM 记忆选择响应: {response}")
|
||||
|
||||
# 解析响应获取选择的记忆编号
|
||||
try:
|
||||
fixed_json = repair_json(response)
|
||||
|
||||
# 解析为 Python 对象
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
|
||||
# 提取 memory_ids 字段
|
||||
memory_ids_str = result.get("memory_ids", "")
|
||||
|
||||
# 解析逗号分隔的编号
|
||||
if memory_ids_str:
|
||||
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
||||
# 过滤掉空字符串和无效编号
|
||||
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
||||
selected_memory_ids = valid_memory_ids
|
||||
else:
|
||||
selected_memory_ids = []
|
||||
except Exception as e:
|
||||
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
|
||||
selected_memory_ids = []
|
||||
|
||||
# 根据编号筛选记忆
|
||||
selected_memories = []
|
||||
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
||||
|
||||
for memory_id in selected_memory_ids:
|
||||
if memory_id in memory_id_to_memory:
|
||||
selected_memories.append(memory_id_to_memory[memory_id])
|
||||
|
||||
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
||||
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
||||
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
|
||||
# 出错时返回前3个候选记忆作为备选,转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
||||
|
||||
return related_memory
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
|
|
|||
|
|
@ -356,15 +356,6 @@ class DefaultReplyer:
|
|||
Returns:
|
||||
str: 记忆信息字符串
|
||||
"""
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
chat_history,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
if not global_config.memory.enable_memory:
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -399,7 +399,7 @@ class MessageReceiveConfig(ConfigBase):
|
|||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
expression_learning: list[list] = field(default_factory=lambda: [])
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||
|
|
@ -469,7 +469,7 @@ class ExpressionConfig(ConfigBase):
|
|||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||
"""
|
||||
if not self.expression_learning:
|
||||
if not self.learning_list:
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||
return True, True, 300
|
||||
|
||||
|
|
@ -497,7 +497,7 @@ class ExpressionConfig(ConfigBase):
|
|||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.expression_learning:
|
||||
for config_item in self.learning_list:
|
||||
if not config_item or len(config_item) < 4:
|
||||
continue
|
||||
|
||||
|
|
@ -534,7 +534,7 @@ class ExpressionConfig(ConfigBase):
|
|||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.expression_learning:
|
||||
for config_item in self.learning_list:
|
||||
if not config_item or len(config_item) < 4:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[inner]
|
||||
version = "6.4.0"
|
||||
version = "6.4.2"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
|
|
@ -34,10 +34,10 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
|
|||
|
||||
[expression]
|
||||
# 表达学习配置
|
||||
expression_learning = [ # 表达学习配置列表,支持按聊天流配置
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:group", "enable", "enable", 1.5], # 特定群聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
learning_list = [ # 表达学习配置列表,支持按聊天流配置
|
||||
["", "enable", "enable", "1.0"], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:group", "enable", "enable", "1.5"], # 特定群聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", "0.5"], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
# 格式说明:
|
||||
# 第一位: chat_stream_id,空字符串表示全局配置
|
||||
# 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
|
|
|
|||
Loading…
Reference in New Issue