Update:获取记忆样本增加去重机制

pull/354/head
MuWinds 2025-03-13 23:55:09 +08:00
parent 32c6d7ea55
commit c84fe4f8ed
2 changed files with 31 additions and 5 deletions

View File

@ -105,11 +105,11 @@ def get_closest_chat_from_db(length: int, timestamp: str):
formatted_records = []
for record in chat_records:
formatted_records.append({
'message_id': record['_id'],
'time': record["time"],
'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
})
return formatted_records
return []

View File

@ -192,26 +192,52 @@ class Hippocampus:
"""
current_timestamp = datetime.datetime.now().timestamp()
chat_samples = []
added_ids = set() # 全局已添加消息ID集合用于去重
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
# 过滤已存在的消息
filtered = []
for msg in messages:
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
if msg_id not in added_ids:
filtered.append(msg)
added_ids.add(msg_id)
if filtered:
chat_samples.append(filtered)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
# 过滤已存在的消息
filtered = []
for msg in messages:
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
if msg_id not in added_ids:
filtered.append(msg)
added_ids.add(msg_id)
if filtered:
chat_samples.append(filtered)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
# 过滤已存在的消息
filtered = []
for msg in messages:
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
if msg_id not in added_ids:
filtered.append(msg)
added_ids.add(msg_id)
if filtered:
chat_samples.append(filtered)
return chat_samples
async def memory_compress(self, messages: list, compress_rate=0.1):