mirror of https://github.com/Mai-with-u/MaiBot.git
feat: 支持在配额异常时自动切换多枚API密钥
parent
5ee3d7ea43
commit
f4829f166d
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
|
|
@ -13,8 +14,11 @@ class APIProvider(ConfigBase):
|
|||
base_url: str
|
||||
"""API基础URL"""
|
||||
|
||||
api_key: str = field(default_factory=str, repr=False)
|
||||
"""API密钥列表"""
|
||||
api_key: str | List[str] = field(default_factory=str, repr=False)
|
||||
"""API密钥(兼容字符串或字符串列表)"""
|
||||
|
||||
api_keys: List[str] = field(default_factory=list, repr=False)
|
||||
"""API密钥优先级列表(可选,覆盖单个api_key设置)"""
|
||||
|
||||
client_type: str = field(default="openai")
|
||||
"""客户端类型(如openai/google等,默认为openai)"""
|
||||
|
|
@ -28,13 +32,69 @@ class APIProvider(ConfigBase):
|
|||
retry_interval: int = 10
|
||||
"""重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
|
||||
|
||||
_ordered_keys: List[str] = field(init=False, repr=False, default_factory=list)
|
||||
_key_index: int = field(init=False, repr=False, default=0)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.api_key
|
||||
"""返回当前生效的API Key"""
|
||||
return self._ordered_keys[self._key_index]
|
||||
|
||||
def rotate_api_key(self, exclude: Optional[Set[str]] = None) -> Optional[str]:
|
||||
"""切换到下一枚可用的API Key,返回新Key;若无可切换则返回None"""
|
||||
if len(self._ordered_keys) <= 1:
|
||||
return None
|
||||
|
||||
original_index = self._key_index
|
||||
key_count = len(self._ordered_keys)
|
||||
|
||||
for _ in range(1, key_count):
|
||||
self._key_index = (self._key_index + 1) % key_count
|
||||
candidate = self._ordered_keys[self._key_index]
|
||||
if exclude and candidate in exclude:
|
||||
continue
|
||||
self.api_key = candidate
|
||||
return candidate
|
||||
|
||||
# 无可用Key,回退到原位置
|
||||
self._key_index = original_index
|
||||
self.api_key = self._ordered_keys[self._key_index]
|
||||
return None
|
||||
|
||||
def __post_init__(self):
|
||||
"""确保api_key在repr中不被显示"""
|
||||
if not self.api_key:
|
||||
raw_keys: List[str] = []
|
||||
|
||||
def _collect_keys(value):
|
||||
if not value:
|
||||
return
|
||||
if isinstance(value, str):
|
||||
# 支持逗号或换行分隔
|
||||
parts = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()]
|
||||
raw_keys.extend(parts)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, str) and item.strip():
|
||||
raw_keys.append(item.strip())
|
||||
|
||||
_collect_keys(self.api_key)
|
||||
_collect_keys(self.api_keys)
|
||||
|
||||
if not raw_keys:
|
||||
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
||||
|
||||
# 按顺序去重
|
||||
ordered_keys: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
for key in raw_keys:
|
||||
if key not in seen:
|
||||
ordered_keys.append(key)
|
||||
seen.add(key)
|
||||
|
||||
self._ordered_keys = ordered_keys
|
||||
self._key_index = 0
|
||||
self.api_keys = ordered_keys
|
||||
self.api_key = ordered_keys[0]
|
||||
|
||||
if not self.base_url and self.client_type != "gemini":
|
||||
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
||||
if not self.name:
|
||||
|
|
|
|||
|
|
@ -184,5 +184,9 @@ class ClientRegistry:
|
|||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
def invalidate_provider(self, provider_name: str) -> None:
|
||||
"""清理指定提供商的客户端缓存"""
|
||||
self.client_instance_cache.pop(provider_name, None)
|
||||
|
||||
|
||||
client_registry = ClientRegistry()
|
||||
|
|
|
|||
|
|
@ -241,6 +241,7 @@ class LLMRequest:
|
|||
"""
|
||||
retry_remain = api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
tried_api_keys: Set[str] = {api_provider.get_api_key()}
|
||||
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
|
|
@ -280,6 +281,20 @@ class LLMRequest:
|
|||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except RespNotOkException as e:
|
||||
# 针对鉴权/限流错误,尝试轮换API Key
|
||||
if e.status_code in {401, 403, 429}:
|
||||
rotated_key = api_provider.rotate_api_key(exclude=tried_api_keys)
|
||||
if rotated_key:
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 在提供商 '{api_provider.name}' 上触发 {e.status_code},已切换至新的API Key。"
|
||||
)
|
||||
tried_api_keys.add(rotated_key)
|
||||
client_registry.invalidate_provider(api_provider.name)
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||
compressed_messages = None
|
||||
retry_remain = api_provider.max_retry
|
||||
continue
|
||||
|
||||
# 可重试的HTTP错误
|
||||
if e.status_code == 429 or e.status_code >= 500:
|
||||
retry_remain -= 1
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ version = "1.7.4"
|
|||
[[api_providers]] # API服务提供商(可以配置多个)
|
||||
name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
|
||||
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
|
||||
api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥)
|
||||
api_key = "your-api-key-here" # API密钥(字符串或以逗号/换行分隔的多个密钥)
|
||||
# api_keys = ["key-1", "key-2"] # 可选:使用数组配置多个密钥,系统会在限流/额度耗尽时自动切换
|
||||
client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini")
|
||||
max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
|
||||
timeout = 120 # API请求超时时间(单位:秒)
|
||||
|
|
|
|||
Loading…
Reference in New Issue