From e1c03608c2e3bdba79aa301e7b73bc8808b46b3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 24 Dec 2025 21:57:35 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E6=94=AF=E6=8C=81Union/Optional?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E7=9A=84=E5=AD=97=E6=AE=B5=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=BC=BA=E9=85=8D=E7=BD=AE=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E7=9A=84=E7=81=B5=E6=B4=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/config_base.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/config/config_base.py b/src/config/config_base.py index 5fb39819..715ee944 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, fields, MISSING -from typing import TypeVar, Type, Any, get_origin, get_args, Literal +from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Union +import types T = TypeVar("T", bound="ConfigBase") @@ -108,6 +109,30 @@ class ConfigBase: return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} + # 处理 Union/Optional 类型(包括 float | None 这种 Python 3.10+ 语法) + # 注意: + # - Optional[float] 等价于 Union[float, None],get_origin() 返回 typing.Union + # - float | None 是 types.UnionType,get_origin() 返回 None + is_union_type = ( + field_origin_type is Union # typing.Optional / typing.Union + or isinstance(field_type, types.UnionType) # Python 3.10+ 的 | 语法 + ) + + if is_union_type: + union_args = field_type_args if field_type_args else get_args(field_type) + # 如果值是 None 且 None 在 Union 中,直接返回 + if value is None and type(None) in union_args: + return None + # 尝试转换为非 None 的类型 + for arg in union_args: + if arg is not type(None): + try: + return cls._convert_field(value, arg) + except (ValueError, TypeError): + continue + # 如果所有类型都转换失败,抛出异常 + raise TypeError("Cannot convert value to any type in Union") + # 处理基础类型,例如 int, str 等 if field_origin_type is type(None) and value is None: # 处理Optional类型 return None