MaiBot/pytests/config_test/test_config_base.py

535 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import sys
from importlib import util
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import pytest
from pydantic import BaseModel, Field
# -------------------------------------------------------------
# 测试环境准备:补全 logger 和 AttrDocBase 依赖
# -------------------------------------------------------------
TEST_ROOT = Path(__file__).parent.parent.absolute().resolve()
logger_file = TEST_ROOT / "logger.py"
spec = util.spec_from_file_location("src.common.logger", logger_file)
module = util.module_from_spec(spec) # type: ignore
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module) # type: ignore
sys.modules["src.common.logger"] = module
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.config.config_base import ConfigBase # noqa: E402
import src.config.config_base as config_base_module # noqa: E402
class AttrDocBase:
"""用于测试的轻量级 AttrDocBase 替身"""
def __post_init__(self) -> None:
# 被 ConfigBase.model_post_init 调用
self.__post_init_called__ = True
# 打补丁,让 ConfigBase 使用测试替身
@pytest.fixture(autouse=True)
def patch_attrdoc_post_init():
orig = config_base_module.AttrDocBase.__post_init__
config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore
yield
config_base_module.AttrDocBase.__post_init__ = orig
config_base_module.logger = logging.getLogger("config_base_test_logger")
class SimpleClass(ConfigBase):
a: int = 1
b: str = "test"
class TestConfigBase:
# ---------------------------------------------------------
# happy path整体 model_post_init 测试
# ---------------------------------------------------------
@pytest.mark.parametrize(
"model_cls, init_kwargs, expected_fields",
[
pytest.param(
# 简单原子类型字段
type(
"SimpleAtomic",
(ConfigBase,),
{
"__annotations__": {
"a": int,
"b": str,
"c": bool,
"d": float,
},
"a": Field(default=1),
"b": Field(default="x"),
"c": Field(default=True),
"d": Field(default=1.5),
},
),
{},
{"a", "b", "c", "d"},
id="happy-simple-atomic-fields",
),
pytest.param(
# list/set/dict 泛型 + 原子内部类型
type(
"AtomicContainers",
(ConfigBase,),
{
"__annotations__": {
"ints": List[int],
"names": Set[str],
"mapping": Dict[str, int],
},
"ints": Field(default_factory=lambda: [1, 2]),
"names": Field(default_factory=lambda: {"a", "b"}),
"mapping": Field(default_factory=lambda: {"x": 1}),
},
),
{},
{"ints", "names", "mapping"},
id="happy-atomic-containers",
),
pytest.param(
# Optional 原子和 Optional 容器
type(
"OptionalFields",
(ConfigBase,),
{
"__annotations__": {
"maybe_int": Optional[int],
"maybe_str_list": Optional[List[str]],
},
"maybe_int": Field(default=None),
"maybe_str_list": Field(default=None),
},
),
{},
{"maybe_int", "maybe_str_list"},
id="happy-optional-fields",
),
],
)
def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields):
# Act
instance = model_cls(**init_kwargs)
# Assert
for field_name in expected_fields:
assert field_name in type(instance).model_fields
_ = getattr(instance, field_name)
assert getattr(instance, "__post_init_called__", False) is True
# ---------------------------------------------------------
# _get_real_type
# ---------------------------------------------------------
def test_get_real_type_non_generic_and_generic(self):
class Sample(ConfigBase):
x: int = 1
y: List[int] = Field(default_factory=list)
instance = Sample()
# Act
origin_x, args_x = instance._get_real_type(int)
# Assert
assert origin_x is int
assert args_x == ()
# Act
origin_y, args_y = instance._get_real_type(List[int])
# Assert
assert origin_y in (list, List)
assert args_y == (int,)
# ---------------------------------------------------------
# _validate_union_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment, expected_origin_type",
[
pytest.param(
int,
False,
None,
int,
id="union-validation-atomic-non-union",
),
pytest.param(
Optional[int],
False,
None,
int,
id="union-validation-optional-atomic",
),
pytest.param(
Optional[List[int]],
False,
None,
list,
id="union-validation-optional-container",
),
pytest.param(
Union[int, str],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-non-optional-union",
),
pytest.param(
int | str,
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-pep604-disallow-non-optional-union",
),
pytest.param(
Union[int, None, str],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-union-more-than-two",
),
pytest.param(
Optional[Union[int, str]],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-nested-optional-union",
),
],
)
def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type):
# 这里我们不实例化 Sample以避免在 __init__/model_post_init 阶段触发验证。
# 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。
class Dummy(ConfigBase):
pass
dummy = Dummy() # 最小初始化,避免字段校验
field_name = "v"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_union_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
origin, args, other = dummy._validate_union_type(annotation, field_name)
# Assert
assert origin is expected_origin_type
assert other is not None
# ---------------------------------------------------------
# _validate_list_set_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment",
[
pytest.param(
List[int],
False,
None,
id="listset-validation-list-happy",
),
pytest.param(
Set[str],
False,
None,
id="listset-validation-set-happy",
),
pytest.param(
list,
True,
"必须指定且仅指定一个类型参数",
id="listset-validation-missing-type-arg",
),
pytest.param(
List[int | None],
True,
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-union",
),
pytest.param(
List[List[int]],
True,
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-list",
),
pytest.param(
List[SimpleClass],
False,
None,
id="listset-validation-list-configbase-element_allow",
),
pytest.param(
Set[SimpleClass],
True,
"ConfigBase is not Hashable",
id="listset-validation-set-configbase-element_reject",
)
],
)
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
# 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败,
# 只测试 _validate_list_set_type 本身的逻辑。
class Dummy(ConfigBase):
pass
dummy = Dummy()
field_name = "items"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_list_set_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_list_set_type(annotation, field_name)
# ---------------------------------------------------------
# _validate_dict_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment",
[
pytest.param(
Dict[str, int],
False,
None,
id="dict-validation-happy-atomic",
),
pytest.param(
Dict[str, Any],
True,
"不允许使用 Any 类型注解",
id="dict-validation-any-value-disallowed",
),
pytest.param(
Dict[str, Dict[str, int]],
True,
"不允许嵌套泛型类型",
id="dict-validation-optional-nested-list",
),
pytest.param(
Dict,
True,
"必须指定键和值的类型参数",
id="dict-validation-missing-args",
),
pytest.param(
Dict[str, SimpleClass],
False,
None,
id="dict-validation-happy-configbase-value",
)
],
)
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
# 同样不通过字段定义来触发 model_post_init只测试 _validate_dict_type 本身。
class Dummy(ConfigBase):
_validate_any: bool = True
dummy = Dummy()
field_name = "mapping"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_dict_type(annotation, field_name)
# ---------------------------------------------------------
# _discourage_any_usage
# ---------------------------------------------------------
def test_discourage_any_usage_raises_when_validate_any_true(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = True
instance = Sample()
# Act / Assert
with pytest.raises(TypeError) as exc_info:
instance._discourage_any_usage("field_x")
assert "不允许使用 Any 类型注解" in str(exc_info.value)
assert "建议避免使用" not in caplog.text
def test_discourage_any_usage_logs_when_validate_any_false(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = False
instance = Sample()
# Arrange
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
# Act
instance._discourage_any_usage("field_y")
# Assert
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
def test_discourage_any_usage_suppressed_warning(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = False
suppress_any_warning: bool = True
instance = Sample()
# Arrange
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
# Act
instance._discourage_any_usage("field_z")
# Assert
assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text
# ---------------------------------------------------------
# model_post_init 规则覆盖(错误与边界情况)
# ---------------------------------------------------------
@pytest.mark.parametrize(
"field_annotation, expect_error, error_fragment, test_id",
[
(
Tuple[int, int],
True,
"不允许使用 Tuple 类型注解",
"model-post-init-disallow-tuple-typing-tuple",
),
(
tuple[int, int],
True,
"不允许使用 Tuple 类型注解",
"model-post-init-disallow-pep604-tuple",
),
(
Union[int, str],
True,
"不允许使用 Union 类型注解",
"model-post-init-disallow-union-field",
),
(
list,
True,
"必须指定且仅指定一个类型参数",
"model-post-init-list-missing-type-arg",
),
(
List[List[int]],
True,
"不允许嵌套泛型类型",
"model-post-init-list-nested-generic",
),
(
Dict[str, Any],
True,
"不允许使用 Any 类型注解",
"model-post-init-dict-value-any",
),
(
Any,
True,
"不允许使用 Any 类型注解",
"model-post-init-field-any-disallowed",
),
(
Set[int],
False,
None,
"model-post-init-allow-set-int",
),
(
Dict[str, Optional[int]],
False,
None,
"model-post-init-allow-dict-optional-int",
),
],
ids=lambda v: v[3] if isinstance(v, tuple) else v,
)
def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id):
# Arrange
attrs = {
"__annotations__": {"f": field_annotation},
"f": Field(default=None),
}
model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs)
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
model_cls()
assert error_fragment in str(exc_info.value)
else:
# Act
instance = model_cls()
# Assert
assert hasattr(instance, "f")
# ---------------------------------------------------------
# 嵌套 ConfigBase & 非支持泛型 origin
# ---------------------------------------------------------
def test_model_post_init_allows_configbase_nested_class(self):
class Child(ConfigBase):
value: int = 1
class Parent(ConfigBase):
child: Child = Field(default_factory=Child)
# Act
parent = Parent()
# Assert
assert isinstance(parent.child, Child)
def test_model_post_init_disallow_non_supported_generic_origin(self):
class CustomGeneric(BaseModel):
pass
class Sample(ConfigBase):
f: CustomGeneric = Field(default_factory=CustomGeneric)
# Arrange / Act / Assert
with pytest.raises(TypeError) as exc_info:
Sample()
assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value)
# ---------------------------------------------------------
# super().model_post_init 和 AttrDocBase.__post_init__ 调用
# ---------------------------------------------------------
def test_super_model_post_init_and_attrdoc_post_init_called(self):
class Sample(ConfigBase):
value: int = 1
# Act
instance = Sample()
# Assert
assert getattr(instance, "__post_init_called__", False) is True