mirror of https://github.com/Mai-with-u/MaiBot.git
附属函数参数修改
parent
e1a21c5a45
commit
e8922672aa
|
|
@ -9,7 +9,6 @@ import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from itertools import combinations
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
@ -23,6 +22,8 @@ from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||||
) # 导入 build_readable_messages
|
) # 导入 build_readable_messages
|
||||||
|
|
||||||
|
|
||||||
# 添加cosine_similarity函数
|
# 添加cosine_similarity函数
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
|
|
@ -51,18 +52,9 @@ def calculate_information_content(text):
|
||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory")
|
logger = get_logger("memory")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryGraph:
|
class MemoryGraph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
|
|
@ -203,7 +195,9 @@ class MemoryGraph:
|
||||||
整合后的记忆:"""
|
整合后的记忆:"""
|
||||||
|
|
||||||
# 调用LLM进行整合
|
# 调用LLM进行整合
|
||||||
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(integration_prompt)
|
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
|
||||||
|
integration_prompt
|
||||||
|
)
|
||||||
|
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
integrated_content = content.strip()
|
integrated_content = content.strip()
|
||||||
|
|
@ -238,7 +232,11 @@ class MemoryGraph:
|
||||||
if memory_items:
|
if memory_items:
|
||||||
# 删除整个节点
|
# 删除整个节点
|
||||||
self.G.remove_node(topic)
|
self.G.remove_node(topic)
|
||||||
return f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
return (
|
||||||
|
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
||||||
|
if len(memory_items) > 50
|
||||||
|
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 如果没有记忆项,删除该节点
|
# 如果没有记忆项,删除该节点
|
||||||
self.G.remove_node(topic)
|
self.G.remove_node(topic)
|
||||||
|
|
@ -263,7 +261,9 @@ class Hippocampus:
|
||||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||||
# 从数据库加载记忆图
|
# 从数据库加载记忆图
|
||||||
self.entorhinal_cortex.sync_memory_from_db()
|
self.entorhinal_cortex.sync_memory_from_db()
|
||||||
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify")
|
self.model_small = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
|
||||||
|
)
|
||||||
|
|
||||||
def get_all_node_names(self) -> list:
|
def get_all_node_names(self) -> list:
|
||||||
"""获取记忆图中所有节点的名字列表"""
|
"""获取记忆图中所有节点的名字列表"""
|
||||||
|
|
@ -333,8 +333,6 @@ class Hippocampus:
|
||||||
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -419,15 +417,12 @@ class Hippocampus:
|
||||||
text_length = len(text)
|
text_length = len(text)
|
||||||
topic_num: int | list[int] = 0
|
topic_num: int | list[int] = 0
|
||||||
|
|
||||||
|
|
||||||
words = jieba.cut(text)
|
words = jieba.cut(text)
|
||||||
keywords_lite = [word for word in words if len(word) > 1]
|
keywords_lite = [word for word in words if len(word) > 1]
|
||||||
keywords_lite = list(set(keywords_lite))
|
keywords_lite = list(set(keywords_lite))
|
||||||
if keywords_lite:
|
if keywords_lite:
|
||||||
logger.debug(f"提取关键词极简版: {keywords_lite}")
|
logger.debug(f"提取关键词极简版: {keywords_lite}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if text_length <= 12:
|
if text_length <= 12:
|
||||||
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
||||||
elif text_length <= 20:
|
elif text_length <= 20:
|
||||||
|
|
@ -455,7 +450,7 @@ class Hippocampus:
|
||||||
if keywords:
|
if keywords:
|
||||||
logger.debug(f"提取关键词: {keywords}")
|
logger.debug(f"提取关键词: {keywords}")
|
||||||
|
|
||||||
return keywords,keywords_lite
|
return keywords, keywords_lite
|
||||||
|
|
||||||
async def get_memory_from_topic(
|
async def get_memory_from_topic(
|
||||||
self,
|
self,
|
||||||
|
|
@ -570,15 +565,12 @@ class Hippocampus:
|
||||||
for node, activation in remember_map.items():
|
for node, activation in remember_map.items():
|
||||||
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
||||||
node_data = self.memory_graph.G.nodes[node]
|
node_data = self.memory_graph.G.nodes[node]
|
||||||
memory_items = node_data.get("memory_items", "")
|
if memory_items := node_data.get("memory_items", ""):
|
||||||
# 直接使用完整的记忆内容
|
|
||||||
if memory_items:
|
|
||||||
logger.debug("节点包含完整记忆")
|
logger.debug("节点包含完整记忆")
|
||||||
# 计算记忆与关键词的相似度
|
# 计算记忆与关键词的相似度
|
||||||
memory_words = set(jieba.cut(memory_items))
|
memory_words = set(jieba.cut(memory_items))
|
||||||
text_words = set(keywords)
|
text_words = set(keywords)
|
||||||
all_words = memory_words | text_words
|
if all_words := memory_words | text_words:
|
||||||
if all_words:
|
|
||||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||||
|
|
@ -613,7 +605,9 @@ class Hippocampus:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
|
async def get_activate_from_text(
|
||||||
|
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||||
|
) -> tuple[float, list[str], list[str]]:
|
||||||
"""从文本中提取关键词并获取相关记忆。
|
"""从文本中提取关键词并获取相关记忆。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -627,13 +621,13 @@ class Hippocampus:
|
||||||
float: 激活节点数与总节点数的比值
|
float: 激活节点数与总节点数的比值
|
||||||
list[str]: 有效的关键词
|
list[str]: 有效的关键词
|
||||||
"""
|
"""
|
||||||
keywords,keywords_lite = await self.get_keywords_from_text(text)
|
keywords, keywords_lite = await self.get_keywords_from_text(text)
|
||||||
|
|
||||||
# 过滤掉不存在于记忆图中的关键词
|
# 过滤掉不存在于记忆图中的关键词
|
||||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||||
if not valid_keywords:
|
if not valid_keywords:
|
||||||
# logger.info("没有找到有效的关键词节点")
|
# logger.info("没有找到有效的关键词节点")
|
||||||
return 0, keywords,keywords_lite
|
return 0, keywords, keywords_lite
|
||||||
|
|
||||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|
||||||
|
|
@ -700,7 +694,7 @@ class Hippocampus:
|
||||||
activation_ratio = activation_ratio * 50
|
activation_ratio = activation_ratio * 50
|
||||||
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
||||||
|
|
||||||
return activation_ratio, keywords,keywords_lite
|
return activation_ratio, keywords, keywords_lite
|
||||||
|
|
||||||
|
|
||||||
# 负责海马体与其他部分的交互
|
# 负责海马体与其他部分的交互
|
||||||
|
|
@ -865,7 +859,9 @@ class EntorhinalCortex:
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒")
|
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||||
logger.info(f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边")
|
logger.info(
|
||||||
|
f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边"
|
||||||
|
)
|
||||||
|
|
||||||
async def resync_memory_to_db(self):
|
async def resync_memory_to_db(self):
|
||||||
"""清空数据库并重新同步所有记忆数据"""
|
"""清空数据库并重新同步所有记忆数据"""
|
||||||
|
|
@ -999,11 +995,15 @@ class EntorhinalCortex:
|
||||||
last_modified = node.last_modified or current_time
|
last_modified = node.last_modified or current_time
|
||||||
|
|
||||||
# 获取权重属性
|
# 获取权重属性
|
||||||
weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
|
weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
|
||||||
|
|
||||||
# 添加节点到图中
|
# 添加节点到图中
|
||||||
self.memory_graph.G.add_node(
|
self.memory_graph.G.add_node(
|
||||||
concept, memory_items=memory_items, weight=weight, created_time=created_time, last_modified=last_modified
|
concept,
|
||||||
|
memory_items=memory_items,
|
||||||
|
weight=weight,
|
||||||
|
created_time=created_time,
|
||||||
|
last_modified=last_modified,
|
||||||
)
|
)
|
||||||
loaded_nodes += 1
|
loaded_nodes += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1046,7 +1046,9 @@ class EntorhinalCortex:
|
||||||
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
||||||
|
|
||||||
# 输出加载统计信息
|
# 输出加载统计信息
|
||||||
logger.info(f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个")
|
logger.info(
|
||||||
|
f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 负责整合,遗忘,合并记忆
|
# 负责整合,遗忘,合并记忆
|
||||||
|
|
@ -1055,9 +1057,11 @@ class ParahippocampalGyrus:
|
||||||
self.hippocampus = hippocampus
|
self.hippocampus = hippocampus
|
||||||
self.memory_graph = hippocampus.memory_graph
|
self.memory_graph = hippocampus.memory_graph
|
||||||
|
|
||||||
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify")
|
self.memory_modify_model = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils, request_type="memory.modify"
|
||||||
|
)
|
||||||
|
|
||||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
|
||||||
"""压缩和总结消息内容,生成记忆主题和摘要。
|
"""压缩和总结消息内容,生成记忆主题和摘要。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1314,8 +1318,6 @@ class ParahippocampalGyrus:
|
||||||
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HippocampusManager:
|
class HippocampusManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hippocampus: Hippocampus = None # type: ignore
|
self._hippocampus: Hippocampus = None # type: ignore
|
||||||
|
|
@ -1374,7 +1376,10 @@ class HippocampusManager:
|
||||||
logger.info(f"为 {chat_id} 构建记忆,消息数量: {len(messages)}")
|
logger.info(f"为 {chat_id} 构建记忆,消息数量: {len(messages)}")
|
||||||
|
|
||||||
# 调用记忆压缩和构建
|
# 调用记忆压缩和构建
|
||||||
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
|
(
|
||||||
|
compressed_memory,
|
||||||
|
similar_topics_dict,
|
||||||
|
) = await self._hippocampus.parahippocampal_gyrus.memory_compress(
|
||||||
messages, global_config.memory.memory_compress_rate
|
messages, global_config.memory.memory_compress_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1390,10 +1395,11 @@ class HippocampusManager:
|
||||||
if topic != similar_topic:
|
if topic != similar_topic:
|
||||||
strength = int(similarity * 10)
|
strength = int(similarity * 10)
|
||||||
self._hippocampus.memory_graph.G.add_edge(
|
self._hippocampus.memory_graph.G.add_edge(
|
||||||
topic, similar_topic,
|
topic,
|
||||||
|
similar_topic,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
created_time=current_time,
|
created_time=current_time,
|
||||||
last_modified=current_time
|
last_modified=current_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 同步到数据库
|
# 同步到数据库
|
||||||
|
|
@ -1407,7 +1413,6 @@ class HippocampusManager:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def get_memory_from_topic(
|
async def get_memory_from_topic(
|
||||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||||
) -> list:
|
) -> list:
|
||||||
|
|
@ -1423,16 +1428,20 @@ class HippocampusManager:
|
||||||
response = []
|
response = []
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
async def get_activate_from_text(
|
||||||
|
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||||
|
) -> tuple[float, list[str]]:
|
||||||
"""从文本中获取激活值的公共接口"""
|
"""从文本中获取激活值的公共接口"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
try:
|
try:
|
||||||
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
|
||||||
|
text, max_depth, fast_retrieval
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"文本产生激活值失败: {e}")
|
logger.error(f"文本产生激活值失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return 0.0, [],[]
|
return 0.0, [], []
|
||||||
|
|
||||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||||
"""从关键词获取相关记忆的公共接口"""
|
"""从关键词获取相关记忆的公共接口"""
|
||||||
|
|
@ -1469,7 +1478,7 @@ class MemoryBuilder:
|
||||||
|
|
||||||
# 检查时间间隔
|
# 检查时间间隔
|
||||||
time_diff = current_time - self.last_update_time
|
time_diff = current_time - self.last_update_time
|
||||||
if time_diff < 600 /global_config.memory.memory_build_frequency:
|
if time_diff < 600 / global_config.memory.memory_build_frequency:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查消息数量
|
# 检查消息数量
|
||||||
|
|
@ -1482,30 +1491,27 @@ class MemoryBuilder:
|
||||||
|
|
||||||
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
|
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
|
||||||
|
|
||||||
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
|
if not recent_messages or len(recent_messages) < 30 / global_config.memory.memory_build_frequency:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
|
def get_messages_for_memory_build(self, threshold: int = 25) -> List[DatabaseMessages]:
|
||||||
"""获取用于记忆构建的消息"""
|
"""获取用于记忆构建的消息"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_update_time,
|
timestamp_start=self.last_update_time,
|
||||||
timestamp_end=current_time,
|
timestamp_end=current_time,
|
||||||
limit=threshold,
|
limit=threshold,
|
||||||
)
|
)
|
||||||
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
|
|
||||||
if messages:
|
if messages:
|
||||||
# 更新最后处理时间
|
# 更新最后处理时间
|
||||||
self.last_processed_time = current_time
|
self.last_processed_time = current_time
|
||||||
self.last_update_time = current_time
|
self.last_update_time = current_time
|
||||||
|
|
||||||
return tmp_msg or []
|
return messages or []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MemorySegmentManager:
|
class MemorySegmentManager:
|
||||||
|
|
@ -1528,7 +1534,7 @@ class MemorySegmentManager:
|
||||||
builder = self.get_or_create_builder(chat_id)
|
builder = self.get_or_create_builder(chat_id)
|
||||||
return builder.should_trigger_memory_build()
|
return builder.should_trigger_memory_build()
|
||||||
|
|
||||||
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
|
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[DatabaseMessages]:
|
||||||
"""获取指定chat_id用于记忆构建的消息"""
|
"""获取指定chat_id用于记忆构建的消息"""
|
||||||
if chat_id not in self.builders:
|
if chat_id not in self.builders:
|
||||||
return []
|
return []
|
||||||
|
|
@ -1537,4 +1543,3 @@ class MemorySegmentManager:
|
||||||
|
|
||||||
# 创建全局实例
|
# 创建全局实例
|
||||||
memory_segment_manager = MemorySegmentManager()
|
memory_segment_manager = MemorySegmentManager()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.chat.utils.utils import parse_keywords_string
|
from src.chat.utils.utils import parse_keywords_string
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
import random
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory_activator")
|
logger = get_logger("memory_activator")
|
||||||
|
|
@ -75,8 +75,9 @@ class MemoryActivator:
|
||||||
request_type="memory.selection",
|
request_type="memory.selection",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def activate_memory_with_chat_history(
|
||||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
|
self, target_message, chat_history: List[DatabaseMessages]
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
激活记忆
|
激活记忆
|
||||||
"""
|
"""
|
||||||
|
|
@ -86,8 +87,8 @@ class MemoryActivator:
|
||||||
|
|
||||||
keywords_list = set()
|
keywords_list = set()
|
||||||
|
|
||||||
for msg in chat_history_prompt:
|
for msg in chat_history:
|
||||||
keywords = parse_keywords_string(msg.get("key_words", ""))
|
keywords = parse_keywords_string(msg.key_words)
|
||||||
if keywords:
|
if keywords:
|
||||||
if len(keywords_list) < 30:
|
if len(keywords_list) < 30:
|
||||||
# 最多容纳30个关键词
|
# 最多容纳30个关键词
|
||||||
|
|
@ -112,20 +113,13 @@ class MemoryActivator:
|
||||||
logger.debug("海马体没有返回相关记忆")
|
logger.debug("海马体没有返回相关记忆")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
used_ids = set()
|
used_ids = set()
|
||||||
candidate_memories = []
|
candidate_memories = []
|
||||||
|
|
||||||
# 为每个记忆分配随机ID并过滤相关记忆
|
# 为每个记忆分配随机ID并过滤相关记忆
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
keyword, content = memory
|
keyword, content = memory
|
||||||
found = False
|
found = any(kw in content for kw in keywords_list)
|
||||||
for kw in keywords_list:
|
|
||||||
if kw in content:
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if found:
|
if found:
|
||||||
# 随机分配一个不重复的2位数id
|
# 随机分配一个不重复的2位数id
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -145,12 +139,11 @@ class MemoryActivator:
|
||||||
# 转换为 (keyword, content) 格式
|
# 转换为 (keyword, content) 格式
|
||||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||||
|
|
||||||
# 使用 LLM 选择合适的记忆
|
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
|
||||||
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
|
|
||||||
|
|
||||||
return selected_memories
|
async def _select_memories_with_llm(
|
||||||
|
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
|
||||||
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
|
) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
使用 LLM 选择合适的记忆
|
使用 LLM 选择合适的记忆
|
||||||
|
|
||||||
|
|
@ -165,14 +158,13 @@ class MemoryActivator:
|
||||||
try:
|
try:
|
||||||
# 构建聊天历史字符串
|
# 构建聊天历史字符串
|
||||||
obs_info_text = build_readable_messages(
|
obs_info_text = build_readable_messages(
|
||||||
chat_history_prompt,
|
chat_history,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 构建记忆信息字符串
|
# 构建记忆信息字符串
|
||||||
memory_lines = []
|
memory_lines = []
|
||||||
for memory in candidate_memories:
|
for memory in candidate_memories:
|
||||||
|
|
@ -193,18 +185,12 @@ class MemoryActivator:
|
||||||
# 获取并格式化 prompt
|
# 获取并格式化 prompt
|
||||||
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
||||||
formatted_prompt = prompt_template.format(
|
formatted_prompt = prompt_template.format(
|
||||||
obs_info_text=obs_info_text,
|
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
|
||||||
target_message=target_message,
|
|
||||||
memory_info=memory_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 调用 LLM
|
# 调用 LLM
|
||||||
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
||||||
formatted_prompt,
|
formatted_prompt, temperature=0.3, max_tokens=150
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=150
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
|
|
@ -221,11 +207,8 @@ class MemoryActivator:
|
||||||
# 解析为 Python 对象
|
# 解析为 Python 对象
|
||||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||||
|
|
||||||
# 提取 memory_ids 字段
|
# 提取 memory_ids 字段并解析逗号分隔的编号
|
||||||
memory_ids_str = result.get("memory_ids", "")
|
if memory_ids_str := result.get("memory_ids", ""):
|
||||||
|
|
||||||
# 解析逗号分隔的编号
|
|
||||||
if memory_ids_str:
|
|
||||||
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
||||||
# 过滤掉空字符串和无效编号
|
# 过滤掉空字符串和无效编号
|
||||||
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
||||||
|
|
@ -240,10 +223,9 @@ class MemoryActivator:
|
||||||
selected_memories = []
|
selected_memories = []
|
||||||
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
||||||
|
|
||||||
for memory_id in selected_memory_ids:
|
selected_memories = [
|
||||||
if memory_id in memory_id_to_memory:
|
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
|
||||||
selected_memories.append(memory_id_to_memory[memory_id])
|
]
|
||||||
|
|
||||||
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
||||||
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
||||||
|
|
||||||
|
|
@ -256,5 +238,4 @@ class MemoryActivator:
|
||||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.mais4u.mai_think import mai_thinking_manager
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.individuality.individuality import get_individuality
|
from src.individuality.individuality import get_individuality
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
@ -352,7 +353,7 @@ class DefaultReplyer:
|
||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
||||||
async def build_memory_block(self, chat_history: List[Dict[str, Any]], target: str) -> str:
|
async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
|
||||||
"""构建记忆块
|
"""构建记忆块
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -369,7 +370,7 @@ class DefaultReplyer:
|
||||||
instant_memory = None
|
instant_memory = None
|
||||||
|
|
||||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||||
target_message=target, chat_history_prompt=chat_history
|
target_message=target, chat_history=chat_history
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.memory.enable_instant_memory:
|
if global_config.memory.enable_instant_memory:
|
||||||
|
|
@ -433,7 +434,7 @@ class DefaultReplyer:
|
||||||
logger.error(f"工具信息获取失败: {e}")
|
logger.error(f"工具信息获取失败: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
def _parse_reply_target(self, target_message: Optional[str]) -> Tuple[str, str]:
|
||||||
"""解析回复目标消息
|
"""解析回复目标消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -514,7 +515,7 @@ class DefaultReplyer:
|
||||||
return name, result, duration
|
return name, result, duration
|
||||||
|
|
||||||
def build_s4u_chat_history_prompts(
|
def build_s4u_chat_history_prompts(
|
||||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
构建 s4u 风格的分离对话 prompt
|
构建 s4u 风格的分离对话 prompt
|
||||||
|
|
@ -530,16 +531,16 @@ class DefaultReplyer:
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||||
for msg_dict in message_list_before_now:
|
for msg in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg.user_info.user_id)
|
||||||
reply_to = msg_dict.get("reply_to", "")
|
reply_to = msg.reply_to
|
||||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||||
# bot 和目标用户的对话
|
# bot 和目标用户的对话
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
|
||||||
|
|
||||||
# 构建背景对话 prompt
|
# 构建背景对话 prompt
|
||||||
all_dialogue_prompt = ""
|
all_dialogue_prompt = ""
|
||||||
|
|
@ -574,7 +575,6 @@ class DefaultReplyer:
|
||||||
core_dialogue_prompt_str = build_readable_messages(
|
core_dialogue_prompt_str = build_readable_messages(
|
||||||
core_dialogue_list,
|
core_dialogue_list,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
truncate=True,
|
truncate=True,
|
||||||
|
|
@ -712,25 +712,20 @@ class DefaultReplyer:
|
||||||
|
|
||||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||||
|
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size * 1,
|
limit=global_config.chat.max_context_size * 1,
|
||||||
)
|
)
|
||||||
temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long]
|
|
||||||
|
|
||||||
# TODO: 修复!
|
|
||||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.33),
|
limit=int(global_config.chat.max_context_size * 0.33),
|
||||||
)
|
)
|
||||||
temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short]
|
|
||||||
|
|
||||||
chat_talking_prompt_short = build_readable_messages(
|
chat_talking_prompt_short = build_readable_messages(
|
||||||
temp_msg_list_before_short,
|
message_list_before_short,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
@ -743,7 +738,7 @@ class DefaultReplyer:
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
|
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
|
||||||
self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"),
|
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||||
),
|
),
|
||||||
|
|
@ -827,7 +822,7 @@ class DefaultReplyer:
|
||||||
|
|
||||||
# 构建分离的对话 prompt
|
# 构建分离的对话 prompt
|
||||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||||
temp_msg_list_before_long, user_id, sender
|
message_list_before_now_long, user_id, sender
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||||
|
|
@ -901,11 +896,8 @@ class DefaultReplyer:
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
|
|
||||||
chat_talking_prompt_half = build_readable_messages(
|
chat_talking_prompt_half = build_readable_messages(
|
||||||
temp_msg_list_before_now_half,
|
message_list_before_now_half,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
@ -913,7 +905,7 @@ class DefaultReplyer:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 并行执行2个构建任务
|
# 并行执行2个构建任务
|
||||||
(expression_habits_block, selected_expressions), relation_info = await asyncio.gather(
|
(expression_habits_block, _), relation_info = await asyncio.gather(
|
||||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||||
self.build_relation_info(sender, target),
|
self.build_relation_info(sender, target),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class BaseDataModel:
|
class BaseDataModel:
|
||||||
|
|
@ -7,6 +7,7 @@ class BaseDataModel:
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||||
|
# sourcery skip: assign-if-exp, reintroduce-else
|
||||||
"""
|
"""
|
||||||
将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例
|
将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例
|
||||||
递归转换为普通 dict,不修改原对象。
|
递归转换为普通 dict,不修改原对象。
|
||||||
|
|
|
||||||
|
|
@ -163,11 +163,8 @@ class ChatAction:
|
||||||
limit=15,
|
limit=15,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
tmp_msgs,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
@ -229,11 +226,8 @@ class ChatAction:
|
||||||
limit=10,
|
limit=10,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
tmp_msgs,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
|
||||||
|
|
@ -166,11 +166,9 @@ class ChatMood:
|
||||||
limit=10,
|
limit=10,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
tmp_msgs,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
@ -247,11 +245,9 @@ class ChatMood:
|
||||||
limit=5,
|
limit=5,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
tmp_msgs,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,10 @@ from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from .s4u_mood_manager import mood_manager
|
from .s4u_mood_manager import mood_manager
|
||||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
logger = get_logger("prompt")
|
logger = get_logger("prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -97,12 +101,11 @@ class PromptBuilder:
|
||||||
self.activate_messages = ""
|
self.activate_messages = ""
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
|
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
|
||||||
|
|
||||||
style_habits = []
|
style_habits = []
|
||||||
|
|
||||||
# 使用从处理器传来的选中表达方式
|
# 使用从处理器传来的选中表达方式
|
||||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||||
selected_expressions ,_ = await expression_selector.select_suitable_expressions_llm(
|
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
|
||||||
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -122,7 +125,6 @@ class PromptBuilder:
|
||||||
if style_habits_str.strip():
|
if style_habits_str.strip():
|
||||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||||
|
|
||||||
|
|
||||||
return expression_habits_block
|
return expression_habits_block
|
||||||
|
|
||||||
async def build_relation_info(self, chat_stream) -> str:
|
async def build_relation_info(self, chat_stream) -> str:
|
||||||
|
|
@ -148,9 +150,7 @@ class PromptBuilder:
|
||||||
person_ids.append(person_id)
|
person_ids.append(person_id)
|
||||||
|
|
||||||
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
|
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
|
||||||
relation_info_list = [
|
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
|
||||||
Person(person_id=person_id).build_relationship() for person_id in person_ids
|
|
||||||
]
|
|
||||||
if relation_info := "".join(relation_info_list):
|
if relation_info := "".join(relation_info_list):
|
||||||
relation_prompt = await global_prompt_manager.format_prompt(
|
relation_prompt = await global_prompt_manager.format_prompt(
|
||||||
"relation_prompt", relation_info=relation_info
|
"relation_prompt", relation_info=relation_info
|
||||||
|
|
@ -176,38 +176,37 @@ class PromptBuilder:
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
|
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
|
||||||
limit=300,
|
limit=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
||||||
|
|
||||||
core_dialogue_list = []
|
core_dialogue_list: List[DatabaseMessages] = []
|
||||||
background_dialogue_list = []
|
background_dialogue_list: List[DatabaseMessages] = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||||
|
|
||||||
# TODO: 修复之!
|
|
||||||
for msg in message_list_before_now:
|
for msg in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg.user_info.user_id)
|
msg_user_id = str(msg.user_info.user_id)
|
||||||
if msg_user_id == bot_id:
|
if msg_user_id == bot_id:
|
||||||
if msg.reply_to and talk_type == msg.reply_to:
|
if msg.reply_to and talk_type == msg.reply_to:
|
||||||
core_dialogue_list.append(msg.__dict__)
|
core_dialogue_list.append(msg)
|
||||||
elif msg.reply_to and talk_type != msg.reply_to:
|
elif msg.reply_to and talk_type != msg.reply_to:
|
||||||
background_dialogue_list.append(msg.__dict__)
|
background_dialogue_list.append(msg)
|
||||||
# else:
|
# else:
|
||||||
# background_dialogue_list.append(msg_dict)
|
# background_dialogue_list.append(msg_dict)
|
||||||
elif msg_user_id == target_user_id:
|
elif msg_user_id == target_user_id:
|
||||||
core_dialogue_list.append(msg.__dict__)
|
core_dialogue_list.append(msg)
|
||||||
else:
|
else:
|
||||||
background_dialogue_list.append(msg.__dict__)
|
background_dialogue_list.append(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
|
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
|
||||||
|
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:]
|
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
|
||||||
background_dialogue_prompt_str = build_readable_messages(
|
background_dialogue_prompt_str = build_readable_messages(
|
||||||
context_msgs,
|
context_msgs,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
@ -217,10 +216,10 @@ class PromptBuilder:
|
||||||
|
|
||||||
core_msg_str = ""
|
core_msg_str = ""
|
||||||
if core_dialogue_list:
|
if core_dialogue_list:
|
||||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:]
|
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
|
||||||
|
|
||||||
first_msg = core_dialogue_list[0]
|
first_msg = core_dialogue_list[0]
|
||||||
start_speaking_user_id = first_msg.get("user_id")
|
start_speaking_user_id = first_msg.user_info.user_id
|
||||||
if start_speaking_user_id == bot_id:
|
if start_speaking_user_id == bot_id:
|
||||||
last_speaking_user_id = bot_id
|
last_speaking_user_id = bot_id
|
||||||
msg_seg_str = "你的发言:\n"
|
msg_seg_str = "你的发言:\n"
|
||||||
|
|
@ -229,13 +228,13 @@ class PromptBuilder:
|
||||||
last_speaking_user_id = start_speaking_user_id
|
last_speaking_user_id = start_speaking_user_id
|
||||||
msg_seg_str = "对方的发言:\n"
|
msg_seg_str = "对方的发言:\n"
|
||||||
|
|
||||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
|
||||||
|
|
||||||
all_msg_seg_list = []
|
all_msg_seg_list = []
|
||||||
for msg in core_dialogue_list[1:]:
|
for msg in core_dialogue_list[1:]:
|
||||||
speaker = msg.get("user_id")
|
speaker = msg.user_info.user_id
|
||||||
if speaker == last_speaking_user_id:
|
if speaker == last_speaking_user_id:
|
||||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
|
||||||
else:
|
else:
|
||||||
msg_seg_str = f"{msg_seg_str}\n"
|
msg_seg_str = f"{msg_seg_str}\n"
|
||||||
all_msg_seg_list.append(msg_seg_str)
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
|
|
@ -252,23 +251,19 @@ class PromptBuilder:
|
||||||
for msg in all_msg_seg_list:
|
for msg in all_msg_seg_list:
|
||||||
core_msg_str += msg
|
core_msg_str += msg
|
||||||
|
|
||||||
|
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
|
||||||
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=20,
|
limit=20,
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt]
|
|
||||||
all_dialogue_prompt_str = build_readable_messages(
|
all_dialogue_prompt_str = build_readable_messages(
|
||||||
tmp_msgs,
|
all_dialogue_history,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
show_pic=False,
|
show_pic=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
|
||||||
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str
|
|
||||||
|
|
||||||
def build_gift_info(self, message: MessageRecvS4U):
|
def build_gift_info(self, message: MessageRecvS4U):
|
||||||
if message.is_gift:
|
if message.is_gift:
|
||||||
|
|
@ -283,13 +278,11 @@ class PromptBuilder:
|
||||||
super_chat_manager = get_super_chat_manager()
|
super_chat_manager = get_super_chat_manager()
|
||||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||||
|
|
||||||
|
|
||||||
async def build_prompt_normal(
|
async def build_prompt_normal(
|
||||||
self,
|
self,
|
||||||
message: MessageRecvS4U,
|
message: MessageRecvS4U,
|
||||||
message_txt: str,
|
message_txt: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
chat_stream = message.chat_stream
|
chat_stream = message.chat_stream
|
||||||
|
|
||||||
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
|
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
|
||||||
|
|
@ -303,12 +296,15 @@ class PromptBuilder:
|
||||||
else:
|
else:
|
||||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||||
|
|
||||||
|
|
||||||
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
||||||
self.build_relation_info(chat_stream), self.build_memory_block(message_txt), self.build_expression_habits(chat_stream, message_txt, sender_name)
|
self.build_relation_info(chat_stream),
|
||||||
|
self.build_memory_block(message_txt),
|
||||||
|
self.build_expression_habits(chat_stream, message_txt, sender_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
core_dialogue_prompt, background_dialogue_prompt,all_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
|
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
|
||||||
|
chat_stream, message
|
||||||
|
)
|
||||||
|
|
||||||
gift_info = self.build_gift_info(message)
|
gift_info = self.build_gift_info(message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,11 +99,9 @@ class ChatMood:
|
||||||
limit=int(global_config.chat.max_context_size / 3),
|
limit=int(global_config.chat.max_context_size / 3),
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
# TODO: 修复!
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
tmp_msgs,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
@ -150,11 +148,9 @@ class ChatMood:
|
||||||
limit=15,
|
limit=15,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
# TODO: 修复
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
tmp_msgs,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from json_repair import repair_json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .person_info import Person
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
import random
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
import json
|
|
||||||
from json_repair import repair_json
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
import traceback
|
from .person_info import Person
|
||||||
|
|
||||||
logger = get_logger("relation")
|
logger = get_logger("relation")
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
|
|
@ -46,7 +49,6 @@ def init_prompt():
|
||||||
"attitude_to_me_prompt",
|
"attitude_to_me_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你的名字是{bot_name},{bot_name}的别名是{alias_str}。
|
你的名字是{bot_name},{bot_name}的别名是{alias_str}。
|
||||||
|
|
@ -80,6 +82,7 @@ def init_prompt():
|
||||||
"neuroticism_prompt",
|
"neuroticism_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.relationship_llm = LLMRequest(
|
self.relationship_llm = LLMRequest(
|
||||||
|
|
@ -95,18 +98,16 @@ class RelationshipManager:
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"attitude_to_me_prompt",
|
"attitude_to_me_prompt",
|
||||||
bot_name = global_config.bot.nickname,
|
bot_name=global_config.bot.nickname,
|
||||||
alias_str = alias_str,
|
alias_str=alias_str,
|
||||||
person_name = person.person_name,
|
person_name=person.person_name,
|
||||||
nickname = person.nickname,
|
nickname=person.nickname,
|
||||||
readable_messages = readable_messages,
|
readable_messages=readable_messages,
|
||||||
current_time = current_time,
|
current_time=current_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
attitude = repair_json(attitude)
|
attitude = repair_json(attitude)
|
||||||
attitude_data = json.loads(attitude)
|
attitude_data = json.loads(attitude)
|
||||||
|
|
||||||
|
|
@ -119,10 +120,10 @@ class RelationshipManager:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
attitude_score = attitude_data["attitude"]
|
attitude_score = attitude_data["attitude"]
|
||||||
confidence = pow(attitude_data["confidence"],2)
|
confidence = pow(attitude_data["confidence"], 2)
|
||||||
|
|
||||||
new_confidence = total_confidence + confidence
|
new_confidence = total_confidence + confidence
|
||||||
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence
|
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence) / new_confidence
|
||||||
|
|
||||||
person.attitude_to_me = new_attitude_score
|
person.attitude_to_me = new_attitude_score
|
||||||
person.attitude_to_me_confidence = new_confidence
|
person.attitude_to_me_confidence = new_confidence
|
||||||
|
|
@ -138,21 +139,19 @@ class RelationshipManager:
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"neuroticism_prompt",
|
"neuroticism_prompt",
|
||||||
bot_name = global_config.bot.nickname,
|
bot_name=global_config.bot.nickname,
|
||||||
alias_str = alias_str,
|
alias_str=alias_str,
|
||||||
person_name = person.person_name,
|
person_name=person.person_name,
|
||||||
nickname = person.nickname,
|
nickname=person.nickname,
|
||||||
readable_messages = readable_messages,
|
readable_messages=readable_messages,
|
||||||
current_time = current_time,
|
current_time=current_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
|
||||||
# logger.info(f"prompt: {prompt}")
|
# logger.info(f"prompt: {prompt}")
|
||||||
# logger.info(f"neuroticism: {neuroticism}")
|
# logger.info(f"neuroticism: {neuroticism}")
|
||||||
|
|
||||||
|
|
||||||
neuroticism = repair_json(neuroticism)
|
neuroticism = repair_json(neuroticism)
|
||||||
neuroticism_data = json.loads(neuroticism)
|
neuroticism_data = json.loads(neuroticism)
|
||||||
|
|
||||||
|
|
@ -165,19 +164,20 @@ class RelationshipManager:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
neuroticism_score = neuroticism_data["neuroticism"]
|
neuroticism_score = neuroticism_data["neuroticism"]
|
||||||
confidence = pow(neuroticism_data["confidence"],2)
|
confidence = pow(neuroticism_data["confidence"], 2)
|
||||||
|
|
||||||
new_confidence = total_confidence + confidence
|
new_confidence = total_confidence + confidence
|
||||||
|
|
||||||
new_neuroticism_score = (current_neuroticism_score * total_confidence + neuroticism_score * confidence)/new_confidence
|
new_neuroticism_score = (
|
||||||
|
current_neuroticism_score * total_confidence + neuroticism_score * confidence
|
||||||
|
) / new_confidence
|
||||||
|
|
||||||
person.neuroticism = new_neuroticism_score
|
person.neuroticism = new_neuroticism_score
|
||||||
person.neuroticism_confidence = new_confidence
|
person.neuroticism_confidence = new_confidence
|
||||||
|
|
||||||
return person
|
return person
|
||||||
|
|
||||||
|
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
|
||||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
|
||||||
"""更新用户印象
|
"""更新用户印象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -202,12 +202,11 @@ class RelationshipManager:
|
||||||
|
|
||||||
# 遍历消息,构建映射
|
# 遍历消息,构建映射
|
||||||
for msg in user_messages:
|
for msg in user_messages:
|
||||||
if msg.get("user_id") == "system":
|
if msg.user_info.user_id == "system":
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
user_id = msg.user_info.user_id
|
||||||
user_id = msg.get("user_id")
|
platform = msg.chat_info.platform
|
||||||
platform = msg.get("chat_info_platform")
|
|
||||||
assert isinstance(user_id, str) and isinstance(platform, str)
|
assert isinstance(user_id, str) and isinstance(platform, str)
|
||||||
msg_person = Person(user_id=user_id, platform=platform)
|
msg_person = Person(user_id=user_id, platform=platform)
|
||||||
|
|
||||||
|
|
@ -244,7 +243,7 @@ class RelationshipManager:
|
||||||
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
|
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
|
||||||
|
|
||||||
# await self.get_points(
|
# await self.get_points(
|
||||||
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
|
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
|
||||||
await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person)
|
await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person)
|
||||||
await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person)
|
await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person)
|
||||||
|
|
||||||
|
|
@ -253,9 +252,6 @@ class RelationshipManager:
|
||||||
|
|
||||||
person.sync_to_database()
|
person.sync_to_database()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
|
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
|
||||||
"""计算基于时间的权重系数"""
|
"""计算基于时间的权重系数"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -280,6 +276,7 @@ class RelationshipManager:
|
||||||
logger.error(f"计算时间权重失败: {e}")
|
logger.error(f"计算时间权重失败: {e}")
|
||||||
return 0.5 # 发生错误时返回中等权重
|
return 0.5 # 发生错误时返回中等权重
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
relationship_manager = None
|
relationship_manager = None
|
||||||
|
|
@ -290,4 +287,3 @@ def get_relationship_manager():
|
||||||
if relationship_manager is None:
|
if relationship_manager is None:
|
||||||
relationship_manager = RelationshipManager()
|
relationship_manager = RelationshipManager()
|
||||||
return relationship_manager
|
return relationship_manager
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -412,7 +412,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
|
||||||
|
|
||||||
|
|
||||||
def build_readable_messages_to_str(
|
def build_readable_messages_to_str(
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[DatabaseMessages],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
timestamp_mode: str = "relative",
|
timestamp_mode: str = "relative",
|
||||||
read_mark: float = 0.0,
|
read_mark: float = 0.0,
|
||||||
|
|
@ -440,7 +440,7 @@ def build_readable_messages_to_str(
|
||||||
|
|
||||||
|
|
||||||
async def build_readable_messages_with_details(
|
async def build_readable_messages_with_details(
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[DatabaseMessages],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
timestamp_mode: str = "relative",
|
timestamp_mode: str = "relative",
|
||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,14 @@ import random
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
# 导入新插件系统
|
# 导入新插件系统
|
||||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
from src.plugin_system import BaseAction, ActionActivationType
|
||||||
|
|
||||||
# 导入依赖的系统组件
|
# 导入依赖的系统组件
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# 导入API模块 - 标准Python包方式
|
# 导入API模块 - 标准Python包方式
|
||||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
||||||
|
|
||||||
# NoReplyAction已集成到heartFC_chat.py中,不再需要导入
|
# NoReplyAction已集成到heartFC_chat.py中,不再需要导入
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
@ -84,11 +85,8 @@ class EmojiAction(BaseAction):
|
||||||
messages_text = ""
|
messages_text = ""
|
||||||
if recent_messages:
|
if recent_messages:
|
||||||
# 使用message_api构建可读的消息字符串
|
# 使用message_api构建可读的消息字符串
|
||||||
# TODO: 修复
|
|
||||||
from src.common.data_models import temporarily_transform_class_to_dict
|
|
||||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages]
|
|
||||||
messages_text = message_api.build_readable_messages(
|
messages_text = message_api.build_readable_messages(
|
||||||
messages=tmp_msgs,
|
messages=recent_messages,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
truncate=False,
|
truncate=False,
|
||||||
show_actions=False,
|
show_actions=False,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue