mirror of https://github.com/Mai-with-u/MaiBot.git
Update:获取记忆样本增加去重机制
parent
32c6d7ea55
commit
c84fe4f8ed
|
|
@ -105,11 +105,11 @@ def get_closest_chat_from_db(length: int, timestamp: str):
|
||||||
formatted_records = []
|
formatted_records = []
|
||||||
for record in chat_records:
|
for record in chat_records:
|
||||||
formatted_records.append({
|
formatted_records.append({
|
||||||
|
'message_id': record['_id'],
|
||||||
'time': record["time"],
|
'time': record["time"],
|
||||||
'chat_id': record["chat_id"],
|
'chat_id': record["chat_id"],
|
||||||
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
|
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
|
||||||
})
|
})
|
||||||
|
|
||||||
return formatted_records
|
return formatted_records
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -192,26 +192,52 @@ class Hippocampus:
|
||||||
"""
|
"""
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
|
added_ids = set() # 全局已添加消息ID集合,用于去重
|
||||||
# 短期:1h 中期:4h 长期:24h
|
# 短期:1h 中期:4h 长期:24h
|
||||||
for _ in range(time_frequency.get('near')):
|
for _ in range(time_frequency.get('near')):
|
||||||
random_time = current_timestamp - random.randint(1, 3600)
|
random_time = current_timestamp - random.randint(1, 3600)
|
||||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||||
if messages:
|
if messages:
|
||||||
chat_samples.append(messages)
|
# 过滤已存在的消息
|
||||||
|
filtered = []
|
||||||
|
for msg in messages:
|
||||||
|
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
|
||||||
|
if msg_id not in added_ids:
|
||||||
|
filtered.append(msg)
|
||||||
|
added_ids.add(msg_id)
|
||||||
|
|
||||||
|
if filtered:
|
||||||
|
chat_samples.append(filtered)
|
||||||
|
|
||||||
for _ in range(time_frequency.get('mid')):
|
for _ in range(time_frequency.get('mid')):
|
||||||
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
||||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||||
if messages:
|
if messages:
|
||||||
chat_samples.append(messages)
|
# 过滤已存在的消息
|
||||||
|
filtered = []
|
||||||
|
for msg in messages:
|
||||||
|
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
|
||||||
|
if msg_id not in added_ids:
|
||||||
|
filtered.append(msg)
|
||||||
|
added_ids.add(msg_id)
|
||||||
|
|
||||||
|
if filtered:
|
||||||
|
chat_samples.append(filtered)
|
||||||
|
|
||||||
for _ in range(time_frequency.get('far')):
|
for _ in range(time_frequency.get('far')):
|
||||||
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
||||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||||
if messages:
|
if messages:
|
||||||
chat_samples.append(messages)
|
# 过滤已存在的消息
|
||||||
|
filtered = []
|
||||||
|
for msg in messages:
|
||||||
|
msg_id = msg.get("message_id") # 假设消息中存在唯一标识字段
|
||||||
|
if msg_id not in added_ids:
|
||||||
|
filtered.append(msg)
|
||||||
|
added_ids.add(msg_id)
|
||||||
|
|
||||||
|
if filtered:
|
||||||
|
chat_samples.append(filtered)
|
||||||
return chat_samples
|
return chat_samples
|
||||||
|
|
||||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
async def memory_compress(self, messages: list, compress_rate=0.1):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue