数据库的信息重构为dataclass

pull/1185/head
UnCLAS-Prommer 2025-08-17 17:11:32 +08:00
parent d74beef4b4
commit 3481234d2b
No known key found for this signature in database
18 changed files with 243 additions and 206 deletions

View File

@ -285,10 +285,11 @@ class HeartFChatting:
filter_mai=True, filter_mai=True,
filter_command=True, filter_command=True,
) )
# TODO: 修复!
temp_recent_messages_dict = [msg.__dict__ for msg in recent_messages_dict]
# 统一的消息处理逻辑 # 统一的消息处理逻辑
should_process,interest_value = await self._should_process_messages(recent_messages_dict) should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
if should_process: if should_process:
self.last_read_time = time.time() self.last_read_time = time.time()
await self._observe(interest_value = interest_value) await self._observe(interest_value = interest_value)

View File

@ -346,13 +346,15 @@ class ExpressionLearner:
current_time = time.time() current_time = time.time()
# 获取上次学习时间 # 获取上次学习时间
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive( random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_learning_time, timestamp_start=self.last_learning_time,
timestamp_end=current_time, timestamp_end=current_time,
limit=num, limit=num,
) )
# TODO: 修复!
random_msg: Optional[List[Dict[str, Any]]] = [msg.__dict__ for msg in random_msg_temp] if random_msg_temp else None
# print(random_msg) # print(random_msg)
if not random_msg or random_msg == []: if not random_msg or random_msg == []:
return None return None

View File

@ -16,6 +16,7 @@ from rich.traceback import install
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入 from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
@ -1495,13 +1496,13 @@ class MemoryBuilder:
timestamp_end=current_time, timestamp_end=current_time,
limit=threshold, limit=threshold,
) )
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
if messages: if messages:
# 更新最后处理时间 # 更新最后处理时间
self.last_processed_time = current_time self.last_processed_time = current_time
self.last_update_time = current_time self.last_update_time = current_time
return messages or [] return tmp_msg or []

View File

@ -70,8 +70,10 @@ class ActionModifier:
timestamp=time.time(), timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10), limit=min(int(global_config.chat.max_context_size * 0.33), 10),
) )
# TODO: 修复!
temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half]
chat_content = build_readable_messages( chat_content = build_readable_messages(
message_list_before_now_half, temp_msg_list_before_now_half,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="relative", timestamp_mode="relative",

View File

@ -95,6 +95,7 @@ class ActionPlanner:
self.max_plan_retries = 3 self.max_plan_retries = 3
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
# sourcery skip: use-next
""" """
根据message_id从message_id_list中查找对应的原始消息 根据message_id从message_id_list中查找对应的原始消息
@ -120,10 +121,7 @@ class ActionPlanner:
Returns: Returns:
最新的消息字典如果列表为空则返回None 最新的消息字典如果列表为空则返回None
""" """
if not message_id_list: return message_id_list[-1].get("message") if message_id_list else None
return None
# 假设消息列表是按时间顺序排列的,最后一个是最新的
return message_id_list[-1].get("message")
async def plan( async def plan(
self, self,
@ -208,22 +206,17 @@ class ActionPlanner:
if target_message is None: if target_message is None:
self.plan_retry_count += 1 self.plan_retry_count += 1
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}") logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
# 仍有重试次数
# 如果连续三次plan均为None输出error并选取最新消息 if self.plan_retry_count < self.max_plan_retries:
if self.plan_retry_count >= self.max_plan_retries:
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message")
target_message = self.get_latest_message(message_id_list)
self.plan_retry_count = 0 # 重置计数器
else:
# 递归重新plan # 递归重新plan
return await self.plan(mode, loop_start_time, available_actions) return await self.plan(mode, loop_start_time, available_actions)
else: logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message")
# 成功获取到target_message重置计数器 target_message = self.get_latest_message(message_id_list)
self.plan_retry_count = 0 self.plan_retry_count = 0 # 重置计数器
else: else:
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id") logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
if action != "no_reply" and action != "reply" and action not in current_available_actions: if action != "no_reply" and action != "reply" and action not in current_available_actions:
logger.warning( logger.warning(
@ -247,28 +240,27 @@ class ActionPlanner:
is_parallel = False is_parallel = False
if mode == ChatMode.NORMAL and action in current_available_actions: if mode == ChatMode.NORMAL and action in current_available_actions:
is_parallel = current_available_actions[action].parallel_action is_parallel = current_available_actions[action].parallel_action
action_data["loop_start_time"] = loop_start_time action_data["loop_start_time"] = loop_start_time
actions = [] actions = [
{
# 1. 添加Planner取得的动作 "action_type": action,
actions.append({ "reasoning": reasoning,
"action_type": action, "action_data": action_data,
"reasoning": reasoning, "action_message": target_message,
"action_data": action_data, "available_actions": available_actions,
"action_message": target_message, }
"available_actions": available_actions # 添加这个字段 ]
})
if action != "reply" and is_parallel: if action != "reply" and is_parallel:
actions.append({ actions.append({
"action_type": "reply", "action_type": "reply",
"action_message": target_message, "action_message": target_message,
"available_actions": available_actions "available_actions": available_actions
}) })
return actions,target_message return actions,target_message
@ -288,9 +280,10 @@ class ActionPlanner:
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6), limit=int(global_config.chat.max_context_size * 0.6),
) )
# TODO: 修复!
temp_msg_list_before_now = [msg.__dict__ for msg in message_list_before_now]
chat_content_block, message_id_list = build_readable_messages_with_id( chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now, messages=temp_msg_list_before_now,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=self.last_obs_time_mark, read_mark=self.last_obs_time_mark,
truncate=True, truncate=True,

View File

@ -91,7 +91,7 @@ def init_prompt():
""", """,
"replyer_prompt", "replyer_prompt",
) )
Prompt( Prompt(
""" """
{expression_habits_block}{tool_info_block} {expression_habits_block}{tool_info_block}
@ -116,7 +116,6 @@ def init_prompt():
""", """,
"replyer_self_prompt", "replyer_self_prompt",
) )
Prompt( Prompt(
""" """
@ -179,7 +178,7 @@ class DefaultReplyer:
Returns: Returns:
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt) Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
""" """
prompt = None prompt = None
selected_expressions = None selected_expressions = None
if available_actions is None: if available_actions is None:
@ -187,7 +186,7 @@ class DefaultReplyer:
try: try:
# 3. 构建 Prompt # 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt,selected_expressions = await self.build_prompt_reply_context( prompt, selected_expressions = await self.build_prompt_reply_context(
extra_info=extra_info, extra_info=extra_info,
available_actions=available_actions, available_actions=available_actions,
choosen_actions=choosen_actions, choosen_actions=choosen_actions,
@ -294,12 +293,12 @@ class DefaultReplyer:
async def build_relation_info(self, sender: str, target: str): async def build_relation_info(self, sender: str, target: str):
if not global_config.relationship.enable_relationship: if not global_config.relationship.enable_relationship:
return "" return ""
if sender == global_config.bot.nickname: if sender == global_config.bot.nickname:
return "" return ""
# 获取用户ID # 获取用户ID
person = Person(person_name = sender) person = Person(person_name=sender)
if not is_person_known(person_name=sender): if not is_person_known(person_name=sender):
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取") logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。" return f"你完全不认识{sender}不理解ta的相关信息。"
@ -307,6 +306,7 @@ class DefaultReplyer:
return person.build_relationship(points_num=5) return person.build_relationship(points_num=5)
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]: async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块 """构建表达习惯块
Args: Args:
@ -359,7 +359,7 @@ class DefaultReplyer:
Returns: Returns:
str: 记忆信息字符串 str: 记忆信息字符串
""" """
if not global_config.memory.enable_memory: if not global_config.memory.enable_memory:
return "" return ""
@ -368,7 +368,6 @@ class DefaultReplyer:
running_memories = await self.memory_activator.activate_memory_with_chat_history( running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history target_message=target, chat_history_prompt=chat_history
) )
if global_config.memory.enable_instant_memory: if global_config.memory.enable_instant_memory:
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history)) asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
@ -379,10 +378,9 @@ class DefaultReplyer:
if not running_memories: if not running_memories:
return "" return ""
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories: for running_memory in running_memories:
keywords,content = running_memory keywords, content = running_memory
memory_str += f"- {keywords}{content}\n" memory_str += f"- {keywords}{content}\n"
if instant_memory: if instant_memory:
@ -405,7 +403,6 @@ class DefaultReplyer:
if not enable_tool: if not enable_tool:
return "" return ""
try: try:
# 使用工具执行器获取信息 # 使用工具执行器获取信息
tool_results, _, _ = await self.tool_executor.execute_from_chat_message( tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
@ -559,16 +556,18 @@ class DefaultReplyer:
# 检查最新五条消息中是否包含bot自己说的消息 # 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}") # logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}") # logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
# 如果最新五条消息中不包含bot的消息则返回空字符串 # 如果最新五条消息中不包含bot的消息则返回空字符串
if not has_bot_message: if not has_bot_message:
core_dialogue_prompt = "" core_dialogue_prompt = ""
else: else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :] # 限制消息数量 core_dialogue_list = core_dialogue_list[
-int(global_config.chat.max_context_size * 0.6) :
] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages( core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list, core_dialogue_list,
replace_bot_name=True, replace_bot_name=True,
@ -630,12 +629,12 @@ class DefaultReplyer:
mai_think.sender = sender mai_think.sender = sender
mai_think.target = target mai_think.target = target
return mai_think return mai_think
async def build_actions_prompt(
async def build_actions_prompt(self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None) -> str: self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None
"""构建动作提示 ) -> str:
""" """构建动作提示"""
action_descriptions = "" action_descriptions = ""
if available_actions: if available_actions:
action_descriptions = "你可以做以下这些动作:\n" action_descriptions = "你可以做以下这些动作:\n"
@ -643,25 +642,24 @@ class DefaultReplyer:
action_description = action_info.description action_description = action_info.description
action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n" action_descriptions += "\n"
choosen_action_descriptions = "" choosen_action_descriptions = ""
if choosen_actions: if choosen_actions:
for action in choosen_actions: for action in choosen_actions:
action_name = action.get('action_type', 'unknown_action') action_name = action.get("action_type", "unknown_action")
if action_name =="reply": if action_name == "reply":
continue continue
action_description = action.get('reason', '无描述') action_description = action.get("reason", "无描述")
reasoning = action.get('reasoning', '无原因') reasoning = action.get("reasoning", "无原因")
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
if choosen_action_descriptions: if choosen_action_descriptions:
action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n" action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n"
action_descriptions += choosen_action_descriptions action_descriptions += choosen_action_descriptions
return action_descriptions return action_descriptions
async def build_prompt_reply_context( async def build_prompt_reply_context(
self, self,
extra_info: str = "", extra_info: str = "",
@ -691,41 +689,44 @@ class DefaultReplyer:
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform platform = chat_stream.platform
if reply_message: if reply_message:
user_id = reply_message.get("user_id","") user_id = reply_message.get("user_id", "")
person = Person(platform=platform, user_id=user_id) person = Person(platform=platform, user_id=user_id)
person_name = person.person_name or user_id person_name = person.person_name or user_id
sender = person_name sender = person_name
target = reply_message.get('processed_plain_text') target = reply_message.get("processed_plain_text")
else: else:
person_name = "用户" person_name = "用户"
sender = "用户" sender = "用户"
target = "消息" target = "消息"
if global_config.mood.enable_mood: if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id) chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state mood_prompt = chat_mood.mood_state
else: else:
mood_prompt = "" mood_prompt = ""
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
# TODO: 修复!
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
limit=global_config.chat.max_context_size * 1, limit=global_config.chat.max_context_size * 1,
) )
temp_msg_list_before_long = [msg.__dict__ for msg in message_list_before_now_long]
# TODO: 修复!
message_list_before_short = get_raw_msg_before_timestamp_with_chat( message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33), limit=int(global_config.chat.max_context_size * 0.33),
) )
temp_msg_list_before_short = [msg.__dict__ for msg in message_list_before_short]
chat_talking_prompt_short = build_readable_messages( chat_talking_prompt_short = build_readable_messages(
message_list_before_short, temp_msg_list_before_short,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="relative", timestamp_mode="relative",
@ -739,12 +740,12 @@ class DefaultReplyer:
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
), ),
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"), self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"), self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
), ),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"), self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions,choosen_actions), "actions_info"), self._time_and_run_task(self.build_actions_prompt(available_actions, choosen_actions), "actions_info"),
) )
# 任务名称中英文映射 # 任务名称中英文映射
@ -760,7 +761,7 @@ class DefaultReplyer:
# 处理结果 # 处理结果
timing_logs = [] timing_logs = []
results_dict = {} results_dict = {}
almost_zero_str = "" almost_zero_str = ""
for name, result, duration in task_results: for name, result, duration in task_results:
results_dict[name] = result results_dict[name] = result
@ -768,7 +769,7 @@ class DefaultReplyer:
if duration < 0.01: if duration < 0.01:
almost_zero_str += f"{chinese_name}," almost_zero_str += f"{chinese_name},"
continue continue
timing_logs.append(f"{chinese_name}: {duration:.1f}s") timing_logs.append(f"{chinese_name}: {duration:.1f}s")
if duration > 8: if duration > 8:
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s请使用更快的模型") logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s请使用更快的模型")
@ -791,9 +792,7 @@ class DefaultReplyer:
identity_block = await get_individuality().get_personality_block() identity_block = await get_individuality().get_personality_block()
moderation_prompt_block = ( moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
)
if sender: if sender:
if is_group_chat: if is_group_chat:
@ -801,7 +800,9 @@ class DefaultReplyer:
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。原因是{reply_reason}" f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。原因是{reply_reason}"
) )
else: # private chat else: # private chat
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。原因是{reply_reason}" reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。原因是{reply_reason}"
)
else: else:
reply_target_block = "" reply_target_block = ""
@ -821,10 +822,9 @@ class DefaultReplyer:
# "chat_target_private2", sender_name=chat_target_name # "chat_target_private2", sender_name=chat_target_name
# ) # )
# 构建分离的对话 prompt # 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts( core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
message_list_before_now_long, user_id, sender temp_msg_list_before_long, user_id, sender
) )
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform: if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
@ -846,7 +846,7 @@ class DefaultReplyer:
reply_style=global_config.personality.reply_style, reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt, keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
),selected_expressions ), selected_expressions
else: else:
return await global_prompt_manager.format_prompt( return await global_prompt_manager.format_prompt(
"replyer_prompt", "replyer_prompt",
@ -867,7 +867,7 @@ class DefaultReplyer:
reply_style=global_config.personality.reply_style, reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt, keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
),selected_expressions ), selected_expressions
async def build_prompt_rewrite_context( async def build_prompt_rewrite_context(
self, self,
@ -898,8 +898,10 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15), limit=min(int(global_config.chat.max_context_size * 0.33), 15),
) )
# TODO: 修复!
temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half]
chat_talking_prompt_half = build_readable_messages( chat_talking_prompt_half = build_readable_messages(
message_list_before_now_half, temp_msg_list_before_now_half,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="relative", timestamp_mode="relative",
@ -912,7 +914,6 @@ class DefaultReplyer:
self.build_expression_habits(chat_talking_prompt_half, target), self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(sender, target), self.build_relation_info(sender, target),
) )
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
@ -1024,7 +1025,9 @@ class DefaultReplyer:
else: else:
logger.debug(f"\n{prompt}\n") logger.debug(f"\n{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(prompt) content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
logger.debug(f"replyer生成内容: {content}") logger.debug(f"replyer生成内容: {content}")
return content, reasoning_content, model_name, tool_calls return content, reasoning_content, model_name, tool_calls
@ -1034,7 +1037,6 @@ class DefaultReplyer:
start_time = time.time() start_time = time.time()
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识 # 从LPMM知识库获取知识
try: try:

View File

@ -7,9 +7,10 @@ from rich.traceback import install
from src.config.config import global_config from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import ActionRecords from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images from src.common.database.database_model import Images
from src.person_info.person_info import Person,get_person_id from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
install(extra_lines=3) install(extra_lines=3)
@ -35,6 +36,7 @@ def replace_user_references_sync(
str: 处理后的内容字符串 str: 处理后的内容字符串
""" """
if name_resolver is None: if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str: def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己 # 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
@ -108,6 +110,7 @@ async def replace_user_references_async(
str: 处理后的内容字符串 str: 处理后的内容字符串
""" """
if name_resolver is None: if name_resolver is None:
async def default_resolver(platform: str, user_id: str) -> str: async def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己 # 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
@ -161,9 +164,7 @@ async def replace_user_references_async(
return content return content
def get_raw_msg_by_timestamp( def get_raw_msg_by_timestamp(timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"):
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
""" """
获取从指定时间戳到指定时间戳的消息按时间升序排序返回消息列表 获取从指定时间戳到指定时间戳的消息按时间升序排序返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
@ -183,7 +184,7 @@ def get_raw_msg_by_timestamp_with_chat(
limit_mode: str = "latest", limit_mode: str = "latest",
filter_bot=False, filter_bot=False,
filter_command=False, filter_command=False,
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 """获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest' limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest'
@ -209,7 +210,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
filter_bot=False, filter_bot=False,
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表 """获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest' limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest'
@ -218,7 +219,6 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
# 只有当 limit 为 0 时才应用外部 sort # 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages # 直接将 limit_mode 传递给 find_messages
return find_messages( return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
) )
@ -231,7 +231,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
person_ids: List[str], person_ids: List[str],
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 """获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest' limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest'
@ -302,7 +302,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
def get_raw_msg_by_timestamp_random( def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息 先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
""" """
@ -312,15 +312,15 @@ def get_raw_msg_by_timestamp_random(
return [] return []
# 随机选一条 # 随机选一条
msg = random.choice(all_msgs) msg = random.choice(all_msgs)
chat_id = msg["chat_id"] chat_id = msg.chat_id
timestamp_start = msg["time"] timestamp_start = msg.time
# 用 chat_id 获取该聊天在指定时间戳范围内的消息 # 用 chat_id 获取该聊天在指定时间戳范围内的消息
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
def get_raw_msg_by_timestamp_with_users( def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest' limit_mode: limit > 0 时生效 'earliest' 表示获取最早的记录 'latest' 表示获取最新的记录默认为 'latest'
@ -331,7 +331,7 @@ def get_raw_msg_by_timestamp_with_users(
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表 """获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
""" """
@ -340,7 +340,7 @@ def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表 """获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
""" """
@ -349,7 +349,7 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表 """获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
""" """

View File

@ -3,13 +3,15 @@ import re
import string import string
import time import time
import jieba import jieba
import json
import ast
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict, List, Any from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
@ -130,22 +132,29 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
return [] return []
who_chat_in_group = [] who_chat_in_group = []
for msg_db_data in recent_messages: for db_msg in recent_messages:
user_info = UserInfo.from_dict( # user_info = UserInfo.from_dict(
{ # {
"platform": msg_db_data["user_platform"], # "platform": msg_db_data["user_platform"],
"user_id": msg_db_data["user_id"], # "user_id": msg_db_data["user_id"],
"user_nickname": msg_db_data["user_nickname"], # "user_nickname": msg_db_data["user_nickname"],
"user_cardname": msg_db_data.get("user_cardname", ""), # "user_cardname": msg_db_data.get("user_cardname", ""),
} # }
) # )
# if (
# (user_info.platform, user_info.user_id) != sender
# and user_info.user_id != global_config.bot.qq_account
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
# and len(who_chat_in_group) < 5
# ): # 排除重复排除消息发送者排除bot限制加载的关系数目
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
if ( if (
(user_info.platform, user_info.user_id) != sender (db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and user_info.user_id != global_config.bot.qq_account and db_msg.user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5 and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目 ): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname)) who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
return who_chat_in_group return who_chat_in_group
@ -555,7 +564,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 获取消息内容计算总长度 # 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query) messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages) total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
return count, total_length return count, total_length
@ -628,41 +637,34 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
user_id: str = user_info.user_id # type: ignore user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info # Initialize target_info with basic info
target_info = { target_info = TargetPersonInfo(
"platform": platform, platform=platform,
"user_id": user_id, user_id=user_id,
"user_nickname": user_info.user_nickname, user_nickname=user_info.user_nickname, # type: ignore
"person_id": None, person_id=None,
"person_name": None, person_name=None
} )
# Try to fetch person info # Try to fetch person info
try: try:
# Assume get_person_id is sync (as per original code), keep using to_thread
person = Person(platform=platform, user_id=user_id) person = Person(platform=platform, user_id=user_id)
if not person.is_known: if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识") logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None # 如果用户尚未认识则返回False和None
return False, None return False, None
person_id = person.person_id if person.person_id:
person_name = None target_info.person_id = person.person_id
if person_id: target_info.person_name = person.person_name
# get_value is async, so await it directly
person_name = person.person_name
target_info["person_id"] = person_id
target_info["person_name"] = person_name
except Exception as person_e: except Exception as person_e:
logger.warning( logger.warning(
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}" f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
) )
chat_target_info = target_info chat_target_info = target_info.__dict__
else: else:
logger.warning(f"无法获取 chat_stream for {chat_id} in utils") logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
except Exception as e: except Exception as e:
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True) logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
# Keep defaults on error
return is_group_chat, chat_target_info return is_group_chat, chat_target_info
@ -771,6 +773,7 @@ def assign_message_ids_flexible(
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}] # # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]: def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
""" """
统一的关键词解析函数支持多种格式的关键词字符串解析 统一的关键词解析函数支持多种格式的关键词字符串解析
@ -802,7 +805,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
try: try:
# 尝试作为JSON对象解析支持 {"keywords": [...]} 格式) # 尝试作为JSON对象解析支持 {"keywords": [...]} 格式)
import json
json_data = json.loads(keywords_str) json_data = json.loads(keywords_str)
if isinstance(json_data, dict) and "keywords" in json_data: if isinstance(json_data, dict) and "keywords" in json_data:
keywords_list = json_data["keywords"] keywords_list = json_data["keywords"]
@ -816,7 +818,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
try: try:
# 尝试使用 ast.literal_eval 解析支持Python字面量格式 # 尝试使用 ast.literal_eval 解析支持Python字面量格式
import ast
parsed = ast.literal_eval(keywords_str) parsed = ast.literal_eval(keywords_str)
if isinstance(parsed, list): if isinstance(parsed, list):
return [str(k).strip() for k in parsed if str(k).strip()] return [str(k).strip() for k in parsed if str(k).strip()]

View File

@ -1,13 +1,10 @@
from enum import Enum from typing import Optional
from typing import Optional, Union, Dict, Any, Tuple, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
@dataclass @dataclass
class DatabaseUserInfo: class DatabaseUserInfo:
user_platform: str = field(default_factory=str) platform: str = field(default_factory=str)
user_id: str = field(default_factory=str) user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str) user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None user_cardname: Optional[str] = None
@ -84,17 +81,21 @@ class DatabaseMessages:
user_id=self.user_id, user_id=self.user_id,
user_nickname=self.user_nickname, user_nickname=self.user_nickname,
user_cardname=self.user_cardname, user_cardname=self.user_cardname,
user_platform=self.user_platform, platform=self.user_platform,
) )
if not (self.chat_info_group_id and self.chat_info_group_name): if self.chat_info_group_id and self.chat_info_group_name:
self.group_info = None self.group_info = DatabaseGroupInfo(
group_id=self.chat_info_group_id,
group_name=self.chat_info_group_name,
group_platform=self.chat_info_group_platform,
)
chat_user_info = DatabaseUserInfo( chat_user_info = DatabaseUserInfo(
user_id=self.chat_info_user_id, user_id=self.chat_info_user_id,
user_nickname=self.chat_info_user_nickname, user_nickname=self.chat_info_user_nickname,
user_cardname=self.chat_info_user_cardname, user_cardname=self.chat_info_user_cardname,
user_platform=self.chat_info_user_platform, platform=self.chat_info_user_platform,
) )
self.chat_info = DatabaseChatInfo( self.chat_info = DatabaseChatInfo(
stream_id=self.chat_info_stream_id, stream_id=self.chat_info_stream_id,

View File

@ -0,0 +1,10 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TargetPersonInfo:
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
person_id: Optional[str] = None
person_name: Optional[str] = None

View File

@ -2,19 +2,20 @@ import traceback
from typing import List, Any, Optional from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入 from peewee import Model # 添加 Peewee Model 导入
from src.config.config import global_config
from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import Messages from src.common.database.database_model import Messages
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]: def _model_to_instance(model_instance: Model) -> DatabaseMessages:
""" """
Peewee 模型实例转换为字典 Peewee 模型实例转换为字典
""" """
return model_instance.__data__ return DatabaseMessages(**model_instance.__data__)
def find_messages( def find_messages(
@ -24,7 +25,7 @@ def find_messages(
limit_mode: str = "latest", limit_mode: str = "latest",
filter_bot=False, filter_bot=False,
filter_command=False, filter_command=False,
) -> List[dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
根据提供的过滤器排序和限制条件查找消息 根据提供的过滤器排序和限制条件查找消息
@ -112,7 +113,7 @@ def find_messages(
query = query.order_by(*peewee_sort_terms) query = query.order_by(*peewee_sort_terms)
peewee_results = list(query) peewee_results = list(query)
return [_model_to_dict(msg) for msg in peewee_results] return [_model_to_instance(msg) for msg in peewee_results]
except Exception as e: except Exception as e:
log_message = ( log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"

View File

@ -163,8 +163,10 @@ class ChatAction:
limit=15, limit=15,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
message_list_before_now, tmp_msgs,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@ -227,8 +229,10 @@ class ChatAction:
limit=10, limit=10,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
message_list_before_now, tmp_msgs,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",

View File

@ -166,8 +166,10 @@ class ChatMood:
limit=10, limit=10,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
message_list_before_now, tmp_msgs,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@ -245,8 +247,10 @@ class ChatMood:
limit=5, limit=5,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
message_list_before_now, tmp_msgs,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",

View File

@ -187,22 +187,23 @@ class PromptBuilder:
bot_id = str(global_config.bot.qq_account) bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id) target_user_id = str(message.chat_stream.user_info.user_id)
for msg_dict in message_list_before_now: # TODO: 修复之!
for msg in message_list_before_now:
try: try:
msg_user_id = str(msg_dict.get("user_id")) msg_user_id = str(msg.user_info.user_id)
if msg_user_id == bot_id: if msg_user_id == bot_id:
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"): if msg.reply_to and talk_type == msg.reply_to:
core_dialogue_list.append(msg_dict) core_dialogue_list.append(msg.__dict__)
elif msg_dict.get("reply_to") and talk_type != msg_dict.get("reply_to"): elif msg.reply_to and talk_type != msg.reply_to:
background_dialogue_list.append(msg_dict) background_dialogue_list.append(msg.__dict__)
# else: # else:
# background_dialogue_list.append(msg_dict) # background_dialogue_list.append(msg_dict)
elif msg_user_id == target_user_id: elif msg_user_id == target_user_id:
core_dialogue_list.append(msg_dict) core_dialogue_list.append(msg.__dict__)
else: else:
background_dialogue_list.append(msg_dict) background_dialogue_list.append(msg.__dict__)
except Exception as e: except Exception as e:
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
background_dialogue_prompt = "" background_dialogue_prompt = ""
if background_dialogue_list: if background_dialogue_list:
@ -257,8 +258,10 @@ class PromptBuilder:
timestamp=time.time(), timestamp=time.time(),
limit=20, limit=20,
) )
# TODO: 修复!
tmp_msgs = [msg.__dict__ for msg in all_dialogue_prompt]
all_dialogue_prompt_str = build_readable_messages( all_dialogue_prompt_str = build_readable_messages(
all_dialogue_prompt, tmp_msgs,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
show_pic=False, show_pic=False,
) )

View File

@ -99,8 +99,10 @@ class ChatMood:
limit=int(global_config.chat.max_context_size / 3), limit=int(global_config.chat.max_context_size / 3),
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
message_list_before_now, tmp_msgs,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@ -148,8 +150,10 @@ class ChatMood:
limit=15, limit=15,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
message_list_before_now, tmp_msgs,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",

View File

@ -7,7 +7,7 @@ from typing import List, Dict, Any
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import Person,get_person_id from src.person_info.person_info import Person, get_person_id
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,
@ -129,7 +129,7 @@ class RelationshipBuilder:
# 获取该消息前5条消息的时间作为潜在的开始时间 # 获取该消息前5条消息的时间作为潜在的开始时间
before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5) before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5)
if before_messages: if before_messages:
potential_start_time = before_messages[0]["time"] potential_start_time = before_messages[0].time
else: else:
potential_start_time = message_time potential_start_time = message_time
@ -175,7 +175,7 @@ class RelationshipBuilder:
) )
if after_messages and len(after_messages) >= 5: if after_messages and len(after_messages) >= 5:
# 如果有足够的后续消息使用第5条消息的时间作为结束时间 # 如果有足够的后续消息使用第5条消息的时间作为结束时间
last_segment["end_time"] = after_messages[4]["time"] last_segment["end_time"] = after_messages[4].time
# 重新计算当前消息段的消息数量 # 重新计算当前消息段的消息数量
last_segment["message_count"] = self._count_messages_in_timerange( last_segment["message_count"] = self._count_messages_in_timerange(
@ -300,7 +300,6 @@ class RelationshipBuilder:
return cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0 return cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0
def get_cache_status(self) -> str: def get_cache_status(self) -> str:
# sourcery skip: merge-list-append, merge-list-appends-into-extend # sourcery skip: merge-list-append, merge-list-appends-into-extend
"""获取缓存状态信息,用于调试和监控""" """获取缓存状态信息,用于调试和监控"""
@ -342,13 +341,12 @@ class RelationshipBuilder:
# 统筹各模块协作、对外提供服务接口 # 统筹各模块协作、对外提供服务接口
# ================================ # ================================
async def build_relation(self,immediate_build: str = "",max_build_threshold: int = MAX_MESSAGE_COUNT): async def build_relation(self, immediate_build: str = "", max_build_threshold: int = MAX_MESSAGE_COUNT):
"""构建关系 """构建关系
immediate_build: 立即构建关系可选值为"all"或person_id immediate_build: 立即构建关系可选值为"all"或person_id
""" """
self._cleanup_old_segments() self._cleanup_old_segments()
current_time = time.time() current_time = time.time()
if latest_messages := get_raw_msg_by_timestamp_with_chat( if latest_messages := get_raw_msg_by_timestamp_with_chat(
self.chat_id, self.chat_id,
@ -358,9 +356,9 @@ class RelationshipBuilder:
): ):
# 处理所有新的非bot消息 # 处理所有新的非bot消息
for latest_msg in latest_messages: for latest_msg in latest_messages:
user_id = latest_msg.get("user_id") user_id = latest_msg.user_info.user_id
platform = latest_msg.get("user_platform") or latest_msg.get("chat_info_platform") platform = latest_msg.user_info.platform or latest_msg.chat_info.platform
msg_time = latest_msg.get("time", 0) msg_time = latest_msg.time
if ( if (
user_id user_id
@ -383,8 +381,10 @@ class RelationshipBuilder:
if not person.is_known: if not person.is_known:
continue continue
person_name = person.person_name or person_id person_name = person.person_name or person_id
if total_message_count >= max_build_threshold or (total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")): if total_message_count >= max_build_threshold or (
total_message_count >= 5 and immediate_build in [person_id, "all"]
):
users_to_build_relationship.append(person_id) users_to_build_relationship.append(person_id)
logger.info( logger.info(
f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}" f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
@ -400,12 +400,11 @@ class RelationshipBuilder:
segments = self.person_engaged_cache[person_id] segments = self.person_engaged_cache[person_id]
# 异步执行关系构建 # 异步执行关系构建
person = Person(person_id=person_id) person = Person(person_id=person_id)
if person.is_known: if person.is_known:
asyncio.create_task(self.update_impression_on_segments(person_id, self.chat_id, segments)) asyncio.create_task(self.update_impression_on_segments(person_id, self.chat_id, segments))
# 移除已处理的用户缓存 # 移除已处理的用户缓存
del self.person_engaged_cache[person_id] del self.person_engaged_cache[person_id]
self._save_cache() self._save_cache()
# ================================ # ================================
# 关系构建模块 # 关系构建模块
@ -458,7 +457,7 @@ class RelationshipBuilder:
"user_cardname": "", "user_cardname": "",
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...", "display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
"is_action_record": True, "is_action_record": True,
"chat_info_platform": segment_messages[0].get("chat_info_platform", ""), "chat_info_platform": segment_messages[0].chat_info.platform or "",
"chat_id": chat_id, "chat_id": chat_id,
} }
processed_messages.append(gap_message) processed_messages.append(gap_message)

View File

@ -8,9 +8,10 @@
readable_text = message_api.build_readable_messages(messages) readable_text = message_api.build_readable_messages(messages)
""" """
from typing import List, Dict, Any, Tuple, Optional
from src.config.config import global_config
import time import time
from typing import List, Dict, Any, Tuple, Optional
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp, get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,
@ -36,7 +37,7 @@ from src.chat.utils.chat_message_builder import (
def get_messages_by_time( def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定时间范围内的消息 获取指定时间范围内的消息
@ -70,7 +71,7 @@ def get_messages_by_time_in_chat(
limit_mode: str = "latest", limit_mode: str = "latest",
filter_mai: bool = False, filter_mai: bool = False,
filter_command: bool = False, filter_command: bool = False,
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定聊天中指定时间范围内的消息 获取指定聊天中指定时间范围内的消息
@ -97,7 +98,9 @@ def get_messages_by_time_in_chat(
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
if filter_mai: if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)) return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
)
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
@ -109,7 +112,7 @@ def get_messages_by_time_in_chat_inclusive(
limit_mode: str = "latest", limit_mode: str = "latest",
filter_mai: bool = False, filter_mai: bool = False,
filter_command: bool = False, filter_command: bool = False,
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定聊天中指定时间范围内的消息包含边界 获取指定聊天中指定时间范围内的消息包含边界
@ -137,9 +140,13 @@ def get_messages_by_time_in_chat_inclusive(
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
if filter_mai: if filter_mai:
return filter_mai_messages( return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command) get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id, start_time, end_time, limit, limit_mode, filter_command
)
) )
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command) return get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id, start_time, end_time, limit, limit_mode, filter_command
)
def get_messages_by_time_in_chat_for_users( def get_messages_by_time_in_chat_for_users(
@ -149,7 +156,7 @@ def get_messages_by_time_in_chat_for_users(
person_ids: List[str], person_ids: List[str],
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定聊天中指定用户在指定时间范围内的消息 获取指定聊天中指定用户在指定时间范围内的消息
@ -180,7 +187,7 @@ def get_messages_by_time_in_chat_for_users(
def get_random_chat_messages( def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
随机选择一个聊天返回该聊天在指定时间范围内的消息 随机选择一个聊天返回该聊天在指定时间范围内的消息
@ -208,7 +215,7 @@ def get_random_chat_messages(
def get_messages_by_time_for_users( def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定用户在所有聊天中指定时间范围内的消息 获取指定用户在所有聊天中指定时间范围内的消息
@ -232,7 +239,7 @@ def get_messages_by_time_for_users(
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]: def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
""" """
获取指定时间戳之前的消息 获取指定时间戳之前的消息
@ -258,7 +265,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
def get_messages_before_time_in_chat( def get_messages_before_time_in_chat(
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定聊天中指定时间戳之前的消息 获取指定聊天中指定时间戳之前的消息
@ -287,7 +294,7 @@ def get_messages_before_time_in_chat(
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]: def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[DatabaseMessages]:
""" """
获取指定用户在指定时间戳之前的消息 获取指定用户在指定时间戳之前的消息
@ -311,7 +318,7 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: List[str],
def get_recent_messages( def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]: ) -> List[DatabaseMessages]:
""" """
获取指定聊天中最近一段时间的消息 获取指定聊天中最近一段时间的消息
@ -472,7 +479,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# ============================================================================= # =============================================================================
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
""" """
从消息列表中移除麦麦的消息 从消息列表中移除麦麦的消息
Args: Args:
@ -480,4 +487,4 @@ def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
Returns: Returns:
过滤后的消息列表 过滤后的消息列表
""" """
return [msg for msg in messages if msg.get("user_id") != str(global_config.bot.qq_account)] return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]

View File

@ -85,8 +85,10 @@ class EmojiAction(BaseAction):
messages_text = "" messages_text = ""
if recent_messages: if recent_messages:
# 使用message_api构建可读的消息字符串 # 使用message_api构建可读的消息字符串
# TODO: 修复
tmp_msgs = [msg.__dict__ for msg in recent_messages]
messages_text = message_api.build_readable_messages( messages_text = message_api.build_readable_messages(
messages=recent_messages, messages=tmp_msgs,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
truncate=False, truncate=False,
show_actions=False, show_actions=False,