修复Config类型没转换的问题

pull/34/head
UnCLAS-Prommer 2025-06-15 16:40:25 +08:00
parent 36305f226c
commit 81a71af4aa
4 changed files with 51 additions and 39 deletions

16
main.py
View File

@ -18,7 +18,7 @@ async def message_recv(server_connection: Server.ServerConnection):
async for raw_message in server_connection: async for raw_message in server_connection:
logger.debug( logger.debug(
f"{raw_message[:100]}..." f"{raw_message[:100]}..."
if (len(raw_message) > 100 and global_config.debug_level != "DEBUG") if (len(raw_message) > 100 and global_config.debug.level != "DEBUG")
else raw_message else raw_message
) )
decoded_raw_message: dict = json.loads(raw_message) decoded_raw_message: dict = json.loads(raw_message)
@ -52,19 +52,23 @@ async def main():
async def napcat_server(): async def napcat_server():
logger.info("正在启动adapter...") logger.info("正在启动adapter...")
async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server: async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port) as server:
logger.info(f"Adapter已启动监听地址: ws://{global_config.server_host}:{global_config.server_port}") logger.info(
f"Adapter已启动监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}"
)
await server.serve_forever() await server.serve_forever()
async def graceful_shutdown(): async def graceful_shutdown():
try: try:
logger.info("正在关闭adapter...") logger.info("正在关闭adapter...")
await mmc_stop_com()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks: for task in tasks:
task.cancel() if not task.done():
await asyncio.gather(*tasks, return_exceptions=True) task.cancel()
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15)
await mmc_stop_com() # 后置避免神秘exception
logger.info("Adapter已成功关闭")
except Exception as e: except Exception as e:
logger.error(f"Adapter关闭中出现错误: {e}") logger.error(f"Adapter关闭中出现错误: {e}")

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, fields, MISSING from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, List, Set, Tuple, Union
T = TypeVar("T", bound="ConfigBase") T = TypeVar("T", bound="ConfigBase")
@ -18,16 +18,16 @@ class ConfigBase:
"""配置类的基类""" """配置类的基类"""
@classmethod @classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T: def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
"""从字典加载配置字段""" """从字典加载配置字段"""
if not isinstance(data, dict): if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}") raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {} init_args: Dict[str, Any] = {}
for f in fields(cls): for f in fields(cls):
field_name = f.name field_name = f.name
field_type = f.type
if field_name.startswith("_"): if field_name.startswith("_"):
# 跳过以 _ 开头的字段 # 跳过以 _ 开头的字段
continue continue
@ -40,14 +40,12 @@ class ConfigBase:
raise ValueError(f"Missing required field: '{field_name}'") raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name] value = data[field_name]
field_type = f.type
try: try:
init_args[field_name] = cls._convert_field(value, field_type) init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e: except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
return cls(**init_args) return cls(**init_args)
@ -61,33 +59,30 @@ class ConfigBase:
3. 对于基础类型int, str, float, bool直接转换 3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型尝试直接转换如果失败则抛出异常 4. 对于其他类型尝试直接转换如果失败则抛出异常
""" """
# 如果是嵌套的 dataclass递归调用 from_dict 方法 # 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase): if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value) return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type) field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type) field_args_type = get_args(field_type)
# 处理泛型集合类型list, set, tuple
if field_origin_type in {list, set, tuple}: if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list # 检查提供的value是否为list
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list: if field_origin_type is list:
return [cls._convert_field(item, field_type_args[0]) for item in value] return [cls._convert_field(item, field_args_type[0]) for item in value]
elif field_origin_type is set: if field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value} return {cls._convert_field(item, field_args_type[0]) for item in value}
elif field_origin_type is tuple: if field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致 # 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args): if len(value) != len(field_args_type):
raise TypeError( raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}" f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
) )
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args)) return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
if field_origin_type is dict: if field_origin_type is dict:
# 检查提供的value是否为dict # 检查提供的value是否为dict
@ -95,18 +90,30 @@ class ConfigBase:
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型 # 检查字典的键值类型
if len(field_type_args) != 2: if len(field_args_type) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args key_type, value_type = field_args_type
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等 # 处理Optional类型
if field_origin_type is type(None) and value is None: # 处理Optional类型 if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
return None if value is None:
return None
# 如果有数据,检查实际类型
if type(value) not in field_args_type:
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
return cls._convert_field(value, field_args_type[0])
# 处理int, str, float, bool等基础类型
if field_origin_type is None:
if isinstance(value, field_type):
return field_type(value)
else:
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
# 处理Literal类型 # 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal: if field_origin_type is Literal:
# 获取Literal的允许值 # 获取Literal的允许值
allowed_values = get_args(field_type) allowed_values = get_args(field_type)
if value in allowed_values: if value in allowed_values:
@ -114,14 +121,15 @@ class ConfigBase:
else: else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type") raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type): # 处理其他类型
if field_type is Any:
return value return value
# 其他类型,尝试直接转换 # 其他类型直接转换
try: try:
return field_type(value) return field_type(value)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
def __str__(self): def __str__(self):
"""返回配置类的字符串表示""" """返回配置类的字符串表示"""

View File

@ -49,16 +49,16 @@ class ChatConfig(ConfigBase):
group_list_type: Literal["whitelist", "blacklist"] = "whitelist" group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""群聊列表类型 白名单/黑名单""" """群聊列表类型 白名单/黑名单"""
group_list: list[str] = field(default_factory=[]) group_list: list[int] = field(default_factory=[])
"""群聊列表""" """群聊列表"""
private_list_type: Literal["whitelist", "blacklist"] = "whitelist" private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""私聊列表类型 白名单/黑名单""" """私聊列表类型 白名单/黑名单"""
private_list: list[str] = field(default_factory=[]) private_list: list[int] = field(default_factory=[])
"""私聊列表""" """私聊列表"""
ban_user_id: list[str] = field(default_factory=[]) ban_user_id: list[int] = field(default_factory=[])
"""被封禁的用户ID列表封禁后将无法与其进行交互""" """被封禁的用户ID列表封禁后将无法与其进行交互"""
enable_poke: bool = True enable_poke: bool = True

View File

@ -697,7 +697,7 @@ class RecvHandler:
user_nickname: str = sender_info.get("nickname", "QQ用户") user_nickname: str = sender_info.get("nickname", "QQ用户")
user_nickname_str = f"{user_nickname}】:" user_nickname_str = f"{user_nickname}】:"
break_seg = Seg(type="text", data="\n") break_seg = Seg(type="text", data="\n")
message_of_sub_message_list: dict = sub_message.get("message") message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
if not message_of_sub_message_list: if not message_of_sub_message_list:
logger.warning("转发消息内容为空") logger.warning("转发消息内容为空")
continue continue