mirror of https://github.com/Mai-with-u/MaiBot.git
数据库的信息重构为dataclass
parent
d74beef4b4
commit
3481234d2b
|
|
@ -285,10 +285,11 @@ class HeartFChatting:
|
||||||
filter_mai=True,
|
filter_mai=True,
|
||||||
filter_command=True,
|
filter_command=True,
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
temp_recent_messages_dict = [msg.__dict__ for msg in recent_messages_dict]
|
||||||
# 统一的消息处理逻辑
|
# 统一的消息处理逻辑
|
||||||
should_process,interest_value = await self._should_process_messages(recent_messages_dict)
|
should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
|
||||||
|
|
||||||
if should_process:
|
if should_process:
|
||||||
self.last_read_time = time.time()
|
self.last_read_time = time.time()
|
||||||
await self._observe(interest_value = interest_value)
|
await self._observe(interest_value = interest_value)
|
||||||
|
|
|
||||||
|
|
@ -346,13 +346,15 @@ class ExpressionLearner:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 获取上次学习时间
|
# 获取上次学习时间
|
||||||
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive(
|
random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_learning_time,
|
timestamp_start=self.last_learning_time,
|
||||||
timestamp_end=current_time,
|
timestamp_end=current_time,
|
||||||
limit=num,
|
limit=num,
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
random_msg: Optional[List[Dict[str, Any]]] = [msg.__dict__ for msg in random_msg_temp] if random_msg_temp else None
|
||||||
|
|
||||||
# print(random_msg)
|
# print(random_msg)
|
||||||
if not random_msg or random_msg == []:
|
if not random_msg or random_msg == []:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from rich.traceback import install
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
|
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
|
|
@ -1495,13 +1496,13 @@ class MemoryBuilder:
|
||||||
timestamp_end=current_time,
|
timestamp_end=current_time,
|
||||||
limit=threshold,
|
limit=threshold,
|
||||||
)
|
)
|
||||||
|
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
|
||||||
if messages:
|
if messages:
|
||||||
# 更新最后处理时间
|
# 更新最后处理时间
|
||||||
self.last_processed_time = current_time
|
self.last_processed_time = current_time
|
||||||
self.last_update_time = current_time
|
self.last_update_time = current_time
|
||||||
|
|
||||||
return messages or []
|
return tmp_msg or []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,8 +70,10 @@ class ActionModifier:
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half]
|
||||||
chat_content = build_readable_messages(
|
chat_content = build_readable_messages(
|
||||||
message_list_before_now_half,
|
temp_msg_list_before_now_half,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,7 @@ class ActionPlanner:
|
||||||
self.max_plan_retries = 3
|
self.max_plan_retries = 3
|
||||||
|
|
||||||
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
|
# sourcery skip: use-next
|
||||||
"""
|
"""
|
||||||
根据message_id从message_id_list中查找对应的原始消息
|
根据message_id从message_id_list中查找对应的原始消息
|
||||||
|
|
||||||
|
|
@ -120,10 +121,7 @@ class ActionPlanner:
|
||||||
Returns:
|
Returns:
|
||||||
最新的消息字典,如果列表为空则返回None
|
最新的消息字典,如果列表为空则返回None
|
||||||
"""
|
"""
|
||||||
if not message_id_list:
|
return message_id_list[-1].get("message") if message_id_list else None
|
||||||
return None
|
|
||||||
# 假设消息列表是按时间顺序排列的,最后一个是最新的
|
|
||||||
return message_id_list[-1].get("message")
|
|
||||||
|
|
||||||
async def plan(
|
async def plan(
|
||||||
self,
|
self,
|
||||||
|
|
@ -208,22 +206,17 @@ class ActionPlanner:
|
||||||
if target_message is None:
|
if target_message is None:
|
||||||
self.plan_retry_count += 1
|
self.plan_retry_count += 1
|
||||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
|
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
|
||||||
|
# 仍有重试次数
|
||||||
# 如果连续三次plan均为None,输出error并选取最新消息
|
if self.plan_retry_count < self.max_plan_retries:
|
||||||
if self.plan_retry_count >= self.max_plan_retries:
|
|
||||||
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message")
|
|
||||||
target_message = self.get_latest_message(message_id_list)
|
|
||||||
self.plan_retry_count = 0 # 重置计数器
|
|
||||||
else:
|
|
||||||
# 递归重新plan
|
# 递归重新plan
|
||||||
return await self.plan(mode, loop_start_time, available_actions)
|
return await self.plan(mode, loop_start_time, available_actions)
|
||||||
else:
|
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message")
|
||||||
# 成功获取到target_message,重置计数器
|
target_message = self.get_latest_message(message_id_list)
|
||||||
self.plan_retry_count = 0
|
self.plan_retry_count = 0 # 重置计数器
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if action != "no_reply" and action != "reply" and action not in current_available_actions:
|
if action != "no_reply" and action != "reply" and action not in current_available_actions:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -247,28 +240,27 @@ class ActionPlanner:
|
||||||
is_parallel = False
|
is_parallel = False
|
||||||
if mode == ChatMode.NORMAL and action in current_available_actions:
|
if mode == ChatMode.NORMAL and action in current_available_actions:
|
||||||
is_parallel = current_available_actions[action].parallel_action
|
is_parallel = current_available_actions[action].parallel_action
|
||||||
|
|
||||||
|
|
||||||
action_data["loop_start_time"] = loop_start_time
|
action_data["loop_start_time"] = loop_start_time
|
||||||
|
|
||||||
actions = []
|
actions = [
|
||||||
|
{
|
||||||
# 1. 添加Planner取得的动作
|
"action_type": action,
|
||||||
actions.append({
|
"reasoning": reasoning,
|
||||||
"action_type": action,
|
"action_data": action_data,
|
||||||
"reasoning": reasoning,
|
"action_message": target_message,
|
||||||
"action_data": action_data,
|
"available_actions": available_actions,
|
||||||
"action_message": target_message,
|
}
|
||||||
"available_actions": available_actions # 添加这个字段
|
]
|
||||||
})
|
|
||||||
|
|
||||||
if action != "reply" and is_parallel:
|
if action != "reply" and is_parallel:
|
||||||
actions.append({
|
actions.append({
|
||||||
"action_type": "reply",
|
"action_type": "reply",
|
||||||
"action_message": target_message,
|
"action_message": target_message,
|
||||||
"available_actions": available_actions
|
"available_actions": available_actions
|
||||||
})
|
})
|
||||||
|
|
||||||
return actions,target_message
|
return actions,target_message
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -288,9 +280,10 @@ class ActionPlanner:
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.6),
|
limit=int(global_config.chat.max_context_size * 0.6),
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
temp_msg_list_before_now = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||||
messages=message_list_before_now,
|
messages=temp_msg_list_before_now,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=self.last_obs_time_mark,
|
read_mark=self.last_obs_time_mark,
|
||||||
truncate=True,
|
truncate=True,
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ def init_prompt():
|
||||||
""",
|
""",
|
||||||
"replyer_prompt",
|
"replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
{expression_habits_block}{tool_info_block}
|
{expression_habits_block}{tool_info_block}
|
||||||
|
|
@ -116,7 +116,6 @@ def init_prompt():
|
||||||
""",
|
""",
|
||||||
"replyer_self_prompt",
|
"replyer_self_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
|
|
@ -179,7 +178,7 @@ class DefaultReplyer:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
|
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt = None
|
prompt = None
|
||||||
selected_expressions = None
|
selected_expressions = None
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
|
|
@ -187,7 +186,7 @@ class DefaultReplyer:
|
||||||
try:
|
try:
|
||||||
# 3. 构建 Prompt
|
# 3. 构建 Prompt
|
||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
prompt,selected_expressions = await self.build_prompt_reply_context(
|
prompt, selected_expressions = await self.build_prompt_reply_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
choosen_actions=choosen_actions,
|
choosen_actions=choosen_actions,
|
||||||
|
|
@ -294,12 +293,12 @@ class DefaultReplyer:
|
||||||
async def build_relation_info(self, sender: str, target: str):
|
async def build_relation_info(self, sender: str, target: str):
|
||||||
if not global_config.relationship.enable_relationship:
|
if not global_config.relationship.enable_relationship:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
if sender == global_config.bot.nickname:
|
if sender == global_config.bot.nickname:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 获取用户ID
|
# 获取用户ID
|
||||||
person = Person(person_name = sender)
|
person = Person(person_name=sender)
|
||||||
if not is_person_known(person_name=sender):
|
if not is_person_known(person_name=sender):
|
||||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|
@ -307,6 +306,7 @@ class DefaultReplyer:
|
||||||
return person.build_relationship(points_num=5)
|
return person.build_relationship(points_num=5)
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
||||||
|
# sourcery skip: for-append-to-extend
|
||||||
"""构建表达习惯块
|
"""构建表达习惯块
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -359,7 +359,7 @@ class DefaultReplyer:
|
||||||
Returns:
|
Returns:
|
||||||
str: 记忆信息字符串
|
str: 记忆信息字符串
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not global_config.memory.enable_memory:
|
if not global_config.memory.enable_memory:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -368,7 +368,6 @@ class DefaultReplyer:
|
||||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||||
target_message=target, chat_history_prompt=chat_history
|
target_message=target, chat_history_prompt=chat_history
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if global_config.memory.enable_instant_memory:
|
if global_config.memory.enable_instant_memory:
|
||||||
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
|
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
|
||||||
|
|
@ -379,10 +378,9 @@ class DefaultReplyer:
|
||||||
if not running_memories:
|
if not running_memories:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||||
for running_memory in running_memories:
|
for running_memory in running_memories:
|
||||||
keywords,content = running_memory
|
keywords, content = running_memory
|
||||||
memory_str += f"- {keywords}:{content}\n"
|
memory_str += f"- {keywords}:{content}\n"
|
||||||
|
|
||||||
if instant_memory:
|
if instant_memory:
|
||||||
|
|
@ -405,7 +403,6 @@ class DefaultReplyer:
|
||||||
if not enable_tool:
|
if not enable_tool:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用工具执行器获取信息
|
# 使用工具执行器获取信息
|
||||||
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||||
|
|
@ -559,16 +556,18 @@ class DefaultReplyer:
|
||||||
# 检查最新五条消息中是否包含bot自己说的消息
|
# 检查最新五条消息中是否包含bot自己说的消息
|
||||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||||
|
|
||||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||||
|
|
||||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||||
if not has_bot_message:
|
if not has_bot_message:
|
||||||
core_dialogue_prompt = ""
|
core_dialogue_prompt = ""
|
||||||
else:
|
else:
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :] # 限制消息数量
|
core_dialogue_list = core_dialogue_list[
|
||||||
|
-int(global_config.chat.max_context_size * 0.6) :
|
||||||
|
] # 限制消息数量
|
||||||
|
|
||||||
core_dialogue_prompt_str = build_readable_messages(
|
core_dialogue_prompt_str = build_readable_messages(
|
||||||
core_dialogue_list,
|
core_dialogue_list,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
|
|
@ -630,12 +629,12 @@ class DefaultReplyer:
|
||||||
mai_think.sender = sender
|
mai_think.sender = sender
|
||||||
mai_think.target = target
|
mai_think.target = target
|
||||||
return mai_think
|
return mai_think
|
||||||
|
|
||||||
|
async def build_actions_prompt(
|
||||||
async def build_actions_prompt(self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None) -> str:
|
self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None
|
||||||
"""构建动作提示
|
) -> str:
|
||||||
"""
|
"""构建动作提示"""
|
||||||
|
|
||||||
action_descriptions = ""
|
action_descriptions = ""
|
||||||
if available_actions:
|
if available_actions:
|
||||||
action_descriptions = "你可以做以下这些动作:\n"
|
action_descriptions = "你可以做以下这些动作:\n"
|
||||||
|
|
@ -643,25 +642,24 @@ class DefaultReplyer:
|
||||||
action_description = action_info.description
|
action_description = action_info.description
|
||||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||||
action_descriptions += "\n"
|
action_descriptions += "\n"
|
||||||
|
|
||||||
choosen_action_descriptions = ""
|
choosen_action_descriptions = ""
|
||||||
if choosen_actions:
|
if choosen_actions:
|
||||||
for action in choosen_actions:
|
for action in choosen_actions:
|
||||||
action_name = action.get('action_type', 'unknown_action')
|
action_name = action.get("action_type", "unknown_action")
|
||||||
if action_name =="reply":
|
if action_name == "reply":
|
||||||
continue
|
continue
|
||||||
action_description = action.get('reason', '无描述')
|
action_description = action.get("reason", "无描述")
|
||||||
reasoning = action.get('reasoning', '无原因')
|
reasoning = action.get("reasoning", "无原因")
|
||||||
|
|
||||||
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
||||||
|
|
||||||
if choosen_action_descriptions:
|
if choosen_action_descriptions:
|
||||||
action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n"
|
action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n"
|
||||||
action_descriptions += choosen_action_descriptions
|
action_descriptions += choosen_action_descriptions
|
||||||
|
|
||||||
return action_descriptions
|
return action_descriptions
|
||||||
|
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
self,
|
self,
|
||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
|
|
@ -691,41 +689,44 @@ class DefaultReplyer:
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
platform = chat_stream.platform
|
platform = chat_stream.platform
|
||||||
|
|
||||||
if reply_message:
|
if reply_message:
|
||||||
user_id = reply_message.get("user_id","")
|
user_id = reply_message.get("user_id", "")
|
||||||
person = Person(platform=platform, user_id=user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
person_name = person.person_name or user_id
|
person_name = person.person_name or user_id
|
||||||
sender = person_name
|
sender = person_name
|
||||||
target = reply_message.get('processed_plain_text')
|
target = reply_message.get("processed_plain_text")
|
||||||
else:
|
else:
|
||||||
person_name = "用户"
|
person_name = "用户"
|
||||||
sender = "用户"
|
sender = "用户"
|
||||||
target = "消息"
|
target = "消息"
|
||||||
|
|
||||||
|
|
||||||
if global_config.mood.enable_mood:
|
if global_config.mood.enable_mood:
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||||
mood_prompt = chat_mood.mood_state
|
mood_prompt = chat_mood.mood_state
|
||||||
else:
|
else:
|
||||||
mood_prompt = ""
|
mood_prompt = ""
|
||||||
|
|
||||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||||
|
|
||||||
|
# TODO: 修复!
|
||||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size * 1,
|
limit=global_config.chat.max_context_size * 1,
|
||||||
)
|
)
|
||||||
|
temp_msg_list_before_long = [msg.__dict__ for msg in message_list_before_now_long]
|
||||||
|
|
||||||
|
# TODO: 修复!
|
||||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.33),
|
limit=int(global_config.chat.max_context_size * 0.33),
|
||||||
)
|
)
|
||||||
|
temp_msg_list_before_short = [msg.__dict__ for msg in message_list_before_short]
|
||||||
|
|
||||||
chat_talking_prompt_short = build_readable_messages(
|
chat_talking_prompt_short = build_readable_messages(
|
||||||
message_list_before_short,
|
temp_msg_list_before_short,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
|
|
@ -739,12 +740,12 @@ class DefaultReplyer:
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
|
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
|
||||||
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||||
self._time_and_run_task(self.build_actions_prompt(available_actions,choosen_actions), "actions_info"),
|
self._time_and_run_task(self.build_actions_prompt(available_actions, choosen_actions), "actions_info"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 任务名称中英文映射
|
# 任务名称中英文映射
|
||||||
|
|
@ -760,7 +761,7 @@ class DefaultReplyer:
|
||||||
# 处理结果
|
# 处理结果
|
||||||
timing_logs = []
|
timing_logs = []
|
||||||
results_dict = {}
|
results_dict = {}
|
||||||
|
|
||||||
almost_zero_str = ""
|
almost_zero_str = ""
|
||||||
for name, result, duration in task_results:
|
for name, result, duration in task_results:
|
||||||
results_dict[name] = result
|
results_dict[name] = result
|
||||||
|
|
@ -768,7 +769,7 @@ class DefaultReplyer:
|
||||||
if duration < 0.01:
|
if duration < 0.01:
|
||||||
almost_zero_str += f"{chinese_name},"
|
almost_zero_str += f"{chinese_name},"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||||
if duration > 8:
|
if duration > 8:
|
||||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||||
|
|
@ -791,9 +792,7 @@ class DefaultReplyer:
|
||||||
|
|
||||||
identity_block = await get_individuality().get_personality_block()
|
identity_block = await get_individuality().get_personality_block()
|
||||||
|
|
||||||
moderation_prompt_block = (
|
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
|
||||||
)
|
|
||||||
|
|
||||||
if sender:
|
if sender:
|
||||||
if is_group_chat:
|
if is_group_chat:
|
||||||
|
|
@ -801,7 +800,9 @@ class DefaultReplyer:
|
||||||
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。原因是{reply_reason}"
|
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。原因是{reply_reason}"
|
||||||
)
|
)
|
||||||
else: # private chat
|
else: # private chat
|
||||||
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。原因是{reply_reason}"
|
reply_target_block = (
|
||||||
|
f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。原因是{reply_reason}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reply_target_block = ""
|
reply_target_block = ""
|
||||||
|
|
||||||
|
|
@ -821,10 +822,9 @@ class DefaultReplyer:
|
||||||
# "chat_target_private2", sender_name=chat_target_name
|
# "chat_target_private2", sender_name=chat_target_name
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
# 构建分离的对话 prompt
|
# 构建分离的对话 prompt
|
||||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||||
message_list_before_now_long, user_id, sender
|
temp_msg_list_before_long, user_id, sender
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||||
|
|
@ -846,7 +846,7 @@ class DefaultReplyer:
|
||||||
reply_style=global_config.personality.reply_style,
|
reply_style=global_config.personality.reply_style,
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
),selected_expressions
|
), selected_expressions
|
||||||
else:
|
else:
|
||||||
return await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
"replyer_prompt",
|
"replyer_prompt",
|
||||||
|
|
@ -867,7 +867,7 @@ class DefaultReplyer:
|
||||||
reply_style=global_config.personality.reply_style,
|
reply_style=global_config.personality.reply_style,
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
),selected_expressions
|
), selected_expressions
|
||||||
|
|
||||||
async def build_prompt_rewrite_context(
|
async def build_prompt_rewrite_context(
|
||||||
self,
|
self,
|
||||||
|
|
@ -898,8 +898,10 @@ class DefaultReplyer:
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half]
|
||||||
chat_talking_prompt_half = build_readable_messages(
|
chat_talking_prompt_half = build_readable_messages(
|
||||||
message_list_before_now_half,
|
temp_msg_list_before_now_half,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
|
|
@ -912,7 +914,6 @@ class DefaultReplyer:
|
||||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||||
self.build_relation_info(sender, target),
|
self.build_relation_info(sender, target),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||||
|
|
||||||
|
|
@ -1024,7 +1025,9 @@ class DefaultReplyer:
|
||||||
else:
|
else:
|
||||||
logger.debug(f"\n{prompt}\n")
|
logger.debug(f"\n{prompt}\n")
|
||||||
|
|
||||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(prompt)
|
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||||
|
prompt
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"replyer生成内容: {content}")
|
logger.debug(f"replyer生成内容: {content}")
|
||||||
return content, reasoning_content, model_name, tool_calls
|
return content, reasoning_content, model_name, tool_calls
|
||||||
|
|
@ -1034,7 +1037,6 @@ class DefaultReplyer:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
||||||
|
|
||||||
|
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
# 从LPMM知识库获取知识
|
# 从LPMM知识库获取知识
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,10 @@ from rich.traceback import install
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecords
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database_model import Images
|
||||||
from src.person_info.person_info import Person,get_person_id
|
from src.person_info.person_info import Person, get_person_id
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
@ -35,6 +36,7 @@ def replace_user_references_sync(
|
||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
|
|
||||||
def default_resolver(platform: str, user_id: str) -> str:
|
def default_resolver(platform: str, user_id: str) -> str:
|
||||||
# 检查是否是机器人自己
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
|
|
@ -108,6 +110,7 @@ async def replace_user_references_async(
|
||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
|
|
||||||
async def default_resolver(platform: str, user_id: str) -> str:
|
async def default_resolver(platform: str, user_id: str) -> str:
|
||||||
# 检查是否是机器人自己
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
|
|
@ -161,9 +164,7 @@ async def replace_user_references_async(
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp(
|
def get_raw_msg_by_timestamp(timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"):
|
||||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
|
|
@ -183,7 +184,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
filter_bot=False,
|
filter_bot=False,
|
||||||
filter_command=False,
|
filter_command=False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||||
|
|
@ -209,7 +210,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
filter_bot=False,
|
filter_bot=False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||||
|
|
@ -218,7 +219,6 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
# 只有当 limit 为 0 时才应用外部 sort
|
# 只有当 limit 为 0 时才应用外部 sort
|
||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
|
|
||||||
return find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||||
)
|
)
|
||||||
|
|
@ -231,7 +231,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
|
||||||
person_ids: List[str],
|
person_ids: List[str],
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||||
|
|
@ -302,7 +302,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_random(
|
def get_raw_msg_by_timestamp_random(
|
||||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||||
"""
|
"""
|
||||||
|
|
@ -312,15 +312,15 @@ def get_raw_msg_by_timestamp_random(
|
||||||
return []
|
return []
|
||||||
# 随机选一条
|
# 随机选一条
|
||||||
msg = random.choice(all_msgs)
|
msg = random.choice(all_msgs)
|
||||||
chat_id = msg["chat_id"]
|
chat_id = msg.chat_id
|
||||||
timestamp_start = msg["time"]
|
timestamp_start = msg.time
|
||||||
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
||||||
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_users(
|
def get_raw_msg_by_timestamp_with_users(
|
||||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||||
|
|
@ -331,7 +331,7 @@ def get_raw_msg_by_timestamp_with_users(
|
||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[DatabaseMessages]:
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
"""
|
"""
|
||||||
|
|
@ -340,7 +340,7 @@ def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[
|
||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[DatabaseMessages]:
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
"""
|
"""
|
||||||
|
|
@ -349,7 +349,7 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit
|
||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]:
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,15 @@ import re
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import jieba
|
import jieba
|
||||||
|
import json
|
||||||
|
import ast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from maim_message import UserInfo
|
|
||||||
from typing import Optional, Tuple, Dict, List, Any
|
from typing import Optional, Tuple, Dict, List, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
|
|
@ -130,22 +132,29 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
|
||||||
return []
|
return []
|
||||||
|
|
||||||
who_chat_in_group = []
|
who_chat_in_group = []
|
||||||
for msg_db_data in recent_messages:
|
for db_msg in recent_messages:
|
||||||
user_info = UserInfo.from_dict(
|
# user_info = UserInfo.from_dict(
|
||||||
{
|
# {
|
||||||
"platform": msg_db_data["user_platform"],
|
# "platform": msg_db_data["user_platform"],
|
||||||
"user_id": msg_db_data["user_id"],
|
# "user_id": msg_db_data["user_id"],
|
||||||
"user_nickname": msg_db_data["user_nickname"],
|
# "user_nickname": msg_db_data["user_nickname"],
|
||||||
"user_cardname": msg_db_data.get("user_cardname", ""),
|
# "user_cardname": msg_db_data.get("user_cardname", ""),
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
# if (
|
||||||
|
# (user_info.platform, user_info.user_id) != sender
|
||||||
|
# and user_info.user_id != global_config.bot.qq_account
|
||||||
|
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||||
|
# and len(who_chat_in_group) < 5
|
||||||
|
# ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||||
|
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||||
if (
|
if (
|
||||||
(user_info.platform, user_info.user_id) != sender
|
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
|
||||||
and user_info.user_id != global_config.bot.qq_account
|
and db_msg.user_info.user_id != global_config.bot.qq_account
|
||||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
|
||||||
and len(who_chat_in_group) < 5
|
and len(who_chat_in_group) < 5
|
||||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||||
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
|
||||||
|
|
||||||
return who_chat_in_group
|
return who_chat_in_group
|
||||||
|
|
||||||
|
|
@ -555,7 +564,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
||||||
|
|
||||||
# 获取消息内容计算总长度
|
# 获取消息内容计算总长度
|
||||||
messages = find_messages(message_filter=filter_query)
|
messages = find_messages(message_filter=filter_query)
|
||||||
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
|
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
|
||||||
|
|
||||||
return count, total_length
|
return count, total_length
|
||||||
|
|
||||||
|
|
@ -628,41 +637,34 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||||
user_id: str = user_info.user_id # type: ignore
|
user_id: str = user_info.user_id # type: ignore
|
||||||
|
|
||||||
# Initialize target_info with basic info
|
# Initialize target_info with basic info
|
||||||
target_info = {
|
target_info = TargetPersonInfo(
|
||||||
"platform": platform,
|
platform=platform,
|
||||||
"user_id": user_id,
|
user_id=user_id,
|
||||||
"user_nickname": user_info.user_nickname,
|
user_nickname=user_info.user_nickname, # type: ignore
|
||||||
"person_id": None,
|
person_id=None,
|
||||||
"person_name": None,
|
person_name=None
|
||||||
}
|
)
|
||||||
|
|
||||||
# Try to fetch person info
|
# Try to fetch person info
|
||||||
try:
|
try:
|
||||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
|
||||||
person = Person(platform=platform, user_id=user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
if not person.is_known:
|
if not person.is_known:
|
||||||
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
|
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
|
||||||
# 如果用户尚未认识,则返回False和None
|
# 如果用户尚未认识,则返回False和None
|
||||||
return False, None
|
return False, None
|
||||||
person_id = person.person_id
|
if person.person_id:
|
||||||
person_name = None
|
target_info.person_id = person.person_id
|
||||||
if person_id:
|
target_info.person_name = person.person_name
|
||||||
# get_value is async, so await it directly
|
|
||||||
person_name = person.person_name
|
|
||||||
|
|
||||||
target_info["person_id"] = person_id
|
|
||||||
target_info["person_name"] = person_name
|
|
||||||
except Exception as person_e:
|
except Exception as person_e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
|
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_target_info = target_info
|
chat_target_info = target_info.__dict__
|
||||||
else:
|
else:
|
||||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||||
# Keep defaults on error
|
|
||||||
|
|
||||||
return is_group_chat, chat_target_info
|
return is_group_chat, chat_target_info
|
||||||
|
|
||||||
|
|
@ -771,6 +773,7 @@ def assign_message_ids_flexible(
|
||||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||||
|
|
||||||
def parse_keywords_string(keywords_input) -> list[str]:
|
def parse_keywords_string(keywords_input) -> list[str]:
|
||||||
|
# sourcery skip: use-contextlib-suppress
|
||||||
"""
|
"""
|
||||||
统一的关键词解析函数,支持多种格式的关键词字符串解析
|
统一的关键词解析函数,支持多种格式的关键词字符串解析
|
||||||
|
|
||||||
|
|
@ -802,7 +805,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试作为JSON对象解析(支持 {"keywords": [...]} 格式)
|
# 尝试作为JSON对象解析(支持 {"keywords": [...]} 格式)
|
||||||
import json
|
|
||||||
json_data = json.loads(keywords_str)
|
json_data = json.loads(keywords_str)
|
||||||
if isinstance(json_data, dict) and "keywords" in json_data:
|
if isinstance(json_data, dict) and "keywords" in json_data:
|
||||||
keywords_list = json_data["keywords"]
|
keywords_list = json_data["keywords"]
|
||||||
|
|
@ -816,7 +818,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试使用 ast.literal_eval 解析(支持Python字面量格式)
|
# 尝试使用 ast.literal_eval 解析(支持Python字面量格式)
|
||||||
import ast
|
|
||||||
parsed = ast.literal_eval(keywords_str)
|
parsed = ast.literal_eval(keywords_str)
|
||||||
if isinstance(parsed, list):
|
if isinstance(parsed, list):
|
||||||
return [str(k).strip() for k in parsed if str(k).strip()]
|
return [str(k).strip() for k in parsed if str(k).strip()]
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
from enum import Enum
|
from typing import Optional
|
||||||
|
|
||||||
from typing import Optional, Union, Dict, Any, Tuple, List
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatabaseUserInfo:
|
class DatabaseUserInfo:
|
||||||
user_platform: str = field(default_factory=str)
|
platform: str = field(default_factory=str)
|
||||||
user_id: str = field(default_factory=str)
|
user_id: str = field(default_factory=str)
|
||||||
user_nickname: str = field(default_factory=str)
|
user_nickname: str = field(default_factory=str)
|
||||||
user_cardname: Optional[str] = None
|
user_cardname: Optional[str] = None
|
||||||
|
|
@ -84,17 +81,21 @@ class DatabaseMessages:
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
user_nickname=self.user_nickname,
|
user_nickname=self.user_nickname,
|
||||||
user_cardname=self.user_cardname,
|
user_cardname=self.user_cardname,
|
||||||
user_platform=self.user_platform,
|
platform=self.user_platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (self.chat_info_group_id and self.chat_info_group_name):
|
if self.chat_info_group_id and self.chat_info_group_name:
|
||||||
self.group_info = None
|
self.group_info = DatabaseGroupInfo(
|
||||||
|
group_id=self.chat_info_group_id,
|
||||||
|
group_name=self.chat_info_group_name,
|
||||||
|
group_platform=self.chat_info_group_platform,
|
||||||
|
)
|
||||||
|
|
||||||
chat_user_info = DatabaseUserInfo(
|
chat_user_info = DatabaseUserInfo(
|
||||||
user_id=self.chat_info_user_id,
|
user_id=self.chat_info_user_id,
|
||||||
user_nickname=self.chat_info_user_nickname,
|
user_nickname=self.chat_info_user_nickname,
|
||||||
user_cardname=self.chat_info_user_cardname,
|
user_cardname=self.chat_info_user_cardname,
|
||||||
user_platform=self.chat_info_user_platform,
|
platform=self.chat_info_user_platform,
|
||||||
)
|
)
|
||||||
self.chat_info = DatabaseChatInfo(
|
self.chat_info = DatabaseChatInfo(
|
||||||
stream_id=self.chat_info_stream_id,
|
stream_id=self.chat_info_stream_id,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TargetPersonInfo:
|
||||||
|
platform: str = field(default_factory=str)
|
||||||
|
user_id: str = field(default_factory=str)
|
||||||
|
user_nickname: str = field(default_factory=str)
|
||||||
|
person_id: Optional[str] = None
|
||||||
|
person_name: Optional[str] = None
|
||||||
|
|
@ -2,19 +2,20 @@ import traceback
|
||||||
|
|
||||||
from typing import List, Any, Optional
|
from typing import List, Any, Optional
|
||||||
from peewee import Model # 添加 Peewee Model 导入
|
from peewee import Model # 添加 Peewee Model 导入
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.database_model import Messages
|
from src.common.database.database_model import Messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
def _model_to_instance(model_instance: Model) -> DatabaseMessages:
|
||||||
"""
|
"""
|
||||||
将 Peewee 模型实例转换为字典。
|
将 Peewee 模型实例转换为字典。
|
||||||
"""
|
"""
|
||||||
return model_instance.__data__
|
return DatabaseMessages(**model_instance.__data__)
|
||||||
|
|
||||||
|
|
||||||
def find_messages(
|
def find_messages(
|
||||||
|
|
@ -24,7 +25,7 @@ def find_messages(
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
filter_bot=False,
|
filter_bot=False,
|
||||||
filter_command=False,
|
filter_command=False,
|
||||||
) -> List[dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
根据提供的过滤器、排序和限制条件查找消息。
|
根据提供的过滤器、排序和限制条件查找消息。
|
||||||
|
|
||||||
|
|
@ -112,7 +113,7 @@ def find_messages(
|
||||||
query = query.order_by(*peewee_sort_terms)
|
query = query.order_by(*peewee_sort_terms)
|
||||||
peewee_results = list(query)
|
peewee_results = list(query)
|
||||||
|
|
||||||
return [_model_to_dict(msg) for msg in peewee_results]
|
return [_model_to_instance(msg) for msg in peewee_results]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = (
|
log_message = (
|
||||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||||
|
|
|
||||||
|
|
@ -163,8 +163,10 @@ class ChatAction:
|
||||||
limit=15,
|
limit=15,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
tmp_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
@ -227,8 +229,10 @@ class ChatAction:
|
||||||
limit=10,
|
limit=10,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
tmp_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
|
||||||
|
|
@ -166,8 +166,10 @@ class ChatMood:
|
||||||
limit=10,
|
limit=10,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
tmp_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
@ -245,8 +247,10 @@ class ChatMood:
|
||||||
limit=5,
|
limit=5,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
tmp_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
|
||||||
|
|
@ -187,22 +187,23 @@ class PromptBuilder:
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||||
|
|
||||||
for msg_dict in message_list_before_now:
|
# TODO: 修复之!
|
||||||
|
for msg in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg.user_info.user_id)
|
||||||
if msg_user_id == bot_id:
|
if msg_user_id == bot_id:
|
||||||
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
if msg.reply_to and talk_type == msg.reply_to:
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg.__dict__)
|
||||||
elif msg_dict.get("reply_to") and talk_type != msg_dict.get("reply_to"):
|
elif msg.reply_to and talk_type != msg.reply_to:
|
||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg.__dict__)
|
||||||
# else:
|
# else:
|
||||||
# background_dialogue_list.append(msg_dict)
|
# background_dialogue_list.append(msg_dict)
|
||||||
elif msg_user_id == target_user_id:
|
elif msg_user_id == target_user_id:
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg.__dict__)
|
||||||
else:
|
else:
|
||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg.__dict__)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
|
||||||
|
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
|
|
@ -257,8 +258,10 @@ class PromptBuilder:
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=20,
|
limit=20,
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in all_dialogue_prompt]
|
||||||
all_dialogue_prompt_str = build_readable_messages(
|
all_dialogue_prompt_str = build_readable_messages(
|
||||||
all_dialogue_prompt,
|
tmp_msgs,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
show_pic=False,
|
show_pic=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -99,8 +99,10 @@ class ChatMood:
|
||||||
limit=int(global_config.chat.max_context_size / 3),
|
limit=int(global_config.chat.max_context_size / 3),
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
|
# TODO: 修复!
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
tmp_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
@ -148,8 +150,10 @@ class ChatMood:
|
||||||
limit=15,
|
limit=15,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
|
# TODO: 修复
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
tmp_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from typing import List, Dict, Any
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
from src.person_info.person_info import Person,get_person_id
|
from src.person_info.person_info import Person, get_person_id
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
|
|
@ -129,7 +129,7 @@ class RelationshipBuilder:
|
||||||
# 获取该消息前5条消息的时间作为潜在的开始时间
|
# 获取该消息前5条消息的时间作为潜在的开始时间
|
||||||
before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5)
|
before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5)
|
||||||
if before_messages:
|
if before_messages:
|
||||||
potential_start_time = before_messages[0]["time"]
|
potential_start_time = before_messages[0].time
|
||||||
else:
|
else:
|
||||||
potential_start_time = message_time
|
potential_start_time = message_time
|
||||||
|
|
||||||
|
|
@ -175,7 +175,7 @@ class RelationshipBuilder:
|
||||||
)
|
)
|
||||||
if after_messages and len(after_messages) >= 5:
|
if after_messages and len(after_messages) >= 5:
|
||||||
# 如果有足够的后续消息,使用第5条消息的时间作为结束时间
|
# 如果有足够的后续消息,使用第5条消息的时间作为结束时间
|
||||||
last_segment["end_time"] = after_messages[4]["time"]
|
last_segment["end_time"] = after_messages[4].time
|
||||||
|
|
||||||
# 重新计算当前消息段的消息数量
|
# 重新计算当前消息段的消息数量
|
||||||
last_segment["message_count"] = self._count_messages_in_timerange(
|
last_segment["message_count"] = self._count_messages_in_timerange(
|
||||||
|
|
@ -300,7 +300,6 @@ class RelationshipBuilder:
|
||||||
|
|
||||||
return cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0
|
return cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0
|
||||||
|
|
||||||
|
|
||||||
def get_cache_status(self) -> str:
|
def get_cache_status(self) -> str:
|
||||||
# sourcery skip: merge-list-append, merge-list-appends-into-extend
|
# sourcery skip: merge-list-append, merge-list-appends-into-extend
|
||||||
"""获取缓存状态信息,用于调试和监控"""
|
"""获取缓存状态信息,用于调试和监控"""
|
||||||
|
|
@ -342,13 +341,12 @@ class RelationshipBuilder:
|
||||||
# 统筹各模块协作、对外提供服务接口
|
# 统筹各模块协作、对外提供服务接口
|
||||||
# ================================
|
# ================================
|
||||||
|
|
||||||
async def build_relation(self,immediate_build: str = "",max_build_threshold: int = MAX_MESSAGE_COUNT):
|
async def build_relation(self, immediate_build: str = "", max_build_threshold: int = MAX_MESSAGE_COUNT):
|
||||||
"""构建关系
|
"""构建关系
|
||||||
immediate_build: 立即构建关系,可选值为"all"或person_id
|
immediate_build: 立即构建关系,可选值为"all"或person_id
|
||||||
"""
|
"""
|
||||||
self._cleanup_old_segments()
|
self._cleanup_old_segments()
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
if latest_messages := get_raw_msg_by_timestamp_with_chat(
|
if latest_messages := get_raw_msg_by_timestamp_with_chat(
|
||||||
self.chat_id,
|
self.chat_id,
|
||||||
|
|
@ -358,9 +356,9 @@ class RelationshipBuilder:
|
||||||
):
|
):
|
||||||
# 处理所有新的非bot消息
|
# 处理所有新的非bot消息
|
||||||
for latest_msg in latest_messages:
|
for latest_msg in latest_messages:
|
||||||
user_id = latest_msg.get("user_id")
|
user_id = latest_msg.user_info.user_id
|
||||||
platform = latest_msg.get("user_platform") or latest_msg.get("chat_info_platform")
|
platform = latest_msg.user_info.platform or latest_msg.chat_info.platform
|
||||||
msg_time = latest_msg.get("time", 0)
|
msg_time = latest_msg.time
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_id
|
user_id
|
||||||
|
|
@ -383,8 +381,10 @@ class RelationshipBuilder:
|
||||||
if not person.is_known:
|
if not person.is_known:
|
||||||
continue
|
continue
|
||||||
person_name = person.person_name or person_id
|
person_name = person.person_name or person_id
|
||||||
|
|
||||||
if total_message_count >= max_build_threshold or (total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")):
|
if total_message_count >= max_build_threshold or (
|
||||||
|
total_message_count >= 5 and immediate_build in [person_id, "all"]
|
||||||
|
):
|
||||||
users_to_build_relationship.append(person_id)
|
users_to_build_relationship.append(person_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
|
f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
|
||||||
|
|
@ -400,12 +400,11 @@ class RelationshipBuilder:
|
||||||
segments = self.person_engaged_cache[person_id]
|
segments = self.person_engaged_cache[person_id]
|
||||||
# 异步执行关系构建
|
# 异步执行关系构建
|
||||||
person = Person(person_id=person_id)
|
person = Person(person_id=person_id)
|
||||||
if person.is_known:
|
if person.is_known:
|
||||||
asyncio.create_task(self.update_impression_on_segments(person_id, self.chat_id, segments))
|
asyncio.create_task(self.update_impression_on_segments(person_id, self.chat_id, segments))
|
||||||
# 移除已处理的用户缓存
|
# 移除已处理的用户缓存
|
||||||
del self.person_engaged_cache[person_id]
|
del self.person_engaged_cache[person_id]
|
||||||
self._save_cache()
|
self._save_cache()
|
||||||
|
|
||||||
|
|
||||||
# ================================
|
# ================================
|
||||||
# 关系构建模块
|
# 关系构建模块
|
||||||
|
|
@ -458,7 +457,7 @@ class RelationshipBuilder:
|
||||||
"user_cardname": "",
|
"user_cardname": "",
|
||||||
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
||||||
"is_action_record": True,
|
"is_action_record": True,
|
||||||
"chat_info_platform": segment_messages[0].get("chat_info_platform", ""),
|
"chat_info_platform": segment_messages[0].chat_info.platform or "",
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
}
|
}
|
||||||
processed_messages.append(gap_message)
|
processed_messages.append(gap_message)
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,10 @@
|
||||||
readable_text = message_api.build_readable_messages(messages)
|
readable_text = message_api.build_readable_messages(messages)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
|
||||||
from src.config.config import global_config
|
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp,
|
get_raw_msg_by_timestamp,
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
|
|
@ -36,7 +37,7 @@ from src.chat.utils.chat_message_builder import (
|
||||||
|
|
||||||
def get_messages_by_time(
|
def get_messages_by_time(
|
||||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定时间范围内的消息
|
获取指定时间范围内的消息
|
||||||
|
|
||||||
|
|
@ -70,7 +71,7 @@ def get_messages_by_time_in_chat(
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
filter_mai: bool = False,
|
filter_mai: bool = False,
|
||||||
filter_command: bool = False,
|
filter_command: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间范围内的消息
|
获取指定聊天中指定时间范围内的消息
|
||||||
|
|
||||||
|
|
@ -97,7 +98,9 @@ def get_messages_by_time_in_chat(
|
||||||
if not isinstance(chat_id, str):
|
if not isinstance(chat_id, str):
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
if filter_mai:
|
if filter_mai:
|
||||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command))
|
return filter_mai_messages(
|
||||||
|
get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||||
|
)
|
||||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -109,7 +112,7 @@ def get_messages_by_time_in_chat_inclusive(
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
filter_mai: bool = False,
|
filter_mai: bool = False,
|
||||||
filter_command: bool = False,
|
filter_command: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||||
|
|
||||||
|
|
@ -137,9 +140,13 @@ def get_messages_by_time_in_chat_inclusive(
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
if filter_mai:
|
if filter_mai:
|
||||||
return filter_mai_messages(
|
return filter_mai_messages(
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
|
chat_id, start_time, end_time, limit, limit_mode, filter_command
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
return get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
|
chat_id, start_time, end_time, limit, limit_mode, filter_command
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_by_time_in_chat_for_users(
|
def get_messages_by_time_in_chat_for_users(
|
||||||
|
|
@ -149,7 +156,7 @@ def get_messages_by_time_in_chat_for_users(
|
||||||
person_ids: List[str],
|
person_ids: List[str],
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定用户在指定时间范围内的消息
|
获取指定聊天中指定用户在指定时间范围内的消息
|
||||||
|
|
||||||
|
|
@ -180,7 +187,7 @@ def get_messages_by_time_in_chat_for_users(
|
||||||
|
|
||||||
def get_random_chat_messages(
|
def get_random_chat_messages(
|
||||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||||
|
|
||||||
|
|
@ -208,7 +215,7 @@ def get_random_chat_messages(
|
||||||
|
|
||||||
def get_messages_by_time_for_users(
|
def get_messages_by_time_for_users(
|
||||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定用户在所有聊天中指定时间范围内的消息
|
获取指定用户在所有聊天中指定时间范围内的消息
|
||||||
|
|
||||||
|
|
@ -232,7 +239,7 @@ def get_messages_by_time_for_users(
|
||||||
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
|
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定时间戳之前的消息
|
获取指定时间戳之前的消息
|
||||||
|
|
||||||
|
|
@ -258,7 +265,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
|
||||||
|
|
||||||
def get_messages_before_time_in_chat(
|
def get_messages_before_time_in_chat(
|
||||||
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间戳之前的消息
|
获取指定聊天中指定时间戳之前的消息
|
||||||
|
|
||||||
|
|
@ -287,7 +294,7 @@ def get_messages_before_time_in_chat(
|
||||||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
|
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定用户在指定时间戳之前的消息
|
获取指定用户在指定时间戳之前的消息
|
||||||
|
|
||||||
|
|
@ -311,7 +318,7 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: List[str],
|
||||||
|
|
||||||
def get_recent_messages(
|
def get_recent_messages(
|
||||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中最近一段时间的消息
|
获取指定聊天中最近一段时间的消息
|
||||||
|
|
||||||
|
|
@ -472,7 +479,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
从消息列表中移除麦麦的消息
|
从消息列表中移除麦麦的消息
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -480,4 +487,4 @@ def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
Returns:
|
Returns:
|
||||||
过滤后的消息列表
|
过滤后的消息列表
|
||||||
"""
|
"""
|
||||||
return [msg for msg in messages if msg.get("user_id") != str(global_config.bot.qq_account)]
|
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
|
||||||
|
|
|
||||||
|
|
@ -85,8 +85,10 @@ class EmojiAction(BaseAction):
|
||||||
messages_text = ""
|
messages_text = ""
|
||||||
if recent_messages:
|
if recent_messages:
|
||||||
# 使用message_api构建可读的消息字符串
|
# 使用message_api构建可读的消息字符串
|
||||||
|
# TODO: 修复
|
||||||
|
tmp_msgs = [msg.__dict__ for msg in recent_messages]
|
||||||
messages_text = message_api.build_readable_messages(
|
messages_text = message_api.build_readable_messages(
|
||||||
messages=recent_messages,
|
messages=tmp_msgs,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
truncate=False,
|
truncate=False,
|
||||||
show_actions=False,
|
show_actions=False,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue