diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index d2a45ca4..101eb822 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -286,7 +286,8 @@ class HeartFChatting: filter_command=True, ) # TODO: 修复! - temp_recent_messages_dict = [msg.__dict__ for msg in recent_messages_dict] + from src.common.data_models import temporarily_transform_class_to_dict + temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict] # 统一的消息处理逻辑 should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 2e35a423..e5b5eb04 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -353,7 +353,8 @@ class ExpressionLearner: limit=num, ) # TODO: 修复! - random_msg: Optional[List[Dict[str, Any]]] = [msg.__dict__ for msg in random_msg_temp] if random_msg_temp else None + from src.common.data_models import temporarily_transform_class_to_dict + random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None # print(random_msg) if not random_msg or random_msg == []: diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 1dcd3a19..aa63aa8f 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -71,7 +71,8 @@ class ActionModifier: limit=min(int(global_config.chat.max_context_size * 0.33), 10), ) # TODO: 修复! - temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half] + from src.common.data_models import temporarily_transform_class_to_dict + temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half] chat_content = build_readable_messages( temp_msg_list_before_now_half, replace_bot_name=True, diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 895bf826..32ab828c 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -281,7 +281,8 @@ class ActionPlanner: limit=int(global_config.chat.max_context_size * 0.6), ) # TODO: 修复! - temp_msg_list_before_now = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_content_block, message_id_list = build_readable_messages_with_id( messages=temp_msg_list_before_now, timestamp_mode="normal_no_YMD", diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 5464b9f5..91d9c687 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -710,12 +710,13 @@ class DefaultReplyer: target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) # TODO: 修复! + from src.common.data_models import temporarily_transform_class_to_dict message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size * 1, ) - temp_msg_list_before_long = [msg.__dict__ for msg in message_list_before_now_long] + temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long] # TODO: 修复! message_list_before_short = get_raw_msg_before_timestamp_with_chat( @@ -723,7 +724,7 @@ class DefaultReplyer: timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - temp_msg_list_before_short = [msg.__dict__ for msg in message_list_before_short] + temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short] chat_talking_prompt_short = build_readable_messages( temp_msg_list_before_short, @@ -899,7 +900,8 @@ class DefaultReplyer: limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) # TODO: 修复! - temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half] + from src.common.data_models import temporarily_transform_class_to_dict + temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half] chat_talking_prompt_half = build_readable_messages( temp_msg_list_before_now_half, replace_bot_name=True, diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index cbf72eeb..c73f1a9e 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -1,3 +1,51 @@ from typing import Dict, Any -def temporarily_transform_class_to_dict(class_instance) -> Dict[str, Any]: - return class_instance.__dict__ \ No newline at end of file + + +class AbstractClassFlag: + pass + + +def temporarily_transform_class_to_dict(obj: Any) -> Any: + """ + 将对象或容器中的 AbstractClassFlag 子类(类对象)或 AbstractClassFlag 实例 + 递归转换为普通 dict,不修改原对象。 + - 对于类对象(isinstance(value, type) 且 issubclass(..., AbstractClassFlag)), + 读取类的 __dict__ 中非 dunder 项并递归转换。 + - 对于实例(isinstance(value, AbstractClassFlag)),读取 vars(instance) 并递归转换。 + """ + + def _transform(value: Any) -> Any: + # 值是类对象且为 AbstractClassFlag 的子类 + if isinstance(value, type) and issubclass(value, AbstractClassFlag): + return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)} + + # 值是 AbstractClassFlag 的实例 + if isinstance(value, AbstractClassFlag): + return {k: _transform(v) for k, v in vars(value).items()} + + # 常见容器类型,递归处理 + if isinstance(value, dict): + return {k: _transform(v) for k, v in value.items()} + if isinstance(value, list): + return [_transform(v) for v in value] + if isinstance(value, tuple): + return tuple(_transform(v) for v in value) + if isinstance(value, set): + return {_transform(v) for v in value} + # 基本类型,直接返回 + return value + + result = _transform(obj) + + def flatten(target_dict: dict): + flat_dict = {} + for k, v in target_dict.items(): + if isinstance(v, dict): + # 递归扁平化子字典 + sub_flat = flatten(v) + flat_dict.update(sub_flat) + else: + flat_dict[k] = v + return flat_dict + + return flatten(result) if isinstance(result, dict) else result diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 91e3f550..77da7f99 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,24 +1,36 @@ -from typing import Optional -from dataclasses import dataclass, field +from typing import Optional, Dict, Any +from dataclasses import dataclass, field, fields, MISSING +from . import AbstractClassFlag @dataclass -class DatabaseUserInfo: +class DatabaseUserInfo(AbstractClassFlag): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) user_cardname: Optional[str] = None + + def __post_init__(self): + assert isinstance(self.platform, str), "platform must be a string" + assert isinstance(self.user_id, str), "user_id must be a string" + assert isinstance(self.user_nickname, str), "user_nickname must be a string" + assert isinstance(self.user_cardname, str) or self.user_cardname is None, "user_cardname must be a string or None" @dataclass -class DatabaseGroupInfo: +class DatabaseGroupInfo(AbstractClassFlag): group_id: str = field(default_factory=str) group_name: str = field(default_factory=str) group_platform: Optional[str] = None + + def __post_init__(self): + assert isinstance(self.group_id, str), "group_id must be a string" + assert isinstance(self.group_name, str), "group_name must be a string" + assert isinstance(self.group_platform, str) or self.group_platform is None, "group_platform must be a string or None" @dataclass -class DatabaseChatInfo: +class DatabaseChatInfo(AbstractClassFlag): stream_id: str = field(default_factory=str) platform: str = field(default_factory=str) create_time: float = field(default_factory=float) @@ -26,12 +38,20 @@ class DatabaseChatInfo: user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) group_info: Optional[DatabaseGroupInfo] = None + def __post_init__(self): + assert isinstance(self.stream_id, str), "stream_id must be a string" + assert isinstance(self.platform, str), "platform must be a string" + assert isinstance(self.create_time, float), "create_time must be a float" + assert isinstance(self.last_active_time, float), "last_active_time must be a float" + assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance" + assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, "group_info must be a DatabaseGroupInfo instance or None" -@dataclass -class DatabaseMessages: - chat_info: DatabaseChatInfo - user_info: DatabaseUserInfo - group_info: Optional[DatabaseGroupInfo] = None + +@dataclass(init=False) +class DatabaseMessages(AbstractClassFlag): + # chat_info: DatabaseChatInfo + # user_info: DatabaseUserInfo + # group_info: Optional[DatabaseGroupInfo] = None message_id: str = field(default_factory=str) time: float = field(default_factory=float) @@ -44,23 +64,23 @@ class DatabaseMessages: is_mentioned: Optional[bool] = None # 从 chat_info 扁平化而来的字段 - chat_info_stream_id: str = field(default_factory=str) - chat_info_platform: str = field(default_factory=str) - chat_info_user_platform: str = field(default_factory=str) - chat_info_user_id: str = field(default_factory=str) - chat_info_user_nickname: str = field(default_factory=str) - chat_info_user_cardname: Optional[str] = None - chat_info_group_platform: Optional[str] = None - chat_info_group_id: Optional[str] = None - chat_info_group_name: Optional[str] = None - chat_info_create_time: float = field(default_factory=float) - chat_info_last_active_time: float = field(default_factory=float) + # chat_info_stream_id: str = field(default_factory=str) + # chat_info_platform: str = field(default_factory=str) + # chat_info_user_platform: str = field(default_factory=str) + # chat_info_user_id: str = field(default_factory=str) + # chat_info_user_nickname: str = field(default_factory=str) + # chat_info_user_cardname: Optional[str] = None + # chat_info_group_platform: Optional[str] = None + # chat_info_group_id: Optional[str] = None + # chat_info_group_name: Optional[str] = None + # chat_info_create_time: float = field(default_factory=float) + # chat_info_last_active_time: float = field(default_factory=float) # 从顶层 user_info 扁平化而来的字段 (消息发送者信息) - user_platform: str = field(default_factory=str) - user_id: str = field(default_factory=str) - user_nickname: str = field(default_factory=str) - user_cardname: Optional[str] = None + # user_platform: str = field(default_factory=str) + # user_id: str = field(default_factory=str) + # user_nickname: str = field(default_factory=str) + # user_cardname: Optional[str] = None processed_plain_text: Optional[str] = None # 处理后的纯文本消息 display_message: Optional[str] = None # 显示的消息 @@ -76,32 +96,65 @@ class DatabaseMessages: selected_expressions: Optional[str] = None - def __post_init__(self): - self.user_info = DatabaseUserInfo( - user_id=self.user_id, - user_nickname=self.user_nickname, - user_cardname=self.user_cardname, - platform=self.user_platform, - ) + # def __post_init__(self): - if self.chat_info_group_id and self.chat_info_group_name: + # if self.chat_info_group_id and self.chat_info_group_name: + # self.group_info = DatabaseGroupInfo( + # group_id=self.chat_info_group_id, + # group_name=self.chat_info_group_name, + # group_platform=self.chat_info_group_platform, + # ) + + # chat_user_info = DatabaseUserInfo( + # user_id=self.chat_info_user_id, + # user_nickname=self.chat_info_user_nickname, + # user_cardname=self.chat_info_user_cardname, + # platform=self.chat_info_user_platform, + # ) + # self.chat_info = DatabaseChatInfo( + # stream_id=self.chat_info_stream_id, + # platform=self.chat_info_platform, + # create_time=self.chat_info_create_time, + # last_active_time=self.chat_info_last_active_time, + # user_info=chat_user_info, + # group_info=self.group_info, + # ) + def __init__(self, **kwargs: Any): + defined = {f.name: f for f in fields(self.__class__)} + for name, f in defined.items(): + if name in kwargs: + setattr(self, name, kwargs.pop(name)) + elif f.default is not MISSING: + setattr(self, name, f.default) + else: + raise TypeError(f"缺失必需字段: {name}") + + self.user_info = DatabaseUserInfo( + user_id=kwargs.get("user_id"), # type: ignore + user_nickname=kwargs.get("user_nickname"), # type: ignore + user_cardname=kwargs.get("user_cardname"), # type: ignore + platform=kwargs.get("user_platform"), # type: ignore + ) + if kwargs.get("chat_info_group_id") and kwargs.get("chat_info_group_name"): self.group_info = DatabaseGroupInfo( - group_id=self.chat_info_group_id, - group_name=self.chat_info_group_name, - group_platform=self.chat_info_group_platform, + group_id=kwargs.get("chat_info_group_id"), # type: ignore + group_name=kwargs.get("chat_info_group_name"), # type: ignore + group_platform=kwargs.get("chat_info_group_platform"), # type: ignore ) chat_user_info = DatabaseUserInfo( - user_id=self.chat_info_user_id, - user_nickname=self.chat_info_user_nickname, - user_cardname=self.chat_info_user_cardname, - platform=self.chat_info_user_platform, + user_id=kwargs.get("chat_info_user_id"), # type: ignore + user_nickname=kwargs.get("chat_info_user_nickname"), # type: ignore + user_cardname=kwargs.get("chat_info_user_cardname"), # type: ignore + platform=kwargs.get("chat_info_user_platform"), # type: ignore ) + self.chat_info = DatabaseChatInfo( - stream_id=self.chat_info_stream_id, - platform=self.chat_info_platform, - create_time=self.chat_info_create_time, - last_active_time=self.chat_info_last_active_time, + stream_id=kwargs.get("chat_info_stream_id"), # type: ignore + platform=kwargs.get("chat_info_platform"), # type: ignore + create_time=kwargs.get("chat_info_create_time"), # type: ignore + last_active_time=kwargs.get("chat_info_last_active_time"), # type: ignore user_info=chat_user_info, group_info=self.group_info, ) + diff --git a/src/common/logger.py b/src/common/logger.py index 9c454a93..e069765e 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -454,6 +454,11 @@ RESET_COLOR = "\033[0m" def convert_pathname_to_module(logger, method_name, event_dict): # sourcery skip: extract-method, use-string-remove-affix """将 pathname 转换为模块风格的路径""" + if "logger_name" in event_dict and event_dict["logger_name"] == "maim_message": + if "pathname" in event_dict: + del event_dict["pathname"] + event_dict["module"] = "maim_message" + return event_dict if "pathname" in event_dict: pathname = event_dict["pathname"] try: diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index dbcc0809..6dd681ea 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -164,7 +164,8 @@ class ChatAction: limit_mode="last", ) # TODO: 修复! - tmp_msgs = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( tmp_msgs, replace_bot_name=True, @@ -230,7 +231,8 @@ class ChatAction: limit_mode="last", ) # TODO: 修复! - tmp_msgs = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( tmp_msgs, replace_bot_name=True, diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index 5609b5ba..51b53f11 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -167,7 +167,8 @@ class ChatMood: limit_mode="last", ) # TODO: 修复! - tmp_msgs = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( tmp_msgs, replace_bot_name=True, @@ -248,7 +249,8 @@ class ChatMood: limit_mode="last", ) # TODO: 修复! - tmp_msgs = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( tmp_msgs, replace_bot_name=True, diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index f9de2e0c..1727ad28 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -259,7 +259,8 @@ class PromptBuilder: limit=20, ) # TODO: 修复! - tmp_msgs = [msg.__dict__ for msg in all_dialogue_prompt] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt] all_dialogue_prompt_str = build_readable_messages( tmp_msgs, timestamp_mode="normal_no_YMD", diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 406968ce..4d501beb 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -100,7 +100,8 @@ class ChatMood: limit_mode="last", ) # TODO: 修复! - tmp_msgs = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( tmp_msgs, replace_bot_name=True, @@ -151,7 +152,8 @@ class ChatMood: limit_mode="last", ) # TODO: 修复 - tmp_msgs = [msg.__dict__ for msg in message_list_before_now] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( tmp_msgs, replace_bot_name=True, diff --git a/src/plugins/built_in/emoji_plugin/emoji.py b/src/plugins/built_in/emoji_plugin/emoji.py index 66bd3e77..df38e56f 100644 --- a/src/plugins/built_in/emoji_plugin/emoji.py +++ b/src/plugins/built_in/emoji_plugin/emoji.py @@ -86,7 +86,8 @@ class EmojiAction(BaseAction): if recent_messages: # 使用message_api构建可读的消息字符串 # TODO: 修复 - tmp_msgs = [msg.__dict__ for msg in recent_messages] + from src.common.data_models import temporarily_transform_class_to_dict + tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages] messages_text = message_api.build_readable_messages( messages=tmp_msgs, timestamp_mode="normal_no_YMD",