MaiBot/src/hippo_memorizer/memory_forget_task.py

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
记忆遗忘任务
每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆
"""
import time
import random
from typing import List
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.manager.async_task_manager import AsyncTask
logger = get_logger("memory_forget_task")
class MemoryForgetTask(AsyncTask):
"""记忆遗忘任务每5分钟执行一次"""
def __init__(self):
# 每5分钟执行一次300秒
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
async def run(self):
"""执行遗忘检查"""
try:
current_time = time.time()
# logger.info("[记忆遗忘] 开始遗忘检查...")
# 执行4个阶段的遗忘检查
# await self._forget_stage_1(current_time)
# await self._forget_stage_2(current_time)
# await self._forget_stage_3(current_time)
# await self._forget_stage_4(current_time)
# logger.info("[记忆遗忘] 遗忘检查完成")
except Exception as e:
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
async def _forget_stage_1(self, current_time: float):
"""
第一次遗忘检查:
搜集所有记忆还未被遗忘检查过forget_times=0且已经是30分钟之外的记忆
取count最高25%和最低25%删除然后标记被遗忘检查次数为1
"""
try:
# 30分钟 = 1800秒
time_threshold = current_time - 1800
# 查询符合条件的记忆forget_times=0 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高25%和最低25%
total_count = len(candidates)
delete_count = int(total_count * 0.25) # 25%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段1] 删除数量为0跳过")
return
# 选择要删除的记录处理count相同的情况随机选择
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重避免重复删除使用id去重
seen_ids = set()
unique_to_delete = []
for record in to_delete:
if record.id not in seen_ids:
seen_ids.add(record.id)
unique_to_delete.append(record)
to_delete = unique_to_delete
# 删除记录并更新forget_times
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
# 更新剩余记录的forget_times为1
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
# 批量更新
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
async def _forget_stage_2(self, current_time: float):
"""
第二次遗忘检查:
搜集所有记忆遗忘检查为1且已经是8小时之外的记忆
取count最高7%和最低7%删除然后标记被遗忘检查次数为2
"""
try:
# 8小时 = 28800秒
time_threshold = current_time - 28800
# 查询符合条件的记忆forget_times=1 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高7%和最低7%
total_count = len(candidates)
delete_count = int(total_count * 0.07) # 7%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段2] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
# 更新剩余记录的forget_times为2
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
async def _forget_stage_3(self, current_time: float):
"""
第三次遗忘检查:
搜集所有记忆遗忘检查为2且已经是48小时之外的记忆
取count最高5%和最低5%删除然后标记被遗忘检查次数为3
"""
try:
# 48小时 = 172800秒
time_threshold = current_time - 172800
# 查询符合条件的记忆forget_times=2 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高5%和最低5%
total_count = len(candidates)
delete_count = int(total_count * 0.05) # 5%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段3] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
# 更新剩余记录的forget_times为3
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
async def _forget_stage_4(self, current_time: float):
"""
第四次遗忘检查:
搜集所有记忆遗忘检查为3且已经是7天之外的记忆
取count最高2%和最低2%删除然后标记被遗忘检查次数为4
"""
try:
# 7天 = 604800秒
time_threshold = current_time - 604800
# 查询符合条件的记忆forget_times=3 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高2%和最低2%
total_count = len(candidates)
delete_count = int(total_count * 0.02) # 2%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段4] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
# 更新剩余记录的forget_times为4
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
def _handle_same_count_random(
self, candidates: List[ChatHistory], delete_count: int, mode: str
) -> List[ChatHistory]:
"""
处理count相同的情况随机选择要删除的记录
Args:
candidates: 候选记录列表已按count排序
delete_count: 要删除的数量
mode: "high" 表示选择最高count的记录"low" 表示选择最低count的记录
Returns:
要删除的记录列表
"""
if not candidates or delete_count == 0:
return []
to_delete = []
if mode == "high":
# 从最高count开始选择
start_idx = 0
while start_idx < len(candidates) and len(to_delete) < delete_count:
# 找到所有count相同的记录
current_count = candidates[start_idx].count
same_count_records = []
idx = start_idx
while idx < len(candidates) and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx += 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
to_delete.extend(same_count_records)
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
else: # mode == "low"
# 从最低count开始选择
start_idx = len(candidates) - 1
while start_idx >= 0 and len(to_delete) < delete_count:
# 找到所有count相同的记录
current_count = candidates[start_idx].count
same_count_records = []
idx = start_idx
while idx >= 0 and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx -= 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
to_delete.extend(same_count_records)
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
return to_delete