mirror of https://github.com/Mai-with-u/MaiBot.git
使用Sourcery的Test,修复测出来的Bug
parent
afb993e481
commit
13f095f231
|
|
@ -1,105 +1,496 @@
|
|||
# 本文件为测试文件,请忽略Lint error,内含大量的ignore标识
|
||||
|
||||
from typing import Any, Optional, Union, List
|
||||
from pathlib import Path
|
||||
from importlib import util
|
||||
import logging
|
||||
import sys
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 测试环境准备:补全 logger 和 AttrDocBase 依赖
|
||||
# -------------------------------------------------------------
|
||||
|
||||
TEST_ROOT = Path(__file__).parent.parent.absolute().resolve()
|
||||
logger_file = TEST_ROOT / "logger.py"
|
||||
spec = util.spec_from_file_location("src.common.logger", logger_file)
|
||||
module = util.module_from_spec(spec) # type: ignore
|
||||
assert spec is not None and spec.loader is not None
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
sys.modules["src.common.logger"] = module
|
||||
|
||||
# 测试对象导入
|
||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
||||
|
||||
from src.config.config_base import ConfigBase, Field # noqa: E402
|
||||
from src.config.config_base import ConfigBase # noqa: E402
|
||||
import src.config.config_base as config_base_module # noqa: E402
|
||||
|
||||
|
||||
class IllegalConfig_Dict(ConfigBase):
|
||||
a: dict = Field(default_factory=dict)
|
||||
class AttrDocBase:
|
||||
"""用于测试的轻量级 AttrDocBase 替身"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# 被 ConfigBase.model_post_init 调用
|
||||
self.__post_init_called__ = True
|
||||
|
||||
|
||||
class IllegalConfig_List(ConfigBase):
|
||||
b: list = Field(default_factory=list)
|
||||
# 打补丁,让 ConfigBase 使用测试替身
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_attrdoc_post_init():
|
||||
orig = config_base_module.AttrDocBase.__post_init__
|
||||
config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore
|
||||
yield
|
||||
config_base_module.AttrDocBase.__post_init__ = orig
|
||||
|
||||
|
||||
class IllegalConfig_Set(ConfigBase):
|
||||
c: set = Field(default_factory=set)
|
||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
||||
|
||||
|
||||
class IllegalConfig_Tuple(ConfigBase):
|
||||
d: tuple = Field(default_factory=tuple)
|
||||
class TestConfigBase:
|
||||
# ---------------------------------------------------------
|
||||
# happy path:整体 model_post_init 测试
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"model_cls, init_kwargs, expected_fields",
|
||||
[
|
||||
pytest.param(
|
||||
# 简单原子类型字段
|
||||
type(
|
||||
"SimpleAtomic",
|
||||
(ConfigBase,),
|
||||
{
|
||||
"__annotations__": {
|
||||
"a": int,
|
||||
"b": str,
|
||||
"c": bool,
|
||||
"d": float,
|
||||
},
|
||||
"a": Field(default=1),
|
||||
"b": Field(default="x"),
|
||||
"c": Field(default=True),
|
||||
"d": Field(default=1.5),
|
||||
},
|
||||
),
|
||||
{},
|
||||
{"a", "b", "c", "d"},
|
||||
id="happy-simple-atomic-fields",
|
||||
),
|
||||
pytest.param(
|
||||
# list/set/dict 泛型 + 原子内部类型
|
||||
type(
|
||||
"AtomicContainers",
|
||||
(ConfigBase,),
|
||||
{
|
||||
"__annotations__": {
|
||||
"ints": List[int],
|
||||
"names": Set[str],
|
||||
"mapping": Dict[str, int],
|
||||
},
|
||||
"ints": Field(default_factory=lambda: [1, 2]),
|
||||
"names": Field(default_factory=lambda: {"a", "b"}),
|
||||
"mapping": Field(default_factory=lambda: {"x": 1}),
|
||||
},
|
||||
),
|
||||
{},
|
||||
{"ints", "names", "mapping"},
|
||||
id="happy-atomic-containers",
|
||||
),
|
||||
pytest.param(
|
||||
# Optional 原子和 Optional 容器
|
||||
type(
|
||||
"OptionalFields",
|
||||
(ConfigBase,),
|
||||
{
|
||||
"__annotations__": {
|
||||
"maybe_int": Optional[int],
|
||||
"maybe_str_list": Optional[List[str]],
|
||||
},
|
||||
"maybe_int": Field(default=None),
|
||||
"maybe_str_list": Field(default=None),
|
||||
},
|
||||
),
|
||||
{},
|
||||
{"maybe_int", "maybe_str_list"},
|
||||
id="happy-optional-fields",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields):
|
||||
# Act
|
||||
instance = model_cls(**init_kwargs)
|
||||
|
||||
# Assert
|
||||
for field_name in expected_fields:
|
||||
assert field_name in type(instance).model_fields
|
||||
_ = getattr(instance, field_name)
|
||||
assert getattr(instance, "__post_init_called__", False) is True
|
||||
|
||||
class IllegalConfig_Union(ConfigBase):
|
||||
e: Union[int, str] = Field(default_factory=str)
|
||||
# ---------------------------------------------------------
|
||||
# _get_real_type
|
||||
# ---------------------------------------------------------
|
||||
def test_get_real_type_non_generic_and_generic(self):
|
||||
class Sample(ConfigBase):
|
||||
x: int = 1
|
||||
y: List[int] = Field(default_factory=list)
|
||||
|
||||
instance = Sample()
|
||||
|
||||
class IllegalConfig_Any(ConfigBase):
|
||||
f: Any = Field(default_factory=dict)
|
||||
# Act
|
||||
origin_x, args_x = instance._get_real_type(int)
|
||||
|
||||
# Assert
|
||||
assert origin_x is int
|
||||
assert args_x == ()
|
||||
|
||||
class IllegalConfig_NestedGeneric(ConfigBase):
|
||||
g: list[List[int]] = Field(default_factory=list)
|
||||
# Act
|
||||
origin_y, args_y = instance._get_real_type(List[int])
|
||||
|
||||
# Assert
|
||||
assert origin_y in (list, List)
|
||||
assert args_y == (int,)
|
||||
|
||||
class IllegalConfig_Any_suppress(ConfigBase):
|
||||
f: Any = Field(default_factory=dict)
|
||||
_validate_any: bool = False
|
||||
# ---------------------------------------------------------
|
||||
# _validate_union_type
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"annotation, expect_error, error_fragment, expected_origin_type",
|
||||
[
|
||||
pytest.param(
|
||||
int,
|
||||
False,
|
||||
None,
|
||||
int,
|
||||
id="union-validation-atomic-non-union",
|
||||
),
|
||||
pytest.param(
|
||||
Optional[int],
|
||||
False,
|
||||
None,
|
||||
int,
|
||||
id="union-validation-optional-atomic",
|
||||
),
|
||||
pytest.param(
|
||||
Optional[List[int]],
|
||||
False,
|
||||
None,
|
||||
list,
|
||||
id="union-validation-optional-container",
|
||||
),
|
||||
pytest.param(
|
||||
Union[int, str],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-disallow-non-optional-union",
|
||||
),
|
||||
pytest.param(
|
||||
int | str,
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-pep604-disallow-non-optional-union",
|
||||
),
|
||||
pytest.param(
|
||||
Union[int, None, str],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-disallow-union-more-than-two",
|
||||
),
|
||||
pytest.param(
|
||||
Optional[Union[int, str]],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-disallow-nested-optional-union",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type):
|
||||
# 这里我们不实例化 Sample,以避免在 __init__/model_post_init 阶段触发验证。
|
||||
# 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。
|
||||
|
||||
class Dummy(ConfigBase):
|
||||
pass
|
||||
|
||||
class SubClass(ConfigBase):
|
||||
x: Optional[int] = Field(default=None)
|
||||
y: list[int] = [123]
|
||||
dummy = Dummy() # 最小初始化,避免字段校验
|
||||
|
||||
field_name = "v"
|
||||
|
||||
class LegalConfig(ConfigBase):
|
||||
a: dict[str, list[int]] = Field(default_factory=dict)
|
||||
b: list[int] = Field(default_factory=list)
|
||||
c: set[str] = Field(default_factory=set)
|
||||
d: Optional[str] = Field(default=None)
|
||||
e: SubClass = Field(default_factory=SubClass)
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_union_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
origin, args, other = dummy._validate_union_type(annotation, field_name)
|
||||
|
||||
# Assert
|
||||
assert origin is expected_origin_type
|
||||
assert other is not None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_class, expected_exception, expected_message",
|
||||
[
|
||||
(IllegalConfig_Dict, TypeError, "必须指定键和值的类型参数"),
|
||||
(IllegalConfig_List, TypeError, "必须指定且仅指定一个类型参数"),
|
||||
(IllegalConfig_Set, TypeError, "必须指定且仅指定一个类型参数"),
|
||||
(IllegalConfig_Tuple, TypeError, "不允许使用 Tuple 类型注解"),
|
||||
(IllegalConfig_Union, TypeError, "不允许使用 Union 类型注解"),
|
||||
(IllegalConfig_Any, TypeError, "不允许使用 Any 类型注解"),
|
||||
(IllegalConfig_NestedGeneric, TypeError, "不允许嵌套泛型类型"),
|
||||
(IllegalConfig_Any_suppress, None, ""),
|
||||
],
|
||||
)
|
||||
def test_illegal_config(config_class, expected_exception, expected_message):
|
||||
# sourcery skip: no-conditionals-in-tests
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception) as exc_info:
|
||||
config_class()
|
||||
assert expected_message in str(exc_info.value)
|
||||
assert expected_exception == exc_info.type
|
||||
else:
|
||||
config_instance = config_class()
|
||||
assert isinstance(config_instance, config_class)
|
||||
# ---------------------------------------------------------
|
||||
# _validate_list_set_type
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"annotation, expect_error, error_fragment",
|
||||
[
|
||||
pytest.param(
|
||||
List[int],
|
||||
False,
|
||||
None,
|
||||
id="listset-validation-list-happy",
|
||||
),
|
||||
pytest.param(
|
||||
Set[str],
|
||||
False,
|
||||
None,
|
||||
id="listset-validation-set-happy",
|
||||
),
|
||||
pytest.param(
|
||||
list,
|
||||
True,
|
||||
"必须指定且仅指定一个类型参数",
|
||||
id="listset-validation-missing-type-arg",
|
||||
),
|
||||
pytest.param(
|
||||
List[int | None],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
id="listset-validation-nested-generic-inner-union",
|
||||
),
|
||||
pytest.param(
|
||||
List[List[int]],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
id="listset-validation-nested-generic-inner-list",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
||||
# 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败,
|
||||
# 只测试 _validate_list_set_type 本身的逻辑。
|
||||
|
||||
class Dummy(ConfigBase):
|
||||
pass
|
||||
|
||||
def test_legal_config():
|
||||
config_instance = LegalConfig()
|
||||
assert isinstance(config_instance, LegalConfig)
|
||||
assert isinstance(config_instance.a, dict)
|
||||
assert isinstance(config_instance.b, list)
|
||||
assert isinstance(config_instance.c, set)
|
||||
assert config_instance.d is None
|
||||
assert isinstance(config_instance.e, SubClass)
|
||||
assert config_instance.e.x is None
|
||||
assert isinstance(config_instance.e.y, list)
|
||||
assert config_instance.e.y == [123]
|
||||
dummy = Dummy()
|
||||
|
||||
field_name = "items"
|
||||
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_list_set_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
dummy._validate_list_set_type(annotation, field_name)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _validate_dict_type
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"annotation, expect_error, error_fragment",
|
||||
[
|
||||
pytest.param(
|
||||
Dict[str, int],
|
||||
False,
|
||||
None,
|
||||
id="dict-validation-happy-atomic",
|
||||
),
|
||||
pytest.param(
|
||||
Dict[str, Any],
|
||||
True,
|
||||
"不允许使用 Any 类型注解",
|
||||
id="dict-validation-any-value-disallowed",
|
||||
),
|
||||
pytest.param(
|
||||
Dict[str, Dict[str, int]],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
id="dict-validation-optional-nested-list",
|
||||
),
|
||||
pytest.param(
|
||||
Dict,
|
||||
True,
|
||||
"必须指定键和值的类型参数",
|
||||
id="dict-validation-missing-args",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
||||
# 同样不通过字段定义来触发 model_post_init,只测试 _validate_dict_type 本身。
|
||||
|
||||
class Dummy(ConfigBase):
|
||||
_validate_any: bool = True
|
||||
|
||||
dummy = Dummy()
|
||||
field_name = "mapping"
|
||||
|
||||
if expect_error:
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
|
||||
# Act
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _discourage_any_usage
|
||||
# ---------------------------------------------------------
|
||||
def test_discourage_any_usage_raises_when_validate_any_true(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = True
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
instance._discourage_any_usage("field_x")
|
||||
assert "不允许使用 Any 类型注解" in str(exc_info.value)
|
||||
assert "建议避免使用" not in caplog.text
|
||||
|
||||
def test_discourage_any_usage_logs_when_validate_any_false(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = False
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Arrange
|
||||
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
|
||||
|
||||
# Act
|
||||
instance._discourage_any_usage("field_y")
|
||||
|
||||
# Assert
|
||||
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# model_post_init 规则覆盖(错误与边界情况)
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"field_annotation, expect_error, error_fragment, test_id",
|
||||
[
|
||||
(
|
||||
Tuple[int, int],
|
||||
True,
|
||||
"不允许使用 Tuple 类型注解",
|
||||
"model-post-init-disallow-tuple-typing-tuple",
|
||||
),
|
||||
(
|
||||
tuple[int, int],
|
||||
True,
|
||||
"不允许使用 Tuple 类型注解",
|
||||
"model-post-init-disallow-pep604-tuple",
|
||||
),
|
||||
(
|
||||
Union[int, str],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
"model-post-init-disallow-union-field",
|
||||
),
|
||||
(
|
||||
list,
|
||||
True,
|
||||
"必须指定且仅指定一个类型参数",
|
||||
"model-post-init-list-missing-type-arg",
|
||||
),
|
||||
(
|
||||
List[List[int]],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
"model-post-init-list-nested-generic",
|
||||
),
|
||||
(
|
||||
Dict[str, Any],
|
||||
True,
|
||||
"不允许使用 Any 类型注解",
|
||||
"model-post-init-dict-value-any",
|
||||
),
|
||||
(
|
||||
Any,
|
||||
True,
|
||||
"不允许使用 Any 类型注解",
|
||||
"model-post-init-field-any-disallowed",
|
||||
),
|
||||
(
|
||||
Set[int],
|
||||
False,
|
||||
None,
|
||||
"model-post-init-allow-set-int",
|
||||
),
|
||||
(
|
||||
Dict[str, Optional[int]],
|
||||
False,
|
||||
None,
|
||||
"model-post-init-allow-dict-optional-int",
|
||||
),
|
||||
],
|
||||
ids=lambda v: v[3] if isinstance(v, tuple) else v,
|
||||
)
|
||||
def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id):
|
||||
# Arrange
|
||||
attrs = {
|
||||
"__annotations__": {"f": field_annotation},
|
||||
"f": Field(default=None),
|
||||
}
|
||||
model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs)
|
||||
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model_cls()
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
instance = model_cls()
|
||||
|
||||
# Assert
|
||||
assert hasattr(instance, "f")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 嵌套 ConfigBase & 非支持泛型 origin
|
||||
# ---------------------------------------------------------
|
||||
def test_model_post_init_allows_configbase_nested_class(self):
|
||||
class Child(ConfigBase):
|
||||
value: int = 1
|
||||
|
||||
class Parent(ConfigBase):
|
||||
child: Child = Field(default_factory=Child)
|
||||
|
||||
# Act
|
||||
parent = Parent()
|
||||
|
||||
# Assert
|
||||
assert isinstance(parent.child, Child)
|
||||
|
||||
def test_model_post_init_disallow_non_supported_generic_origin(self):
|
||||
class CustomGeneric(BaseModel):
|
||||
pass
|
||||
|
||||
class Sample(ConfigBase):
|
||||
f: CustomGeneric = Field(default_factory=CustomGeneric)
|
||||
|
||||
# Arrange / Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
Sample()
|
||||
assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# super().model_post_init 和 AttrDocBase.__post_init__ 调用
|
||||
# ---------------------------------------------------------
|
||||
def test_super_model_post_init_and_attrdoc_post_init_called(self):
|
||||
class Sample(ConfigBase):
|
||||
value: int = 1
|
||||
|
||||
# Act
|
||||
instance = Sample()
|
||||
|
||||
# Assert
|
||||
assert getattr(instance, "__post_init_called__", False) is True
|
||||
|
|
|
|||
|
|
@ -179,6 +179,8 @@ class ConfigBase(BaseModel, AttrDocBase):
|
|||
self._validate_list_set_type(anno, field_name)
|
||||
elif origin_type is Any:
|
||||
self._discourage_any_usage(field_name)
|
||||
elif origin_type in (int, float, str, bool, complex, bytes):
|
||||
return
|
||||
else:
|
||||
raise TypeError(
|
||||
f"类'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
|
||||
|
|
@ -207,7 +209,7 @@ class ConfigBase(BaseModel, AttrDocBase):
|
|||
self._discourage_any_usage(field_name)
|
||||
|
||||
# 非泛型注解视为原子类型,允许
|
||||
if origin_type in (int, float, str, bool, complex, bytes, type(None), Any):
|
||||
if origin_type in (int, float, str, bool, complex, bytes, Any):
|
||||
continue
|
||||
# 允许嵌套的ConfigBase自定义类
|
||||
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
|
||||
|
|
|
|||
Loading…
Reference in New Issue