MaiBot/src/config/config_base.py

294 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import ast
import inspect
import types
from dataclasses import dataclass, field
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field
from typing import Any, Dict, List, Literal, Set, Tuple, Union, cast, get_args, get_origin
__all__ = ["ConfigBase", "Field", "AttributeData"]
from src.common.logger import get_logger
logger = get_logger("ConfigBase")
@dataclass
class AttributeData:
missing_attributes: list[str] = field(default_factory=list)
"""缺失的属性列表"""
redundant_attributes: list[str] = field(default_factory=list)
"""多余的属性列表"""
class AttrDocBase:
"""解析字段说明的基类"""
field_docs: dict[str, str] = {}
def __post_init__(self, allow_extra_methods: bool = False):
self.field_docs = self._get_field_docs(allow_extra_methods) # 全局仅获取一次并保留
def _get_field_docs(self, allow_extra_methods: bool) -> dict[str, str]:
"""
获取字段的说明字符串
:param cls: 配置类
:return: 字段说明字典,键为字段名,值为说明字符串
"""
# 获取类的源代码文本
class_source = self._get_class_source()
# 解析源代码,找到对应的类定义节点
class_node = self._find_class_node(class_source)
# 从类定义节点中提取字段文档
return self._extract_field_docs(class_node, allow_extra_methods)
@classmethod
def get_class_field_docs(cls) -> dict[str, str]:
class_source = cls._get_class_source()
class_node = cls._find_class_node(class_source)
return AttrDocBase._extract_field_docs(
cast(AttrDocBase, cast(Any, cls)),
class_node,
allow_extra_methods=False,
)
@classmethod
def _get_class_source(cls) -> str:
"""获取类定义所在文件的完整源代码"""
# 使用 inspect 模块获取类定义所在的文件路径
class_file = inspect.getfile(cls)
# 读取文件内容并以 UTF-8 编码返回
return Path(class_file).read_text(encoding="utf-8")
@classmethod
def _find_class_node(cls, class_source: str) -> ast.ClassDef:
"""在源代码中找到类定义的AST节点"""
tree = ast.parse(class_source)
# 遍历 AST 中的所有节点
for node in ast.walk(tree):
# 查找类定义节点,且类名与当前类名匹配
if isinstance(node, ast.ClassDef) and node.name == cls.__name__:
"""类名匹配,返回节点"""
return node
# 如果没有找到匹配的类定义,抛出异常
raise AttributeError(f"Class {cls.__name__} not found in source.")
def _extract_field_docs(self, class_node: ast.ClassDef, allow_extra_methods: bool) -> dict[str, str]:
"""从类的 AST 节点中提取字段的文档字符串"""
# sourcery skip: merge-nested-ifs
doc_dict: dict[str, str] = {}
class_body = class_node.body # 类属性节点列表
for i in range(len(class_body)):
body_item = class_body[i]
if not allow_extra_methods:
# 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
# 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
"""检验ConfigBase子类中是否有除model_post_init以外的方法规范配置类的定义"""
raise AttributeError(
f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
) from None
# 检查当前语句是否为带注解的赋值语句 (类型注解的字段定义)
# 并且下一个语句存在
if (
i + 1 < len(class_body)
and isinstance(body_item, ast.AnnAssign) # 例如: field_name: int = 10
and isinstance(body_item.target, ast.Name) # 目标是一个简单的名称
):
"""字段定义后紧跟的字符串表达式即为字段说明"""
expr_item = class_body[i + 1]
# 检查下一个语句是否为字符串常量表达式 (文档字符串)
if (
isinstance(expr_item, ast.Expr) # 表达式语句
and isinstance(expr_item.value, ast.Constant) # 常量值
and isinstance(expr_item.value.value, str) # 字符串常量
):
doc_string = expr_item.value.value.strip() # 获取说明字符串并去除首尾空白
processed_doc_lines = [line.strip() for line in doc_string.splitlines()] # 多行处理
# 删除开头的所有空行
while processed_doc_lines and not processed_doc_lines[0]:
processed_doc_lines.pop(0)
# 删除结尾的所有空行
while processed_doc_lines and not processed_doc_lines[-1]:
processed_doc_lines.pop()
# 将处理后的行重新组合,并存入字典
# 键是字段名,值是清理后的文档字符串
doc_dict[body_item.target.id] = "\n".join(processed_doc_lines)
return doc_dict
class ConfigBase(BaseModel, AttrDocBase):
model_config = ConfigDict(validate_assignment=True, extra="forbid")
_validate_any: bool = True # 是否验证 Any 类型的使用,默认为 True
suppress_any_warning: bool = False # 是否抑制 Any 类型使用的警告,默认为 False仅仅在_validate_any 为 False 时生效
@classmethod
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
"""从字典创建配置对象,并收集缺失和多余的属性信息"""
class_fields = set(cls.model_fields.keys())
class_fields.remove("field_docs") # 忽略 field_docs 字段
if "_validate_any" in class_fields:
class_fields.remove("_validate_any") # 忽略 _validate_any 字段
if "suppress_any_warning" in class_fields:
class_fields.remove("suppress_any_warning") # 忽略 suppress_any_warning 字
for class_field in class_fields:
if class_field not in data:
attribute_data.missing_attributes.append(class_field) # 记录缺失的属性
cleaned_data_list: list[str] = []
for data_field in data:
if data_field not in class_fields:
cleaned_data_list.append(data_field)
attribute_data.redundant_attributes.append(data_field) # 记录多余的属性
for redundant_field in cleaned_data_list:
data.pop(redundant_field) # 移除多余的属性
# 对于是ConfigBase子类的字段递归调用from_dict
class_field_infos = dict(cls.model_fields.items())
for field_data in data:
if info := class_field_infos.get(field_data):
field_type = info.annotation
if inspect.isclass(field_type) and issubclass(field_type, ConfigBase):
data[field_data] = field_type.from_dict(attribute_data, data[field_data])
if get_origin(field_type) in {list, List}:
elem_type = get_args(field_type)[0]
if inspect.isclass(elem_type) and issubclass(elem_type, ConfigBase):
data[field_data] = [elem_type.from_dict(attribute_data, item) for item in data[field_data]]
# 没有set因为ConfigBase is not Hashable
if get_origin(field_type) in {dict, Dict}:
val_type = get_args(field_type)[1]
if inspect.isclass(val_type) and issubclass(val_type, ConfigBase):
data[field_data] = {
key: val_type.from_dict(attribute_data, val) for key, val in data[field_data].items()
}
return cls(**data)
def _discourage_any_usage(self, field_name: str) -> None:
"""警告使用 Any 类型的字段可被suppress"""
if self._validate_any:
raise TypeError(f"字段'{field_name}'中不允许使用 Any 类型注解")
if not self.suppress_any_warning:
logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议使用更具体的类型注解以提高类型安全性")
def _get_real_type(self, annotation: type[Any] | Any | None):
"""获取真实类型,处理 dict 等没有参数的情况"""
origin_type = get_origin(annotation)
args_type = get_args(annotation)
if origin_type is None:
origin_type = annotation
args_type = ()
return origin_type, args_type
def _validate_union_type(self, annotation: type[Any] | Any | None, field_name: str):
"""
验证 Union 类型的使用可被suppress
明确禁止 Union / PEP 604 的 | 表示法
允许 Optional[T](即 Union[T, None]"""
origin, args = self._get_real_type(annotation)
other = annotation
if origin in (Union, types.UnionType):
if len(args) != 2 or all(a is not type(None) for a in args):
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中不允许使用 Union 类型注解")
# 将注解替换为 Optional 的内部类型,继续后续校验(允许原子或容器类型)
other = args[0] if args[1] is type(None) else args[1]
origin, args = self._get_real_type(other)
if origin in (Union, types.UnionType):
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套使用 Union/Optional 类型注解")
return origin, args, other
def _validate_list_set_type(self, annotation: Any | None, field_name: str):
"""验证 list/set 类型的使用"""
origin, args = self._get_real_type(annotation)
if origin in (list, set, List, Set):
if len(args) != 1:
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中必须指定且仅指定一个类型参数,使用了: {annotation}"
)
elem = args[0]
if elem is Any:
self._discourage_any_usage(field_name)
elif get_origin(elem) is not None:
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
)
elif inspect.isclass(elem) and issubclass(elem, ConfigBase) and origin in (set, Set):
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许使用 ConfigBase 子类作为 set 元素类型: {annotation}。ConfigBase is not Hashable。"
)
def _validate_dict_type(self, annotation: Any | None, field_name: str):
"""验证 dict 类型的使用"""
_, args = self._get_real_type(annotation)
if len(args) != 2:
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中必须指定键和值的类型参数: {annotation}")
_, val_t = args
if val_t is Any:
self._discourage_any_usage(field_name)
if get_origin(val_t):
origin_type = get_origin(val_t)
if origin_type is None:
return
origin_type, _, anno = self._validate_union_type(val_t, field_name)
if origin_type in (list, set, List, Set):
self._validate_list_set_type(anno, field_name)
elif origin_type is Any:
self._discourage_any_usage(field_name)
elif origin_type in (int, float, str, bool, complex, bytes):
return
else:
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
)
def model_post_init(self, context: Any = None) -> None:
"""验证字段的类型注解
规则:
- 允许原子注解(非泛型,且不为 Any
- 允许 list[T], set[T],其中 T 为原子注解
- 允许 dict[K, V],其中 K、V 为原子注解
- 禁止使用 Union不包含 Optional和 tuple及 Tuple
- 禁止嵌套泛型(例如 list[list[int]])和使用 Any
"""
for field_name, field_info in type(self).model_fields.items():
annotation = field_info.annotation
origin_type, _ = self._get_real_type(annotation)
# 处理 Union (含Optional) 类型
origin_type, _, annotation = self._validate_union_type(annotation, field_name)
# 禁止 tuple / Tuple
if origin_type in (tuple, Tuple):
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中不允许使用 Tuple 类型注解")
# 处理 Any 类型
if origin_type is Any:
self._discourage_any_usage(field_name)
# 非泛型注解视为原子类型,允许
if origin_type in (int, float, str, bool, complex, bytes, Any):
continue
# 允许嵌套的ConfigBase自定义类
if isinstance(origin_type, type) and issubclass(cast(type, origin_type), ConfigBase):
continue
# 只允许 list, set, dict 三类泛型
if origin_type not in (list, set, dict, List, Set, Dict, Literal):
raise TypeError(
f"仅允许使用list, set, dict三种泛型类型注解'{type(self).__name__}'字段'{field_name}'中使用了: {annotation}"
)
# list/set: 必须指定且仅指定一个类型参数,且参数为原子类型
if origin_type in (list, set, List, Set):
self._validate_list_set_type(annotation, field_name)
# dict: 必须指定两个类型参数,且 key/value 为原子类型或者set/list类型
if origin_type in (dict, Dict):
self._validate_dict_type(annotation, field_name)
super().model_post_init(context)
super().__post_init__() # 获取字段说明