mirror of https://github.com/Mai-with-u/MaiBot.git
尝试增加对群聊的记录隔离
parent
1ac5c225af
commit
5b13754504
|
|
@ -91,12 +91,20 @@ class PromptBuilder:
|
||||||
memory_prompt = ''
|
memory_prompt = ''
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 调用 hippocampus 的 get_relevant_memories 方法
|
# 获取群聊ID
|
||||||
|
stream_group_id = None
|
||||||
|
if stream_id:
|
||||||
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
|
if chat_stream and chat_stream.group_info:
|
||||||
|
stream_group_id = str(chat_stream.group_info.group_id)
|
||||||
|
|
||||||
|
# 调用 hippocampus 的 get_relevant_memories 方法,添加群聊ID参数
|
||||||
relevant_memories = await hippocampus.get_relevant_memories(
|
relevant_memories = await hippocampus.get_relevant_memories(
|
||||||
text=message_txt,
|
text=message_txt,
|
||||||
max_topics=5,
|
max_topics=5,
|
||||||
similarity_threshold=0.4,
|
similarity_threshold=0.4,
|
||||||
max_memory_num=5
|
max_memory_num=5,
|
||||||
|
group_id=stream_group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if relevant_memories:
|
if relevant_memories:
|
||||||
|
|
|
||||||
|
|
@ -13,15 +13,21 @@ class MessageStorage:
|
||||||
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
|
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
try:
|
try:
|
||||||
|
# 提取群组ID信息,如果存在的话
|
||||||
|
group_id = None
|
||||||
|
if chat_stream.group_info:
|
||||||
|
group_id = str(chat_stream.group_info.group_id)
|
||||||
|
|
||||||
message_data = {
|
message_data = {
|
||||||
"message_id": message.message_info.message_id,
|
"message_id": message.message_info.message_id,
|
||||||
"time": message.message_info.time,
|
"time": message.message_info.time,
|
||||||
"chat_id":chat_stream.stream_id,
|
"chat_id": chat_stream.stream_id,
|
||||||
"chat_info": chat_stream.to_dict(),
|
"chat_info": chat_stream.to_dict(),
|
||||||
"user_info": message.message_info.user_info.to_dict(),
|
"user_info": message.message_info.user_info.to_dict(),
|
||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
|
"group_id": group_id, # 显式添加group_id字段
|
||||||
}
|
}
|
||||||
self.db.db.messages.insert_one(message_data)
|
self.db.db.messages.insert_one(message_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -104,11 +104,20 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
# 转换记录格式
|
# 转换记录格式
|
||||||
formatted_records = []
|
formatted_records = []
|
||||||
for record in chat_records:
|
for record in chat_records:
|
||||||
formatted_records.append({
|
formatted_record = {
|
||||||
'time': record["time"],
|
'time': record["time"],
|
||||||
'chat_id': record["chat_id"],
|
'chat_id': record["chat_id"],
|
||||||
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
|
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
|
||||||
})
|
}
|
||||||
|
|
||||||
|
# 添加group_id信息,如果存在
|
||||||
|
if 'group_id' in record:
|
||||||
|
formatted_record['group_id'] = record['group_id']
|
||||||
|
elif 'chat_info' in record and 'group_info' in record['chat_info'] and record['chat_info']['group_info']:
|
||||||
|
# 从chat_info中提取group_id
|
||||||
|
formatted_record['group_id'] = record['chat_info']['group_info'].get('group_id')
|
||||||
|
|
||||||
|
formatted_records.append(formatted_record)
|
||||||
|
|
||||||
return formatted_records
|
return formatted_records
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,19 @@ class Memory_graph:
|
||||||
created_time=current_time, # 添加创建时间
|
created_time=current_time, # 添加创建时间
|
||||||
last_modified=current_time) # 添加最后修改时间
|
last_modified=current_time) # 添加最后修改时间
|
||||||
|
|
||||||
def add_dot(self, concept, memory):
|
def add_dot(self, concept, memory, group_id=None):
|
||||||
current_time = datetime.datetime.now().timestamp()
|
current_time = datetime.datetime.now().timestamp()
|
||||||
|
|
||||||
|
# 如果memory不是字典格式,将其转换为字典
|
||||||
|
if not isinstance(memory, dict):
|
||||||
|
memory = {
|
||||||
|
'content': memory,
|
||||||
|
'group_id': group_id
|
||||||
|
}
|
||||||
|
# 如果memory是字典但没有group_id,添加group_id
|
||||||
|
elif 'group_id' not in memory and group_id is not None:
|
||||||
|
memory['group_id'] = group_id
|
||||||
|
|
||||||
if concept in self.G:
|
if concept in self.G:
|
||||||
if 'memory_items' in self.G.nodes[concept]:
|
if 'memory_items' in self.G.nodes[concept]:
|
||||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||||
|
|
@ -218,6 +228,13 @@ class Hippocampus:
|
||||||
if not messages:
|
if not messages:
|
||||||
return set(), {}
|
return set(), {}
|
||||||
|
|
||||||
|
# 提取群聊ID信息
|
||||||
|
group_id = None
|
||||||
|
for msg in messages:
|
||||||
|
if 'group_id' in msg and msg['group_id']:
|
||||||
|
group_id = msg['group_id']
|
||||||
|
break
|
||||||
|
|
||||||
# 合并消息文本,同时保留时间信息
|
# 合并消息文本,同时保留时间信息
|
||||||
input_text = ""
|
input_text = ""
|
||||||
time_info = ""
|
time_info = ""
|
||||||
|
|
@ -267,7 +284,13 @@ class Hippocampus:
|
||||||
for topic, task in tasks:
|
for topic, task in tasks:
|
||||||
response = await task
|
response = await task
|
||||||
if response:
|
if response:
|
||||||
compressed_memory.add((topic, response[0]))
|
# 使用字典结构存储记忆内容与群组ID
|
||||||
|
memory_content = {
|
||||||
|
'content': response[0],
|
||||||
|
'group_id': group_id
|
||||||
|
}
|
||||||
|
compressed_memory.add((topic, memory_content))
|
||||||
|
|
||||||
# 为每个话题查找相似的已存在主题
|
# 为每个话题查找相似的已存在主题
|
||||||
existing_topics = list(self.memory_graph.G.nodes())
|
existing_topics = list(self.memory_graph.G.nodes())
|
||||||
similar_topics = []
|
similar_topics = []
|
||||||
|
|
@ -315,6 +338,11 @@ class Hippocampus:
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||||
|
|
||||||
|
# 获取该批次消息的group_id
|
||||||
|
group_id = None
|
||||||
|
if messages and len(messages) > 0 and 'group_id' in messages[0]:
|
||||||
|
group_id = messages[0]['group_id']
|
||||||
|
|
||||||
compress_rate = global_config.memory_compress_rate
|
compress_rate = global_config.memory_compress_rate
|
||||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||||
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
|
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
|
||||||
|
|
@ -323,7 +351,7 @@ class Hippocampus:
|
||||||
|
|
||||||
for topic, memory in compressed_memory:
|
for topic, memory in compressed_memory:
|
||||||
logger.info(f"添加节点: {topic}")
|
logger.info(f"添加节点: {topic}")
|
||||||
self.memory_graph.add_dot(topic, memory)
|
self.memory_graph.add_dot(topic, memory, group_id)
|
||||||
all_topics.append(topic)
|
all_topics.append(topic)
|
||||||
|
|
||||||
# 连接相似的已存在主题
|
# 连接相似的已存在主题
|
||||||
|
|
@ -841,7 +869,7 @@ class Hippocampus:
|
||||||
return activation
|
return activation
|
||||||
|
|
||||||
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
|
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
|
||||||
max_memory_num: int = 5) -> list:
|
max_memory_num: int = 5, group_id: str = None) -> list:
|
||||||
"""根据输入文本获取相关的记忆内容"""
|
"""根据输入文本获取相关的记忆内容"""
|
||||||
# 识别主题
|
# 识别主题
|
||||||
identified_topics = await self._identify_topics(text)
|
identified_topics = await self._identify_topics(text)
|
||||||
|
|
@ -865,15 +893,29 @@ class Hippocampus:
|
||||||
# 如果记忆条数超过限制,随机选择指定数量的记忆
|
# 如果记忆条数超过限制,随机选择指定数量的记忆
|
||||||
if len(first_layer) > max_memory_num / 2:
|
if len(first_layer) > max_memory_num / 2:
|
||||||
first_layer = random.sample(first_layer, max_memory_num // 2)
|
first_layer = random.sample(first_layer, max_memory_num // 2)
|
||||||
|
|
||||||
# 为每条记忆添加来源主题和相似度信息
|
# 为每条记忆添加来源主题和相似度信息
|
||||||
for memory in first_layer:
|
for memory in first_layer:
|
||||||
relevant_memories.append({
|
# 添加群聊ID筛选
|
||||||
'topic': topic,
|
# 如果memory是字典格式且有群组信息,则进行过滤
|
||||||
'similarity': score,
|
if isinstance(memory, dict) and 'group_id' in memory:
|
||||||
'content': memory
|
# 当前没有指定群聊ID或者记忆来自相同群聊时才添加
|
||||||
})
|
if group_id is None or memory['group_id'] == group_id:
|
||||||
|
relevant_memories.append({
|
||||||
|
'topic': topic,
|
||||||
|
'similarity': score,
|
||||||
|
'content': memory['content'] if 'content' in memory else memory,
|
||||||
|
'group_id': memory.get('group_id')
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# 对于没有群组信息的旧记忆,保持向后兼容
|
||||||
|
relevant_memories.append({
|
||||||
|
'topic': topic,
|
||||||
|
'similarity': score,
|
||||||
|
'content': memory,
|
||||||
|
'group_id': None # 旧数据没有群组信息
|
||||||
|
})
|
||||||
|
|
||||||
# 如果记忆数量超过5个,随机选择5个
|
|
||||||
# 按相似度排序
|
# 按相似度排序
|
||||||
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||||
|
|
||||||
|
|
@ -882,6 +924,34 @@ class Hippocampus:
|
||||||
|
|
||||||
return relevant_memories
|
return relevant_memories
|
||||||
|
|
||||||
|
def get_group_memories(self, group_id: str) -> list:
|
||||||
|
"""获取特定群聊的所有记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: 群聊ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 该群聊的记忆列表,每个记忆包含主题和内容
|
||||||
|
"""
|
||||||
|
all_memories = []
|
||||||
|
all_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||||
|
|
||||||
|
for concept, data in all_nodes:
|
||||||
|
memory_items = data.get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 筛选出属于指定群聊的记忆
|
||||||
|
for memory in memory_items:
|
||||||
|
if isinstance(memory, dict) and 'group_id' in memory and memory['group_id'] == group_id:
|
||||||
|
# 添加到结果列表
|
||||||
|
all_memories.append({
|
||||||
|
'topic': concept,
|
||||||
|
'content': memory.get('content', str(memory))
|
||||||
|
})
|
||||||
|
|
||||||
|
return all_memories
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
seg_text = list(jieba.cut(text))
|
seg_text = list(jieba.cut(text))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue