diff --git a/src/config/model_configs.py b/src/config/model_configs.py index 374aef59..9665d9c6 100644 --- a/src/config/model_configs.py +++ b/src/config/model_configs.py @@ -5,25 +5,73 @@ from .config_base import ConfigBase, Field class APIProvider(ConfigBase): """API提供商配置类""" - name: str = "" + name: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "tag", + }, + ) """API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)""" - base_url: str = "" + base_url: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "link", + }, + ) """API服务商的BaseURL""" - api_key: str = Field(default_factory=str, repr=False) + api_key: str = Field( + default_factory=str, + repr=False, + json_schema_extra={ + "x-widget": "input", + "x-icon": "key", + }, + ) """API密钥""" - client_type: str = Field(default="openai") + client_type: str = Field( + default="openai", + json_schema_extra={ + "x-widget": "select", + "x-icon": "settings", + }, + ) """客户端类型 (可选: openai/google, 默认为openai)""" - max_retry: int = Field(default=2) + max_retry: int = Field( + default=2, + ge=0, + json_schema_extra={ + "x-widget": "input", + "x-icon": "repeat", + }, + ) """最大重试次数 (单个模型API调用失败, 最多重试的次数)""" - timeout: int = 10 + timeout: int = Field( + default=10, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "clock", + "step": 1, + }, + ) """API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)""" - retry_interval: int = 10 + retry_interval: int = Field( + default=10, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "timer", + "step": 1, + }, + ) """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)""" def model_post_init(self, context: Any = None): @@ -39,34 +87,93 @@ class APIProvider(ConfigBase): class ModelInfo(ConfigBase): """单个模型信息配置类""" + _validate_any: bool = False suppress_any_warning: bool = True - model_identifier: str = "" + model_identifier: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "package", + }, + ) """模型标识符 (API服务商提供的模型标识符)""" - name: str = "" + name: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "tag", + }, + ) """模型名称 (可随意命名, 在models中需使用这个命名)""" - api_provider: str = "" + api_provider: str = Field( + default="", + json_schema_extra={ + "x-widget": "select", + "x-icon": "link", + }, + ) """API服务商名称 (对应在api_providers中配置的服务商名称)""" - price_in: float = Field(default=0.0) + price_in: float = Field( + default=0.0, + ge=0, + json_schema_extra={ + "x-widget": "input", + "x-icon": "dollar-sign", + "step": 0.001, + }, + ) """输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" - price_out: float = Field(default=0.0) + price_out: float = Field( + default=0.0, + ge=0, + json_schema_extra={ + "x-widget": "input", + "x-icon": "dollar-sign", + "step": 0.001, + }, + ) """输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" - - temperature: float | None = Field(default=None) + + temperature: float | None = Field( + default=None, + json_schema_extra={ + "x-widget": "input", + "x-icon": "thermometer", + }, + ) """模型级别温度(可选),会覆盖任务配置中的温度""" - max_tokens: int | None = Field(default=None) + max_tokens: int | None = Field( + default=None, + json_schema_extra={ + "x-widget": "input", + "x-icon": "layers", + }, + ) """模型级别最大token数(可选),会覆盖任务配置中的max_tokens""" - force_stream_mode: bool = Field(default=False) + force_stream_mode: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "zap", + }, + ) """强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)""" - extra_params: dict[str, Any] = Field(default_factory=dict) + extra_params: dict[str, Any] = Field( + default_factory=dict, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "sliders", + }, + ) """额外参数 (用于API调用时的额外配置)""" def model_post_init(self, context: Any = None): @@ -82,48 +189,139 @@ class ModelInfo(ConfigBase): class TaskConfig(ConfigBase): """任务配置类""" - model_list: list[str] = Field(default_factory=list) + model_list: list[str] = Field( + default_factory=list, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "list", + }, + ) """使用的模型列表, 每个元素对应上面的模型名称(name)""" - max_tokens: int = 1024 + max_tokens: int = Field( + default=1024, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "layers", + "step": 1, + }, + ) """任务最大输出token数""" - temperature: float = 0.3 + temperature: float = Field( + default=0.3, + ge=0, + le=2, + json_schema_extra={ + "x-widget": "slider", + "x-icon": "thermometer", + "step": 0.1, + }, + ) """模型温度""" - - slow_threshold: float = 15.0 + + slow_threshold: float = Field( + default=15.0, + ge=0, + json_schema_extra={ + "x-widget": "input", + "x-icon": "alert-circle", + "step": 0.1, + }, + ) """慢请求阈值(秒),超过此值会输出警告日志""" - selection_strategy: str = Field(default="balance") + selection_strategy: str = Field( + default="balance", + json_schema_extra={ + "x-widget": "select", + "x-icon": "shuffle", + }, + ) """模型选择策略:balance(负载均衡)或 random(随机选择)""" class ModelTaskConfig(ConfigBase): """模型配置类""" - utils: TaskConfig = Field(default_factory=TaskConfig) + utils: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "wrench", + }, + ) """组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型""" - replyer: TaskConfig = Field(default_factory=TaskConfig) + replyer: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "message-square", + }, + ) """首要回复模型配置, 还用于表达器和表达方式学习""" - vlm: TaskConfig = Field(default_factory=TaskConfig) + vlm: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "image", + }, + ) """视觉模型配置""" - voice: TaskConfig = Field(default_factory=TaskConfig) + voice: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "volume-2", + }, + ) """语音识别模型配置""" - tool_use: TaskConfig = Field(default_factory=TaskConfig) + tool_use: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "tools", + }, + ) """工具使用模型配置, 需要使用支持工具调用的模型""" - planner: TaskConfig = Field(default_factory=TaskConfig) + planner: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "map", + }, + ) """规划模型配置""" - embedding: TaskConfig = Field(default_factory=TaskConfig) + embedding: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "database", + }, + ) """嵌入模型配置""" - lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig) + lpmm_entity_extract: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "filter", + }, + ) """LPMM实体提取模型配置""" - lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig) + lpmm_rdf_build: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "network", + }, + ) """LPMM RDF构建模型配置"""