UpdateL去重机制改进,新增memorized属性

pull/354/head
MuWinds 2025-03-14 09:06:27 +08:00
parent c84fe4f8ed
commit e9538096e0
3 changed files with 35 additions and 53 deletions

View File

@ -78,41 +78,47 @@ def calculate_information_content(text):
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Args:
length: 要获取的消息数量
timestamp: 时间戳
Returns:
list: 消息记录列表每个记录包含时间和文本信息
list: 消息记录字典列表每个字典包含消息内容和时间信息
"""
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
# 如果memorized属性不存在就加一个memoried并设置初始值为0
db.messages.update_many(
{"memorized": {"$exists": False}},
{"$set": {"memorized": 0}}
)
if closest_record:
closest_time = closest_record['time']
chat_id = closest_record['chat_id'] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(db.messages.find(
{
"time": {"$gt": closest_time},
"chat_id": chat_id # 添加chat_id过滤
}
chat_id = closest_record['chat_id']
# 获取该时间戳之后的length条消息且chat_id相同
records = list(db.messages.find(
{"time": {"$gt": closest_time}, "chat_id": chat_id}
).sort('time', 1).limit(length))
# 转换记录格式
formatted_records = []
for record in chat_records:
formatted_records.append({
'message_id': record['_id'],
# 更新每条消息的memorized属性
for record in records:
current_memorized = record["memorized"]
if current_memorized > 1:
print("消息已读取1次跳过")
continue
# 更新memorized值
db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
# 添加到记录列表中
chat_records.append({
'time': record["time"],
'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
'detailed_plain_text': record["detailed_plain_text"],
})
return formatted_records
return []
return chat_records
async def get_recent_group_messages(chat_id:str, limit: int = 12) -> list:

View File

@ -192,52 +192,24 @@ class Hippocampus:
"""
current_timestamp = datetime.datetime.now().timestamp()
chat_samples = []
added_ids = set() # 全局已添加消息ID集合用于去重
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
# 过滤已存在的消息
filtered = []
for msg in messages:
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
if msg_id not in added_ids:
filtered.append(msg)
added_ids.add(msg_id)
if filtered:
chat_samples.append(filtered)
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
# 过滤已存在的消息
filtered = []
for msg in messages:
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
if msg_id not in added_ids:
filtered.append(msg)
added_ids.add(msg_id)
if filtered:
chat_samples.append(filtered)
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
# 过滤已存在的消息
filtered = []
for msg in messages:
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
if msg_id not in added_ids:
filtered.append(msg)
added_ids.add(msg_id)
if filtered:
chat_samples.append(filtered)
chat_samples.append(messages)
return chat_samples
async def memory_compress(self, messages: list, compress_rate=0.1):

View File

@ -56,6 +56,10 @@ def get_closest_chat_from_db(length: int, timestamp: str):
list: 消息记录字典列表每个字典包含消息内容和时间信息
"""
chat_records = []
db.messages.update_many(
{"memorized": {"$exists": False}},
{"$set": {"memorized": 0}}
)
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4: