使用Sourcery的Test,修复测出来的Bug

r-dev
UnCLAS-Prommer 2026-01-12 21:48:54 +08:00
parent afb993e481
commit 13f095f231
No known key found for this signature in database
2 changed files with 461 additions and 68 deletions

View File

@ -1,105 +1,496 @@
# 本文件为测试文件请忽略Lint error内含大量的ignore标识
from typing import Any, Optional, Union, List
from pathlib import Path
from importlib import util
import logging
import sys
from importlib import util
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import pytest
from pydantic import BaseModel, Field
# -------------------------------------------------------------
# 测试环境准备:补全 logger 和 AttrDocBase 依赖
# -------------------------------------------------------------
TEST_ROOT = Path(__file__).parent.parent.absolute().resolve()
logger_file = TEST_ROOT / "logger.py"
spec = util.spec_from_file_location("src.common.logger", logger_file)
module = util.module_from_spec(spec) # type: ignore
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module) # type: ignore
sys.modules["src.common.logger"] = module
# 测试对象导入
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.config.config_base import ConfigBase, Field # noqa: E402
from src.config.config_base import ConfigBase # noqa: E402
import src.config.config_base as config_base_module # noqa: E402
class IllegalConfig_Dict(ConfigBase):
a: dict = Field(default_factory=dict)
class AttrDocBase:
"""用于测试的轻量级 AttrDocBase 替身"""
def __post_init__(self) -> None:
# 被 ConfigBase.model_post_init 调用
self.__post_init_called__ = True
class IllegalConfig_List(ConfigBase):
b: list = Field(default_factory=list)
# 打补丁,让 ConfigBase 使用测试替身
@pytest.fixture(autouse=True)
def patch_attrdoc_post_init():
orig = config_base_module.AttrDocBase.__post_init__
config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore
yield
config_base_module.AttrDocBase.__post_init__ = orig
class IllegalConfig_Set(ConfigBase):
c: set = Field(default_factory=set)
config_base_module.logger = logging.getLogger("config_base_test_logger")
class IllegalConfig_Tuple(ConfigBase):
d: tuple = Field(default_factory=tuple)
class TestConfigBase:
# ---------------------------------------------------------
# happy path整体 model_post_init 测试
# ---------------------------------------------------------
@pytest.mark.parametrize(
"model_cls, init_kwargs, expected_fields",
[
pytest.param(
# 简单原子类型字段
type(
"SimpleAtomic",
(ConfigBase,),
{
"__annotations__": {
"a": int,
"b": str,
"c": bool,
"d": float,
},
"a": Field(default=1),
"b": Field(default="x"),
"c": Field(default=True),
"d": Field(default=1.5),
},
),
{},
{"a", "b", "c", "d"},
id="happy-simple-atomic-fields",
),
pytest.param(
# list/set/dict 泛型 + 原子内部类型
type(
"AtomicContainers",
(ConfigBase,),
{
"__annotations__": {
"ints": List[int],
"names": Set[str],
"mapping": Dict[str, int],
},
"ints": Field(default_factory=lambda: [1, 2]),
"names": Field(default_factory=lambda: {"a", "b"}),
"mapping": Field(default_factory=lambda: {"x": 1}),
},
),
{},
{"ints", "names", "mapping"},
id="happy-atomic-containers",
),
pytest.param(
# Optional 原子和 Optional 容器
type(
"OptionalFields",
(ConfigBase,),
{
"__annotations__": {
"maybe_int": Optional[int],
"maybe_str_list": Optional[List[str]],
},
"maybe_int": Field(default=None),
"maybe_str_list": Field(default=None),
},
),
{},
{"maybe_int", "maybe_str_list"},
id="happy-optional-fields",
),
],
)
def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields):
# Act
instance = model_cls(**init_kwargs)
# Assert
for field_name in expected_fields:
assert field_name in type(instance).model_fields
_ = getattr(instance, field_name)
assert getattr(instance, "__post_init_called__", False) is True
class IllegalConfig_Union(ConfigBase):
e: Union[int, str] = Field(default_factory=str)
# ---------------------------------------------------------
# _get_real_type
# ---------------------------------------------------------
def test_get_real_type_non_generic_and_generic(self):
class Sample(ConfigBase):
x: int = 1
y: List[int] = Field(default_factory=list)
instance = Sample()
class IllegalConfig_Any(ConfigBase):
f: Any = Field(default_factory=dict)
# Act
origin_x, args_x = instance._get_real_type(int)
# Assert
assert origin_x is int
assert args_x == ()
class IllegalConfig_NestedGeneric(ConfigBase):
g: list[List[int]] = Field(default_factory=list)
# Act
origin_y, args_y = instance._get_real_type(List[int])
# Assert
assert origin_y in (list, List)
assert args_y == (int,)
class IllegalConfig_Any_suppress(ConfigBase):
f: Any = Field(default_factory=dict)
_validate_any: bool = False
# ---------------------------------------------------------
# _validate_union_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment, expected_origin_type",
[
pytest.param(
int,
False,
None,
int,
id="union-validation-atomic-non-union",
),
pytest.param(
Optional[int],
False,
None,
int,
id="union-validation-optional-atomic",
),
pytest.param(
Optional[List[int]],
False,
None,
list,
id="union-validation-optional-container",
),
pytest.param(
Union[int, str],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-non-optional-union",
),
pytest.param(
int | str,
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-pep604-disallow-non-optional-union",
),
pytest.param(
Union[int, None, str],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-union-more-than-two",
),
pytest.param(
Optional[Union[int, str]],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-nested-optional-union",
),
],
)
def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type):
# 这里我们不实例化 Sample以避免在 __init__/model_post_init 阶段触发验证。
# 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。
class Dummy(ConfigBase):
pass
class SubClass(ConfigBase):
x: Optional[int] = Field(default=None)
y: list[int] = [123]
dummy = Dummy() # 最小初始化,避免字段校验
field_name = "v"
class LegalConfig(ConfigBase):
a: dict[str, list[int]] = Field(default_factory=dict)
b: list[int] = Field(default_factory=list)
c: set[str] = Field(default_factory=set)
d: Optional[str] = Field(default=None)
e: SubClass = Field(default_factory=SubClass)
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_union_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
origin, args, other = dummy._validate_union_type(annotation, field_name)
# Assert
assert origin is expected_origin_type
assert other is not None
@pytest.mark.parametrize(
"config_class, expected_exception, expected_message",
[
(IllegalConfig_Dict, TypeError, "必须指定键和值的类型参数"),
(IllegalConfig_List, TypeError, "必须指定且仅指定一个类型参数"),
(IllegalConfig_Set, TypeError, "必须指定且仅指定一个类型参数"),
(IllegalConfig_Tuple, TypeError, "不允许使用 Tuple 类型注解"),
(IllegalConfig_Union, TypeError, "不允许使用 Union 类型注解"),
(IllegalConfig_Any, TypeError, "不允许使用 Any 类型注解"),
(IllegalConfig_NestedGeneric, TypeError, "不允许嵌套泛型类型"),
(IllegalConfig_Any_suppress, None, ""),
],
)
def test_illegal_config(config_class, expected_exception, expected_message):
# sourcery skip: no-conditionals-in-tests
if expected_exception:
with pytest.raises(expected_exception) as exc_info:
config_class()
assert expected_message in str(exc_info.value)
assert expected_exception == exc_info.type
else:
config_instance = config_class()
assert isinstance(config_instance, config_class)
# ---------------------------------------------------------
# _validate_list_set_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment",
[
pytest.param(
List[int],
False,
None,
id="listset-validation-list-happy",
),
pytest.param(
Set[str],
False,
None,
id="listset-validation-set-happy",
),
pytest.param(
list,
True,
"必须指定且仅指定一个类型参数",
id="listset-validation-missing-type-arg",
),
pytest.param(
List[int | None],
True,
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-union",
),
pytest.param(
List[List[int]],
True,
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-list",
),
],
)
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
# 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败,
# 只测试 _validate_list_set_type 本身的逻辑。
class Dummy(ConfigBase):
pass
def test_legal_config():
config_instance = LegalConfig()
assert isinstance(config_instance, LegalConfig)
assert isinstance(config_instance.a, dict)
assert isinstance(config_instance.b, list)
assert isinstance(config_instance.c, set)
assert config_instance.d is None
assert isinstance(config_instance.e, SubClass)
assert config_instance.e.x is None
assert isinstance(config_instance.e.y, list)
assert config_instance.e.y == [123]
dummy = Dummy()
field_name = "items"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_list_set_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_list_set_type(annotation, field_name)
# ---------------------------------------------------------
# _validate_dict_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment",
[
pytest.param(
Dict[str, int],
False,
None,
id="dict-validation-happy-atomic",
),
pytest.param(
Dict[str, Any],
True,
"不允许使用 Any 类型注解",
id="dict-validation-any-value-disallowed",
),
pytest.param(
Dict[str, Dict[str, int]],
True,
"不允许嵌套泛型类型",
id="dict-validation-optional-nested-list",
),
pytest.param(
Dict,
True,
"必须指定键和值的类型参数",
id="dict-validation-missing-args",
),
],
)
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
# 同样不通过字段定义来触发 model_post_init只测试 _validate_dict_type 本身。
class Dummy(ConfigBase):
_validate_any: bool = True
dummy = Dummy()
field_name = "mapping"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_dict_type(annotation, field_name)
# ---------------------------------------------------------
# _discourage_any_usage
# ---------------------------------------------------------
def test_discourage_any_usage_raises_when_validate_any_true(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = True
instance = Sample()
# Act / Assert
with pytest.raises(TypeError) as exc_info:
instance._discourage_any_usage("field_x")
assert "不允许使用 Any 类型注解" in str(exc_info.value)
assert "建议避免使用" not in caplog.text
def test_discourage_any_usage_logs_when_validate_any_false(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = False
instance = Sample()
# Arrange
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
# Act
instance._discourage_any_usage("field_y")
# Assert
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
# ---------------------------------------------------------
# model_post_init 规则覆盖(错误与边界情况)
# ---------------------------------------------------------
@pytest.mark.parametrize(
"field_annotation, expect_error, error_fragment, test_id",
[
(
Tuple[int, int],
True,
"不允许使用 Tuple 类型注解",
"model-post-init-disallow-tuple-typing-tuple",
),
(
tuple[int, int],
True,
"不允许使用 Tuple 类型注解",
"model-post-init-disallow-pep604-tuple",
),
(
Union[int, str],
True,
"不允许使用 Union 类型注解",
"model-post-init-disallow-union-field",
),
(
list,
True,
"必须指定且仅指定一个类型参数",
"model-post-init-list-missing-type-arg",
),
(
List[List[int]],
True,
"不允许嵌套泛型类型",
"model-post-init-list-nested-generic",
),
(
Dict[str, Any],
True,
"不允许使用 Any 类型注解",
"model-post-init-dict-value-any",
),
(
Any,
True,
"不允许使用 Any 类型注解",
"model-post-init-field-any-disallowed",
),
(
Set[int],
False,
None,
"model-post-init-allow-set-int",
),
(
Dict[str, Optional[int]],
False,
None,
"model-post-init-allow-dict-optional-int",
),
],
ids=lambda v: v[3] if isinstance(v, tuple) else v,
)
def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id):
# Arrange
attrs = {
"__annotations__": {"f": field_annotation},
"f": Field(default=None),
}
model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs)
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
model_cls()
assert error_fragment in str(exc_info.value)
else:
# Act
instance = model_cls()
# Assert
assert hasattr(instance, "f")
# ---------------------------------------------------------
# 嵌套 ConfigBase & 非支持泛型 origin
# ---------------------------------------------------------
def test_model_post_init_allows_configbase_nested_class(self):
class Child(ConfigBase):
value: int = 1
class Parent(ConfigBase):
child: Child = Field(default_factory=Child)
# Act
parent = Parent()
# Assert
assert isinstance(parent.child, Child)
def test_model_post_init_disallow_non_supported_generic_origin(self):
class CustomGeneric(BaseModel):
pass
class Sample(ConfigBase):
f: CustomGeneric = Field(default_factory=CustomGeneric)
# Arrange / Act / Assert
with pytest.raises(TypeError) as exc_info:
Sample()
assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value)
# ---------------------------------------------------------
# super().model_post_init 和 AttrDocBase.__post_init__ 调用
# ---------------------------------------------------------
def test_super_model_post_init_and_attrdoc_post_init_called(self):
class Sample(ConfigBase):
value: int = 1
# Act
instance = Sample()
# Assert
assert getattr(instance, "__post_init_called__", False) is True

View File

@ -179,6 +179,8 @@ class ConfigBase(BaseModel, AttrDocBase):
self._validate_list_set_type(anno, field_name)
elif origin_type is Any:
self._discourage_any_usage(field_name)
elif origin_type in (int, float, str, bool, complex, bytes):
return
else:
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
@ -207,7 +209,7 @@ class ConfigBase(BaseModel, AttrDocBase):
self._discourage_any_usage(field_name)
# 非泛型注解视为原子类型,允许
if origin_type in (int, float, str, bool, complex, bytes, type(None), Any):
if origin_type in (int, float, str, bool, complex, bytes, Any):
continue
# 允许嵌套的ConfigBase自定义类
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore