尝试增加对群聊的记录隔离

pull/248/head
Cindy-Master 2025-03-12 09:28:06 +08:00 committed by GitHub
parent 1ac5c225af
commit 5b13754504
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 313 additions and 220 deletions

View File

@ -91,12 +91,20 @@ class PromptBuilder:
memory_prompt = '' memory_prompt = ''
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 获取群聊ID
stream_group_id = None
if stream_id:
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.group_info:
stream_group_id = str(chat_stream.group_info.group_id)
# 调用 hippocampus 的 get_relevant_memories 方法添加群聊ID参数
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt, text=message_txt,
max_topics=5, max_topics=5,
similarity_threshold=0.4, similarity_threshold=0.4,
max_memory_num=5 max_memory_num=5,
group_id=stream_group_id
) )
if relevant_memories: if relevant_memories:

View File

@ -13,15 +13,21 @@ class MessageStorage:
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
# 提取群组ID信息如果存在的话
group_id = None
if chat_stream.group_info:
group_id = str(chat_stream.group_info.group_id)
message_data = { message_data = {
"message_id": message.message_info.message_id, "message_id": message.message_info.message_id,
"time": message.message_info.time, "time": message.message_info.time,
"chat_id":chat_stream.stream_id, "chat_id": chat_stream.stream_id,
"chat_info": chat_stream.to_dict(), "chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(), "user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
"detailed_plain_text": message.detailed_plain_text, "detailed_plain_text": message.detailed_plain_text,
"topic": topic, "topic": topic,
"group_id": group_id, # 显式添加group_id字段
} }
self.db.db.messages.insert_one(message_data) self.db.db.messages.insert_one(message_data)
except Exception: except Exception:

View File

@ -104,11 +104,20 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
# 转换记录格式 # 转换记录格式
formatted_records = [] formatted_records = []
for record in chat_records: for record in chat_records:
formatted_records.append({ formatted_record = {
'time': record["time"], 'time': record["time"],
'chat_id': record["chat_id"], 'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容 'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
}) }
# 添加group_id信息如果存在
if 'group_id' in record:
formatted_record['group_id'] = record['group_id']
elif 'chat_info' in record and 'group_info' in record['chat_info'] and record['chat_info']['group_info']:
# 从chat_info中提取group_id
formatted_record['group_id'] = record['chat_info']['group_info'].get('group_id')
formatted_records.append(formatted_record)
return formatted_records return formatted_records

View File

@ -44,9 +44,19 @@ class Memory_graph:
created_time=current_time, # 添加创建时间 created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间 last_modified=current_time) # 添加最后修改时间
def add_dot(self, concept, memory): def add_dot(self, concept, memory, group_id=None):
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 如果memory不是字典格式将其转换为字典
if not isinstance(memory, dict):
memory = {
'content': memory,
'group_id': group_id
}
# 如果memory是字典但没有group_id添加group_id
elif 'group_id' not in memory and group_id is not None:
memory['group_id'] = group_id
if concept in self.G: if concept in self.G:
if 'memory_items' in self.G.nodes[concept]: if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list): if not isinstance(self.G.nodes[concept]['memory_items'], list):
@ -218,6 +228,13 @@ class Hippocampus:
if not messages: if not messages:
return set(), {} return set(), {}
# 提取群聊ID信息
group_id = None
for msg in messages:
if 'group_id' in msg and msg['group_id']:
group_id = msg['group_id']
break
# 合并消息文本,同时保留时间信息 # 合并消息文本,同时保留时间信息
input_text = "" input_text = ""
time_info = "" time_info = ""
@ -267,7 +284,13 @@ class Hippocampus:
for topic, task in tasks: for topic, task in tasks:
response = await task response = await task
if response: if response:
compressed_memory.add((topic, response[0])) # 使用字典结构存储记忆内容与群组ID
memory_content = {
'content': response[0],
'group_id': group_id
}
compressed_memory.add((topic, memory_content))
# 为每个话题查找相似的已存在主题 # 为每个话题查找相似的已存在主题
existing_topics = list(self.memory_graph.G.nodes()) existing_topics = list(self.memory_graph.G.nodes())
similar_topics = [] similar_topics = []
@ -315,6 +338,11 @@ class Hippocampus:
bar = '' * filled_length + '-' * (bar_length - filled_length) bar = '' * filled_length + '-' * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
# 获取该批次消息的group_id
group_id = None
if messages and len(messages) > 0 and 'group_id' in messages[0]:
group_id = messages[0]['group_id']
compress_rate = global_config.memory_compress_rate compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}") logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
@ -323,7 +351,7 @@ class Hippocampus:
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
logger.info(f"添加节点: {topic}") logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory, group_id)
all_topics.append(topic) all_topics.append(topic)
# 连接相似的已存在主题 # 连接相似的已存在主题
@ -841,7 +869,7 @@ class Hippocampus:
return activation return activation
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
max_memory_num: int = 5) -> list: max_memory_num: int = 5, group_id: str = None) -> list:
"""根据输入文本获取相关的记忆内容""" """根据输入文本获取相关的记忆内容"""
# 识别主题 # 识别主题
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
@ -865,15 +893,29 @@ class Hippocampus:
# 如果记忆条数超过限制,随机选择指定数量的记忆 # 如果记忆条数超过限制,随机选择指定数量的记忆
if len(first_layer) > max_memory_num / 2: if len(first_layer) > max_memory_num / 2:
first_layer = random.sample(first_layer, max_memory_num // 2) first_layer = random.sample(first_layer, max_memory_num // 2)
# 为每条记忆添加来源主题和相似度信息 # 为每条记忆添加来源主题和相似度信息
for memory in first_layer: for memory in first_layer:
relevant_memories.append({ # 添加群聊ID筛选
'topic': topic, # 如果memory是字典格式且有群组信息则进行过滤
'similarity': score, if isinstance(memory, dict) and 'group_id' in memory:
'content': memory # 当前没有指定群聊ID或者记忆来自相同群聊时才添加
}) if group_id is None or memory['group_id'] == group_id:
relevant_memories.append({
'topic': topic,
'similarity': score,
'content': memory['content'] if 'content' in memory else memory,
'group_id': memory.get('group_id')
})
else:
# 对于没有群组信息的旧记忆,保持向后兼容
relevant_memories.append({
'topic': topic,
'similarity': score,
'content': memory,
'group_id': None # 旧数据没有群组信息
})
# 如果记忆数量超过5个,随机选择5个
# 按相似度排序 # 按相似度排序
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
@ -882,6 +924,34 @@ class Hippocampus:
return relevant_memories return relevant_memories
def get_group_memories(self, group_id: str) -> list:
"""获取特定群聊的所有记忆
Args:
group_id: 群聊ID
Returns:
list: 该群聊的记忆列表每个记忆包含主题和内容
"""
all_memories = []
all_nodes = list(self.memory_graph.G.nodes(data=True))
for concept, data in all_nodes:
memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 筛选出属于指定群聊的记忆
for memory in memory_items:
if isinstance(memory, dict) and 'group_id' in memory and memory['group_id'] == group_id:
# 添加到结果列表
all_memories.append({
'topic': concept,
'content': memory.get('content', str(memory))
})
return all_memories
def segment_text(text): def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))