mirror of https://github.com/Mai-with-u/MaiBot.git
363 lines
14 KiB
Python
363 lines
14 KiB
Python
"""
|
||
记忆遗忘任务
|
||
每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆
|
||
"""
|
||
|
||
import time
|
||
import random
|
||
from typing import List
|
||
|
||
from src.common.logger import get_logger
|
||
from src.common.database.database_model import ChatHistory
|
||
from src.manager.async_task_manager import AsyncTask
|
||
|
||
logger = get_logger("memory_forget_task")
|
||
|
||
|
||
class MemoryForgetTask(AsyncTask):
|
||
"""记忆遗忘任务,每5分钟执行一次"""
|
||
|
||
def __init__(self):
|
||
# 每5分钟执行一次(300秒)
|
||
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
|
||
|
||
async def run(self):
|
||
"""执行遗忘检查"""
|
||
try:
|
||
current_time = time.time()
|
||
# logger.info("[记忆遗忘] 开始遗忘检查...")
|
||
|
||
# 执行4个阶段的遗忘检查
|
||
# await self._forget_stage_1(current_time)
|
||
# await self._forget_stage_2(current_time)
|
||
# await self._forget_stage_3(current_time)
|
||
# await self._forget_stage_4(current_time)
|
||
|
||
# logger.info("[记忆遗忘] 遗忘检查完成")
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
|
||
|
||
async def _forget_stage_1(self, current_time: float):
|
||
"""
|
||
第一次遗忘检查:
|
||
搜集所有:记忆还未被遗忘检查过(forget_times=0),且已经是30分钟之外的记忆
|
||
取count最高25%和最低25%,删除,然后标记被遗忘检查次数为1
|
||
"""
|
||
try:
|
||
# 30分钟 = 1800秒
|
||
time_threshold = current_time - 1800
|
||
|
||
# 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold
|
||
candidates = list(
|
||
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
|
||
)
|
||
|
||
if not candidates:
|
||
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
|
||
return
|
||
|
||
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
|
||
|
||
# 按count排序
|
||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||
|
||
# 计算要删除的数量(最高25%和最低25%)
|
||
total_count = len(candidates)
|
||
delete_count = int(total_count * 0.25) # 25%
|
||
|
||
if delete_count == 0:
|
||
logger.debug("[记忆遗忘-阶段1] 删除数量为0,跳过")
|
||
return
|
||
|
||
# 选择要删除的记录(处理count相同的情况:随机选择)
|
||
to_delete = []
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||
|
||
# 去重(避免重复删除),使用id去重
|
||
seen_ids = set()
|
||
unique_to_delete = []
|
||
for record in to_delete:
|
||
if record.id not in seen_ids:
|
||
seen_ids.add(record.id)
|
||
unique_to_delete.append(record)
|
||
to_delete = unique_to_delete
|
||
|
||
# 删除记录并更新forget_times
|
||
deleted_count = 0
|
||
for record in to_delete:
|
||
try:
|
||
record.delete_instance()
|
||
deleted_count += 1
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
|
||
|
||
# 更新剩余记录的forget_times为1
|
||
to_delete_ids = {r.id for r in to_delete}
|
||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||
if remaining:
|
||
# 批量更新
|
||
ids_to_update = [r.id for r in remaining]
|
||
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||
|
||
logger.info(
|
||
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
|
||
|
||
async def _forget_stage_2(self, current_time: float):
|
||
"""
|
||
第二次遗忘检查:
|
||
搜集所有:记忆遗忘检查为1,且已经是8小时之外的记忆
|
||
取count最高7%和最低7%,删除,然后标记被遗忘检查次数为2
|
||
"""
|
||
try:
|
||
# 8小时 = 28800秒
|
||
time_threshold = current_time - 28800
|
||
|
||
# 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold
|
||
candidates = list(
|
||
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
|
||
)
|
||
|
||
if not candidates:
|
||
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
|
||
return
|
||
|
||
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
|
||
|
||
# 按count排序
|
||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||
|
||
# 计算要删除的数量(最高7%和最低7%)
|
||
total_count = len(candidates)
|
||
delete_count = int(total_count * 0.07) # 7%
|
||
|
||
if delete_count == 0:
|
||
logger.debug("[记忆遗忘-阶段2] 删除数量为0,跳过")
|
||
return
|
||
|
||
# 选择要删除的记录
|
||
to_delete = []
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||
|
||
# 去重
|
||
to_delete = list(set(to_delete))
|
||
|
||
# 删除记录
|
||
deleted_count = 0
|
||
for record in to_delete:
|
||
try:
|
||
record.delete_instance()
|
||
deleted_count += 1
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
|
||
|
||
# 更新剩余记录的forget_times为2
|
||
to_delete_ids = {r.id for r in to_delete}
|
||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||
if remaining:
|
||
ids_to_update = [r.id for r in remaining]
|
||
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||
|
||
logger.info(
|
||
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
|
||
|
||
async def _forget_stage_3(self, current_time: float):
|
||
"""
|
||
第三次遗忘检查:
|
||
搜集所有:记忆遗忘检查为2,且已经是48小时之外的记忆
|
||
取count最高5%和最低5%,删除,然后标记被遗忘检查次数为3
|
||
"""
|
||
try:
|
||
# 48小时 = 172800秒
|
||
time_threshold = current_time - 172800
|
||
|
||
# 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold
|
||
candidates = list(
|
||
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
|
||
)
|
||
|
||
if not candidates:
|
||
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
|
||
return
|
||
|
||
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
|
||
|
||
# 按count排序
|
||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||
|
||
# 计算要删除的数量(最高5%和最低5%)
|
||
total_count = len(candidates)
|
||
delete_count = int(total_count * 0.05) # 5%
|
||
|
||
if delete_count == 0:
|
||
logger.debug("[记忆遗忘-阶段3] 删除数量为0,跳过")
|
||
return
|
||
|
||
# 选择要删除的记录
|
||
to_delete = []
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||
|
||
# 去重
|
||
to_delete = list(set(to_delete))
|
||
|
||
# 删除记录
|
||
deleted_count = 0
|
||
for record in to_delete:
|
||
try:
|
||
record.delete_instance()
|
||
deleted_count += 1
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
|
||
|
||
# 更新剩余记录的forget_times为3
|
||
to_delete_ids = {r.id for r in to_delete}
|
||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||
if remaining:
|
||
ids_to_update = [r.id for r in remaining]
|
||
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||
|
||
logger.info(
|
||
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
|
||
|
||
async def _forget_stage_4(self, current_time: float):
|
||
"""
|
||
第四次遗忘检查:
|
||
搜集所有:记忆遗忘检查为3,且已经是7天之外的记忆
|
||
取count最高2%和最低2%,删除,然后标记被遗忘检查次数为4
|
||
"""
|
||
try:
|
||
# 7天 = 604800秒
|
||
time_threshold = current_time - 604800
|
||
|
||
# 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold
|
||
candidates = list(
|
||
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
|
||
)
|
||
|
||
if not candidates:
|
||
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
|
||
return
|
||
|
||
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
|
||
|
||
# 按count排序
|
||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||
|
||
# 计算要删除的数量(最高2%和最低2%)
|
||
total_count = len(candidates)
|
||
delete_count = int(total_count * 0.02) # 2%
|
||
|
||
if delete_count == 0:
|
||
logger.debug("[记忆遗忘-阶段4] 删除数量为0,跳过")
|
||
return
|
||
|
||
# 选择要删除的记录
|
||
to_delete = []
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||
|
||
# 去重
|
||
to_delete = list(set(to_delete))
|
||
|
||
# 删除记录
|
||
deleted_count = 0
|
||
for record in to_delete:
|
||
try:
|
||
record.delete_instance()
|
||
deleted_count += 1
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
|
||
|
||
# 更新剩余记录的forget_times为4
|
||
to_delete_ids = {r.id for r in to_delete}
|
||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||
if remaining:
|
||
ids_to_update = [r.id for r in remaining]
|
||
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||
|
||
logger.info(
|
||
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
|
||
|
||
def _handle_same_count_random(
|
||
self, candidates: List[ChatHistory], delete_count: int, mode: str
|
||
) -> List[ChatHistory]:
|
||
"""
|
||
处理count相同的情况,随机选择要删除的记录
|
||
|
||
Args:
|
||
candidates: 候选记录列表(已按count排序)
|
||
delete_count: 要删除的数量
|
||
mode: "high" 表示选择最高count的记录,"low" 表示选择最低count的记录
|
||
|
||
Returns:
|
||
要删除的记录列表
|
||
"""
|
||
if not candidates or delete_count == 0:
|
||
return []
|
||
|
||
to_delete = []
|
||
|
||
if mode == "high":
|
||
# 从最高count开始选择
|
||
start_idx = 0
|
||
while start_idx < len(candidates) and len(to_delete) < delete_count:
|
||
# 找到所有count相同的记录
|
||
current_count = candidates[start_idx].count
|
||
same_count_records = []
|
||
idx = start_idx
|
||
while idx < len(candidates) and candidates[idx].count == current_count:
|
||
same_count_records.append(candidates[idx])
|
||
idx += 1
|
||
|
||
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||
needed = delete_count - len(to_delete)
|
||
if len(same_count_records) <= needed:
|
||
to_delete.extend(same_count_records)
|
||
else:
|
||
# 随机选择需要的数量
|
||
to_delete.extend(random.sample(same_count_records, needed))
|
||
|
||
start_idx = idx
|
||
|
||
else: # mode == "low"
|
||
# 从最低count开始选择
|
||
start_idx = len(candidates) - 1
|
||
while start_idx >= 0 and len(to_delete) < delete_count:
|
||
# 找到所有count相同的记录
|
||
current_count = candidates[start_idx].count
|
||
same_count_records = []
|
||
idx = start_idx
|
||
while idx >= 0 and candidates[idx].count == current_count:
|
||
same_count_records.append(candidates[idx])
|
||
idx -= 1
|
||
|
||
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||
needed = delete_count - len(to_delete)
|
||
if len(same_count_records) <= needed:
|
||
to_delete.extend(same_count_records)
|
||
else:
|
||
# 随机选择需要的数量
|
||
to_delete.extend(random.sample(same_count_records, needed))
|
||
|
||
start_idx = idx
|
||
|
||
return to_delete
|