diff --git a/main.py b/main.py index 50cf968..2c71ef1 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,7 @@ async def message_recv(server_connection: Server.ServerConnection): async for raw_message in server_connection: logger.debug( f"{raw_message[:100]}..." - if (len(raw_message) > 100 and global_config.debug_level != "DEBUG") + if (len(raw_message) > 100 and global_config.debug.level != "DEBUG") else raw_message ) decoded_raw_message: dict = json.loads(raw_message) @@ -52,19 +52,23 @@ async def main(): async def napcat_server(): logger.info("正在启动adapter...") - async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server: - logger.info(f"Adapter已启动,监听地址: ws://{global_config.server_host}:{global_config.server_port}") + async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port) as server: + logger.info( + f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}" + ) await server.serve_forever() async def graceful_shutdown(): try: logger.info("正在关闭adapter...") - await mmc_stop_com() tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + if not task.done(): + task.cancel() + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15) + await mmc_stop_com() # 后置避免神秘exception + logger.info("Adapter已成功关闭") except Exception as e: logger.error(f"Adapter关闭中出现错误: {e}") diff --git a/src/config/config_base.py b/src/config/config_base.py index fbd3dd9..518f99c 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, fields, MISSING -from typing import TypeVar, Type, Any, get_origin, get_args, Literal +from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, List, Set, Tuple, Union T = TypeVar("T", bound="ConfigBase") @@ -18,16 +18,16 @@ class ConfigBase: """配置类的基类""" @classmethod - def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: """从字典加载配置字段""" if not isinstance(data, dict): raise TypeError(f"Expected a dictionary, got {type(data).__name__}") - init_args: dict[str, Any] = {} + init_args: Dict[str, Any] = {} for f in fields(cls): field_name = f.name - + field_type = f.type if field_name.startswith("_"): # 跳过以 _ 开头的字段 continue @@ -40,14 +40,12 @@ class ConfigBase: raise ValueError(f"Missing required field: '{field_name}'") value = data[field_name] - field_type = f.type - try: init_args[field_name] = cls._convert_field(value, field_type) except TypeError as e: - raise TypeError(f"Field '{field_name}' has a type error: {e}") from e + raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e except Exception as e: - raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e + raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e return cls(**init_args) @@ -61,33 +59,30 @@ class ConfigBase: 3. 对于基础类型(int, str, float, bool),直接转换 4. 对于其他类型,尝试直接转换,如果失败则抛出异常 """ - # 如果是嵌套的 dataclass,递归调用 from_dict 方法 if isinstance(field_type, type) and issubclass(field_type, ConfigBase): - if not isinstance(value, dict): - raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") return field_type.from_dict(value) - # 处理泛型集合类型(list, set, tuple) field_origin_type = get_origin(field_type) - field_type_args = get_args(field_type) + field_args_type = get_args(field_type) + # 处理泛型集合类型(list, set, tuple) if field_origin_type in {list, set, tuple}: # 检查提供的value是否为list if not isinstance(value, list): raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") if field_origin_type is list: - return [cls._convert_field(item, field_type_args[0]) for item in value] - elif field_origin_type is set: - return {cls._convert_field(item, field_type_args[0]) for item in value} - elif field_origin_type is tuple: + return [cls._convert_field(item, field_args_type[0]) for item in value] + if field_origin_type is set: + return {cls._convert_field(item, field_args_type[0]) for item in value} + if field_origin_type is tuple: # 检查提供的value长度是否与类型参数一致 - if len(value) != len(field_type_args): + if len(value) != len(field_args_type): raise TypeError( - f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}" + f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}" ) - return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args)) + return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type)) if field_origin_type is dict: # 检查提供的value是否为dict @@ -95,18 +90,30 @@ class ConfigBase: raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") # 检查字典的键值类型 - if len(field_type_args) != 2: + if len(field_args_type) != 2: raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") - key_type, value_type = field_type_args + key_type, value_type = field_args_type return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} - # 处理基础类型,例如 int, str 等 - if field_origin_type is type(None) and value is None: # 处理Optional类型 - return None + # 处理Optional类型 + if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union + if value is None: + return None + # 如果有数据,检查实际类型 + if type(value) not in field_args_type: + raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}") + return cls._convert_field(value, field_args_type[0]) + + # 处理int, str, float, bool等基础类型 + if field_origin_type is None: + if isinstance(value, field_type): + return field_type(value) + else: + raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}") # 处理Literal类型 - if field_origin_type is Literal or get_origin(field_type) is Literal: + if field_origin_type is Literal: # 获取Literal的允许值 allowed_values = get_args(field_type) if value in allowed_values: @@ -114,14 +121,15 @@ class ConfigBase: else: raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type") - if field_type is Any or isinstance(value, field_type): + # 处理其他类型 + if field_type is Any: return value - # 其他类型,尝试直接转换 + # 其他类型直接转换 try: return field_type(value) except (ValueError, TypeError) as e: - raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e + raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e def __str__(self): """返回配置类的字符串表示""" diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 3119ffb..d8928a8 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -49,16 +49,16 @@ class ChatConfig(ConfigBase): group_list_type: Literal["whitelist", "blacklist"] = "whitelist" """群聊列表类型 白名单/黑名单""" - group_list: list[str] = field(default_factory=[]) + group_list: list[int] = field(default_factory=[]) """群聊列表""" private_list_type: Literal["whitelist", "blacklist"] = "whitelist" """私聊列表类型 白名单/黑名单""" - private_list: list[str] = field(default_factory=[]) + private_list: list[int] = field(default_factory=[]) """私聊列表""" - ban_user_id: list[str] = field(default_factory=[]) + ban_user_id: list[int] = field(default_factory=[]) """被封禁的用户ID列表,封禁后将无法与其进行交互""" enable_poke: bool = True diff --git a/src/recv_handler.py b/src/recv_handler.py index 21ff56f..7cc0a07 100644 --- a/src/recv_handler.py +++ b/src/recv_handler.py @@ -697,7 +697,7 @@ class RecvHandler: user_nickname: str = sender_info.get("nickname", "QQ用户") user_nickname_str = f"【{user_nickname}】:" break_seg = Seg(type="text", data="\n") - message_of_sub_message_list: dict = sub_message.get("message") + message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message") if not message_of_sub_message_list: logger.warning("转发消息内容为空") continue