From f4829f166d0efb837e260c871822dd7131035d7c Mon Sep 17 00:00:00 2001 From: magisk317 Date: Thu, 23 Oct 2025 17:37:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=9C=A8=E9=85=8D?= =?UTF-8?q?=E9=A2=9D=E5=BC=82=E5=B8=B8=E6=97=B6=E8=87=AA=E5=8A=A8=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E5=A4=9A=E6=9E=9AAPI=E5=AF=86=E9=92=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/api_ada_configs.py | 68 ++++++++++++++++++++-- src/llm_models/model_client/base_client.py | 4 ++ src/llm_models/utils_model.py | 15 +++++ template/model_config_template.toml | 3 +- 4 files changed, 85 insertions(+), 5 deletions(-) diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 3fc9c878..15c58e9b 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import List, Optional, Set from .config_base import ConfigBase @@ -13,8 +14,11 @@ class APIProvider(ConfigBase): base_url: str """API基础URL""" - api_key: str = field(default_factory=str, repr=False) - """API密钥列表""" + api_key: str | List[str] = field(default_factory=str, repr=False) + """API密钥(兼容字符串或字符串列表)""" + + api_keys: List[str] = field(default_factory=list, repr=False) + """API密钥优先级列表(可选,覆盖单个api_key设置)""" client_type: str = field(default="openai") """客户端类型(如openai/google等,默认为openai)""" @@ -28,13 +32,69 @@ class APIProvider(ConfigBase): retry_interval: int = 10 """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)""" + _ordered_keys: List[str] = field(init=False, repr=False, default_factory=list) + _key_index: int = field(init=False, repr=False, default=0) + def get_api_key(self) -> str: - return self.api_key + """返回当前生效的API Key""" + return self._ordered_keys[self._key_index] + + def rotate_api_key(self, exclude: Optional[Set[str]] = None) -> Optional[str]: + """切换到下一枚可用的API Key,返回新Key;若无可切换则返回None""" + if len(self._ordered_keys) <= 1: + return None + + original_index = self._key_index + key_count = len(self._ordered_keys) + + for _ in range(1, key_count): + self._key_index = (self._key_index + 1) % key_count + candidate = self._ordered_keys[self._key_index] + if exclude and candidate in exclude: + continue + self.api_key = candidate + return candidate + + # 无可用Key,回退到原位置 + self._key_index = original_index + self.api_key = self._ordered_keys[self._key_index] + return None def __post_init__(self): """确保api_key在repr中不被显示""" - if not self.api_key: + raw_keys: List[str] = [] + + def _collect_keys(value): + if not value: + return + if isinstance(value, str): + # 支持逗号或换行分隔 + parts = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()] + raw_keys.extend(parts) + elif isinstance(value, list): + for item in value: + if isinstance(item, str) and item.strip(): + raw_keys.append(item.strip()) + + _collect_keys(self.api_key) + _collect_keys(self.api_keys) + + if not raw_keys: raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。") + + # 按顺序去重 + ordered_keys: List[str] = [] + seen: Set[str] = set() + for key in raw_keys: + if key not in seen: + ordered_keys.append(key) + seen.add(key) + + self._ordered_keys = ordered_keys + self._key_index = 0 + self.api_keys = ordered_keys + self.api_key = ordered_keys[0] + if not self.base_url and self.client_type != "gemini": raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。") if not self.name: diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index dcb710fe..9c07ac8a 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -184,5 +184,9 @@ class ClientRegistry: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") return self.client_instance_cache[api_provider.name] + def invalidate_provider(self, provider_name: str) -> None: + """清理指定提供商的客户端缓存""" + self.client_instance_cache.pop(provider_name, None) + client_registry = ClientRegistry() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 4d7865d9..b6118090 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -241,6 +241,7 @@ class LLMRequest: """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None + tried_api_keys: Set[str] = {api_provider.get_api_key()} while retry_remain > 0: try: @@ -280,6 +281,20 @@ class LLMRequest: await asyncio.sleep(api_provider.retry_interval) except RespNotOkException as e: + # 针对鉴权/限流错误,尝试轮换API Key + if e.status_code in {401, 403, 429}: + rotated_key = api_provider.rotate_api_key(exclude=tried_api_keys) + if rotated_key: + logger.warning( + f"模型 '{model_info.name}' 在提供商 '{api_provider.name}' 上触发 {e.status_code},已切换至新的API Key。" + ) + tried_api_keys.add(rotated_key) + client_registry.invalidate_provider(api_provider.name) + client = client_registry.get_client_class_instance(api_provider, force_new=True) + compressed_messages = None + retry_remain = api_provider.max_retry + continue + # 可重试的HTTP错误 if e.status_code == 429 or e.status_code >= 500: retry_remain -= 1 diff --git a/template/model_config_template.toml b/template/model_config_template.toml index c97a6579..925e378e 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -6,7 +6,8 @@ version = "1.7.4" [[api_providers]] # API服务提供商(可以配置多个) name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL -api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥) +api_key = "your-api-key-here" # API密钥(字符串或以逗号/换行分隔的多个密钥) +# api_keys = ["key-1", "key-2"] # 可选:使用数组配置多个密钥,系统会在限流/额度耗尽时自动切换 client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) timeout = 120 # API请求超时时间(单位:秒)