mirror of https://github.com/Mai-with-u/MaiBot.git
恢复可用性
parent
3481234d2b
commit
5e26414839
|
|
@ -286,7 +286,8 @@ class HeartFChatting:
|
|||
filter_command=True,
|
||||
)
|
||||
# TODO: 修复!
|
||||
temp_recent_messages_dict = [msg.__dict__ for msg in recent_messages_dict]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict]
|
||||
# 统一的消息处理逻辑
|
||||
should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -353,7 +353,8 @@ class ExpressionLearner:
|
|||
limit=num,
|
||||
)
|
||||
# TODO: 修复!
|
||||
random_msg: Optional[List[Dict[str, Any]]] = [msg.__dict__ for msg in random_msg_temp] if random_msg_temp else None
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None
|
||||
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
|
|
|
|||
|
|
@ -71,7 +71,8 @@ class ActionModifier:
|
|||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
)
|
||||
# TODO: 修复!
|
||||
temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
|
||||
chat_content = build_readable_messages(
|
||||
temp_msg_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
|
|
|
|||
|
|
@ -281,7 +281,8 @@ class ActionPlanner:
|
|||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
# TODO: 修复!
|
||||
temp_msg_list_before_now = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=temp_msg_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
|
|
|
|||
|
|
@ -710,12 +710,13 @@ class DefaultReplyer:
|
|||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
)
|
||||
temp_msg_list_before_long = [msg.__dict__ for msg in message_list_before_now_long]
|
||||
temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long]
|
||||
|
||||
# TODO: 修复!
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
@ -723,7 +724,7 @@ class DefaultReplyer:
|
|||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
temp_msg_list_before_short = [msg.__dict__ for msg in message_list_before_short]
|
||||
temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short]
|
||||
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
temp_msg_list_before_short,
|
||||
|
|
@ -899,7 +900,8 @@ class DefaultReplyer:
|
|||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
# TODO: 修复!
|
||||
temp_msg_list_before_now_half = [msg.__dict__ for msg in message_list_before_now_half]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
temp_msg_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,51 @@
|
|||
from typing import Dict, Any
|
||||
def temporarily_transform_class_to_dict(class_instance) -> Dict[str, Any]:
|
||||
return class_instance.__dict__
|
||||
|
||||
|
||||
class AbstractClassFlag:
|
||||
pass
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
"""
|
||||
将对象或容器中的 AbstractClassFlag 子类(类对象)或 AbstractClassFlag 实例
|
||||
递归转换为普通 dict,不修改原对象。
|
||||
- 对于类对象(isinstance(value, type) 且 issubclass(..., AbstractClassFlag)),
|
||||
读取类的 __dict__ 中非 dunder 项并递归转换。
|
||||
- 对于实例(isinstance(value, AbstractClassFlag)),读取 vars(instance) 并递归转换。
|
||||
"""
|
||||
|
||||
def _transform(value: Any) -> Any:
|
||||
# 值是类对象且为 AbstractClassFlag 的子类
|
||||
if isinstance(value, type) and issubclass(value, AbstractClassFlag):
|
||||
return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)}
|
||||
|
||||
# 值是 AbstractClassFlag 的实例
|
||||
if isinstance(value, AbstractClassFlag):
|
||||
return {k: _transform(v) for k, v in vars(value).items()}
|
||||
|
||||
# 常见容器类型,递归处理
|
||||
if isinstance(value, dict):
|
||||
return {k: _transform(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_transform(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_transform(v) for v in value)
|
||||
if isinstance(value, set):
|
||||
return {_transform(v) for v in value}
|
||||
# 基本类型,直接返回
|
||||
return value
|
||||
|
||||
result = _transform(obj)
|
||||
|
||||
def flatten(target_dict: dict):
|
||||
flat_dict = {}
|
||||
for k, v in target_dict.items():
|
||||
if isinstance(v, dict):
|
||||
# 递归扁平化子字典
|
||||
sub_flat = flatten(v)
|
||||
flat_dict.update(sub_flat)
|
||||
else:
|
||||
flat_dict[k] = v
|
||||
return flat_dict
|
||||
|
||||
return flatten(result) if isinstance(result, dict) else result
|
||||
|
|
|
|||
|
|
@ -1,24 +1,36 @@
|
|||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
|
||||
from . import AbstractClassFlag
|
||||
|
||||
@dataclass
|
||||
class DatabaseUserInfo:
|
||||
class DatabaseUserInfo(AbstractClassFlag):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.platform, str), "platform must be a string"
|
||||
assert isinstance(self.user_id, str), "user_id must be a string"
|
||||
assert isinstance(self.user_nickname, str), "user_nickname must be a string"
|
||||
assert isinstance(self.user_cardname, str) or self.user_cardname is None, "user_cardname must be a string or None"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseGroupInfo:
|
||||
class DatabaseGroupInfo(AbstractClassFlag):
|
||||
group_id: str = field(default_factory=str)
|
||||
group_name: str = field(default_factory=str)
|
||||
group_platform: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
assert isinstance(self.group_name, str), "group_name must be a string"
|
||||
assert isinstance(self.group_platform, str) or self.group_platform is None, "group_platform must be a string or None"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseChatInfo:
|
||||
class DatabaseChatInfo(AbstractClassFlag):
|
||||
stream_id: str = field(default_factory=str)
|
||||
platform: str = field(default_factory=str)
|
||||
create_time: float = field(default_factory=float)
|
||||
|
|
@ -26,12 +38,20 @@ class DatabaseChatInfo:
|
|||
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
|
||||
group_info: Optional[DatabaseGroupInfo] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
assert isinstance(self.platform, str), "platform must be a string"
|
||||
assert isinstance(self.create_time, float), "create_time must be a float"
|
||||
assert isinstance(self.last_active_time, float), "last_active_time must be a float"
|
||||
assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
|
||||
assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, "group_info must be a DatabaseGroupInfo instance or None"
|
||||
|
||||
@dataclass
|
||||
class DatabaseMessages:
|
||||
chat_info: DatabaseChatInfo
|
||||
user_info: DatabaseUserInfo
|
||||
group_info: Optional[DatabaseGroupInfo] = None
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseMessages(AbstractClassFlag):
|
||||
# chat_info: DatabaseChatInfo
|
||||
# user_info: DatabaseUserInfo
|
||||
# group_info: Optional[DatabaseGroupInfo] = None
|
||||
|
||||
message_id: str = field(default_factory=str)
|
||||
time: float = field(default_factory=float)
|
||||
|
|
@ -44,23 +64,23 @@ class DatabaseMessages:
|
|||
is_mentioned: Optional[bool] = None
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
chat_info_stream_id: str = field(default_factory=str)
|
||||
chat_info_platform: str = field(default_factory=str)
|
||||
chat_info_user_platform: str = field(default_factory=str)
|
||||
chat_info_user_id: str = field(default_factory=str)
|
||||
chat_info_user_nickname: str = field(default_factory=str)
|
||||
chat_info_user_cardname: Optional[str] = None
|
||||
chat_info_group_platform: Optional[str] = None
|
||||
chat_info_group_id: Optional[str] = None
|
||||
chat_info_group_name: Optional[str] = None
|
||||
chat_info_create_time: float = field(default_factory=float)
|
||||
chat_info_last_active_time: float = field(default_factory=float)
|
||||
# chat_info_stream_id: str = field(default_factory=str)
|
||||
# chat_info_platform: str = field(default_factory=str)
|
||||
# chat_info_user_platform: str = field(default_factory=str)
|
||||
# chat_info_user_id: str = field(default_factory=str)
|
||||
# chat_info_user_nickname: str = field(default_factory=str)
|
||||
# chat_info_user_cardname: Optional[str] = None
|
||||
# chat_info_group_platform: Optional[str] = None
|
||||
# chat_info_group_id: Optional[str] = None
|
||||
# chat_info_group_name: Optional[str] = None
|
||||
# chat_info_create_time: float = field(default_factory=float)
|
||||
# chat_info_last_active_time: float = field(default_factory=float)
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
|
||||
user_platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
# user_platform: str = field(default_factory=str)
|
||||
# user_id: str = field(default_factory=str)
|
||||
# user_nickname: str = field(default_factory=str)
|
||||
# user_cardname: Optional[str] = None
|
||||
|
||||
processed_plain_text: Optional[str] = None # 处理后的纯文本消息
|
||||
display_message: Optional[str] = None # 显示的消息
|
||||
|
|
@ -76,32 +96,65 @@ class DatabaseMessages:
|
|||
|
||||
selected_expressions: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.user_info = DatabaseUserInfo(
|
||||
user_id=self.user_id,
|
||||
user_nickname=self.user_nickname,
|
||||
user_cardname=self.user_cardname,
|
||||
platform=self.user_platform,
|
||||
)
|
||||
# def __post_init__(self):
|
||||
|
||||
if self.chat_info_group_id and self.chat_info_group_name:
|
||||
# if self.chat_info_group_id and self.chat_info_group_name:
|
||||
# self.group_info = DatabaseGroupInfo(
|
||||
# group_id=self.chat_info_group_id,
|
||||
# group_name=self.chat_info_group_name,
|
||||
# group_platform=self.chat_info_group_platform,
|
||||
# )
|
||||
|
||||
# chat_user_info = DatabaseUserInfo(
|
||||
# user_id=self.chat_info_user_id,
|
||||
# user_nickname=self.chat_info_user_nickname,
|
||||
# user_cardname=self.chat_info_user_cardname,
|
||||
# platform=self.chat_info_user_platform,
|
||||
# )
|
||||
# self.chat_info = DatabaseChatInfo(
|
||||
# stream_id=self.chat_info_stream_id,
|
||||
# platform=self.chat_info_platform,
|
||||
# create_time=self.chat_info_create_time,
|
||||
# last_active_time=self.chat_info_last_active_time,
|
||||
# user_info=chat_user_info,
|
||||
# group_info=self.group_info,
|
||||
# )
|
||||
def __init__(self, **kwargs: Any):
|
||||
defined = {f.name: f for f in fields(self.__class__)}
|
||||
for name, f in defined.items():
|
||||
if name in kwargs:
|
||||
setattr(self, name, kwargs.pop(name))
|
||||
elif f.default is not MISSING:
|
||||
setattr(self, name, f.default)
|
||||
else:
|
||||
raise TypeError(f"缺失必需字段: {name}")
|
||||
|
||||
self.user_info = DatabaseUserInfo(
|
||||
user_id=kwargs.get("user_id"), # type: ignore
|
||||
user_nickname=kwargs.get("user_nickname"), # type: ignore
|
||||
user_cardname=kwargs.get("user_cardname"), # type: ignore
|
||||
platform=kwargs.get("user_platform"), # type: ignore
|
||||
)
|
||||
if kwargs.get("chat_info_group_id") and kwargs.get("chat_info_group_name"):
|
||||
self.group_info = DatabaseGroupInfo(
|
||||
group_id=self.chat_info_group_id,
|
||||
group_name=self.chat_info_group_name,
|
||||
group_platform=self.chat_info_group_platform,
|
||||
group_id=kwargs.get("chat_info_group_id"), # type: ignore
|
||||
group_name=kwargs.get("chat_info_group_name"), # type: ignore
|
||||
group_platform=kwargs.get("chat_info_group_platform"), # type: ignore
|
||||
)
|
||||
|
||||
chat_user_info = DatabaseUserInfo(
|
||||
user_id=self.chat_info_user_id,
|
||||
user_nickname=self.chat_info_user_nickname,
|
||||
user_cardname=self.chat_info_user_cardname,
|
||||
platform=self.chat_info_user_platform,
|
||||
user_id=kwargs.get("chat_info_user_id"), # type: ignore
|
||||
user_nickname=kwargs.get("chat_info_user_nickname"), # type: ignore
|
||||
user_cardname=kwargs.get("chat_info_user_cardname"), # type: ignore
|
||||
platform=kwargs.get("chat_info_user_platform"), # type: ignore
|
||||
)
|
||||
|
||||
self.chat_info = DatabaseChatInfo(
|
||||
stream_id=self.chat_info_stream_id,
|
||||
platform=self.chat_info_platform,
|
||||
create_time=self.chat_info_create_time,
|
||||
last_active_time=self.chat_info_last_active_time,
|
||||
stream_id=kwargs.get("chat_info_stream_id"), # type: ignore
|
||||
platform=kwargs.get("chat_info_platform"), # type: ignore
|
||||
create_time=kwargs.get("chat_info_create_time"), # type: ignore
|
||||
last_active_time=kwargs.get("chat_info_last_active_time"), # type: ignore
|
||||
user_info=chat_user_info,
|
||||
group_info=self.group_info,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -454,6 +454,11 @@ RESET_COLOR = "\033[0m"
|
|||
def convert_pathname_to_module(logger, method_name, event_dict):
|
||||
# sourcery skip: extract-method, use-string-remove-affix
|
||||
"""将 pathname 转换为模块风格的路径"""
|
||||
if "logger_name" in event_dict and event_dict["logger_name"] == "maim_message":
|
||||
if "pathname" in event_dict:
|
||||
del event_dict["pathname"]
|
||||
event_dict["module"] = "maim_message"
|
||||
return event_dict
|
||||
if "pathname" in event_dict:
|
||||
pathname = event_dict["pathname"]
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -164,7 +164,8 @@ class ChatAction:
|
|||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
replace_bot_name=True,
|
||||
|
|
@ -230,7 +231,8 @@ class ChatAction:
|
|||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
replace_bot_name=True,
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ class ChatMood:
|
|||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
replace_bot_name=True,
|
||||
|
|
@ -248,7 +249,8 @@ class ChatMood:
|
|||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
replace_bot_name=True,
|
||||
|
|
|
|||
|
|
@ -259,7 +259,8 @@ class PromptBuilder:
|
|||
limit=20,
|
||||
)
|
||||
# TODO: 修复!
|
||||
tmp_msgs = [msg.__dict__ for msg in all_dialogue_prompt]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
tmp_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
|
|
|
|||
|
|
@ -100,7 +100,8 @@ class ChatMood:
|
|||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
replace_bot_name=True,
|
||||
|
|
@ -151,7 +152,8 @@ class ChatMood:
|
|||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复
|
||||
tmp_msgs = [msg.__dict__ for msg in message_list_before_now]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
replace_bot_name=True,
|
||||
|
|
|
|||
|
|
@ -86,7 +86,8 @@ class EmojiAction(BaseAction):
|
|||
if recent_messages:
|
||||
# 使用message_api构建可读的消息字符串
|
||||
# TODO: 修复
|
||||
tmp_msgs = [msg.__dict__ for msg in recent_messages]
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages]
|
||||
messages_text = message_api.build_readable_messages(
|
||||
messages=tmp_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
|
|
|
|||
Loading…
Reference in New Issue