MaiBot/src/chat/memory_system/memory_management_task.py

172 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import asyncio
import random
from typing import List
from src.manager.async_task_manager import AsyncTask
from src.chat.memory_system.Memory_chest import global_memory_chest
from src.common.logger import get_logger
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.config.config import global_config
logger = get_logger("memory_management")
class MemoryManagementTask(AsyncTask):
"""记忆管理定时任务
根据Memory_chest中的记忆数量与MAX_MEMORY_NUMBER的比例来决定执行频率
- 小于50%每600秒执行一次
- 大于等于50%每300秒执行一次
每次执行时随机选择一个title执行choose_merge_target和merge_memory
然后删除原始记忆
"""
def __init__(self):
super().__init__(
task_name="Memory Management Task",
wait_before_start=10, # 启动后等待10秒再开始
run_interval=300 # 默认300秒间隔会根据记忆数量动态调整
)
self.max_memory_number = global_config.memory.max_memory_number
async def start_task(self, abort_flag: asyncio.Event):
"""重写start_task方法支持动态调整执行间隔"""
if self.wait_before_start > 0:
# 等待指定时间后开始任务
await asyncio.sleep(self.wait_before_start)
while not abort_flag.is_set():
await self.run()
# 动态调整执行间隔
current_interval = self._calculate_interval()
logger.info(f"[记忆管理] 下次执行间隔: {current_interval}")
if current_interval > 0:
await asyncio.sleep(current_interval)
else:
break
def _calculate_interval(self) -> int:
"""根据当前记忆数量计算执行间隔"""
try:
current_count = self._get_memory_count()
percentage = current_count / self.max_memory_number
if percentage < 0.5:
# 小于50%每600秒执行一次
return 3600
elif percentage < 0.7:
# 大于等于50%每300秒执行一次
return 1800
elif percentage < 0.9:
# 大于等于70%每120秒执行一次
return 300
elif percentage < 1.2:
return 30
else:
return 10
except Exception as e:
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
return 300 # 默认300秒
def _get_memory_count(self) -> int:
"""获取当前记忆数量"""
try:
count = MemoryChestModel.select().count()
logger.debug(f"[记忆管理] 当前记忆数量: {count}")
return count
except Exception as e:
logger.error(f"[记忆管理] 获取记忆数量时出错: {e}")
return 0
async def run(self):
"""执行记忆管理任务"""
try:
logger.info("[记忆管理] 开始执行记忆管理任务")
# 获取当前记忆数量
current_count = self._get_memory_count()
percentage = current_count / self.max_memory_number
logger.info(f"[记忆管理] 当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
# 如果记忆数量为0跳过执行
if current_count < 10:
logger.info("[记忆管理] 没有太多记忆,跳过执行")
return
# 随机选择一个记忆标题
selected_title = self._get_random_memory_title()
if not selected_title:
logger.warning("[记忆管理] 无法获取随机记忆标题,跳过执行")
return
logger.info(f"[记忆管理] 随机选择的记忆标题: {selected_title}")
# 执行choose_merge_target获取相关记忆内容
related_contents_titles = await global_memory_chest.choose_merge_target(selected_title)
if not related_contents_titles:
logger.warning("[记忆管理] 未找到相关记忆内容,跳过合并")
return
logger.info(f"[记忆管理] 找到 {len(related_contents_titles)} 条相关记忆")
# 执行merge_memory合并记忆
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents_titles)
if not merged_title or not merged_content:
logger.warning("[记忆管理] 记忆合并失败,跳过删除")
return
logger.info(f"[记忆管理] 记忆合并成功,新标题: {merged_title}")
# 删除原始记忆(包括选中的标题和相关的记忆)
deleted_count = self._delete_original_memories(related_contents_titles)
logger.info(f"[记忆管理] 已删除 {deleted_count} 条原始记忆")
logger.info("[记忆管理] 记忆管理任务完成")
except Exception as e:
logger.error(f"[记忆管理] 执行记忆管理任务时发生错误: {e}", exc_info=True)
def _get_random_memory_title(self) -> str:
"""随机获取一个记忆标题"""
try:
# 获取所有记忆标题
all_titles = global_memory_chest.get_all_titles()
if not all_titles:
return ""
# 随机选择一个标题
selected_title = random.choice(all_titles)
return selected_title
except Exception as e:
logger.error(f"[记忆管理] 获取随机记忆标题时发生错误: {e}")
return ""
def _delete_original_memories(self, related_contents: List[str]) -> int:
"""删除原始记忆"""
try:
deleted_count = 0
# 删除相关记忆(通过内容匹配)
for content in related_contents:
try:
# 通过内容查找并删除对应的记忆
memories_to_delete = MemoryChestModel.select().where(MemoryChestModel.content == content)
for memory in memories_to_delete:
MemoryChestModel.delete().where(MemoryChestModel.id == memory.id).execute()
deleted_count += 1
logger.debug(f"[记忆管理] 删除相关记忆: {memory.title}")
except Exception as e:
logger.error(f"[记忆管理] 删除相关记忆时出错: {e}")
continue
return deleted_count
except Exception as e:
logger.error(f"[记忆管理] 删除原始记忆时发生错误: {e}")
return 0