mirror of https://github.com/Mai-with-u/MaiBot.git
feat: 支持在配额异常时自动切换多枚API密钥
parent
5ee3d7ea43
commit
f4829f166d
|
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Set
|
||||||
|
|
||||||
from .config_base import ConfigBase
|
from .config_base import ConfigBase
|
||||||
|
|
||||||
|
|
@ -13,8 +14,11 @@ class APIProvider(ConfigBase):
|
||||||
base_url: str
|
base_url: str
|
||||||
"""API基础URL"""
|
"""API基础URL"""
|
||||||
|
|
||||||
api_key: str = field(default_factory=str, repr=False)
|
api_key: str | List[str] = field(default_factory=str, repr=False)
|
||||||
"""API密钥列表"""
|
"""API密钥(兼容字符串或字符串列表)"""
|
||||||
|
|
||||||
|
api_keys: List[str] = field(default_factory=list, repr=False)
|
||||||
|
"""API密钥优先级列表(可选,覆盖单个api_key设置)"""
|
||||||
|
|
||||||
client_type: str = field(default="openai")
|
client_type: str = field(default="openai")
|
||||||
"""客户端类型(如openai/google等,默认为openai)"""
|
"""客户端类型(如openai/google等,默认为openai)"""
|
||||||
|
|
@ -28,13 +32,69 @@ class APIProvider(ConfigBase):
|
||||||
retry_interval: int = 10
|
retry_interval: int = 10
|
||||||
"""重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
|
"""重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
|
||||||
|
|
||||||
|
_ordered_keys: List[str] = field(init=False, repr=False, default_factory=list)
|
||||||
|
_key_index: int = field(init=False, repr=False, default=0)
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
return self.api_key
|
"""返回当前生效的API Key"""
|
||||||
|
return self._ordered_keys[self._key_index]
|
||||||
|
|
||||||
|
def rotate_api_key(self, exclude: Optional[Set[str]] = None) -> Optional[str]:
|
||||||
|
"""切换到下一枚可用的API Key,返回新Key;若无可切换则返回None"""
|
||||||
|
if len(self._ordered_keys) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
original_index = self._key_index
|
||||||
|
key_count = len(self._ordered_keys)
|
||||||
|
|
||||||
|
for _ in range(1, key_count):
|
||||||
|
self._key_index = (self._key_index + 1) % key_count
|
||||||
|
candidate = self._ordered_keys[self._key_index]
|
||||||
|
if exclude and candidate in exclude:
|
||||||
|
continue
|
||||||
|
self.api_key = candidate
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# 无可用Key,回退到原位置
|
||||||
|
self._key_index = original_index
|
||||||
|
self.api_key = self._ordered_keys[self._key_index]
|
||||||
|
return None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""确保api_key在repr中不被显示"""
|
"""确保api_key在repr中不被显示"""
|
||||||
if not self.api_key:
|
raw_keys: List[str] = []
|
||||||
|
|
||||||
|
def _collect_keys(value):
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
if isinstance(value, str):
|
||||||
|
# 支持逗号或换行分隔
|
||||||
|
parts = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()]
|
||||||
|
raw_keys.extend(parts)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str) and item.strip():
|
||||||
|
raw_keys.append(item.strip())
|
||||||
|
|
||||||
|
_collect_keys(self.api_key)
|
||||||
|
_collect_keys(self.api_keys)
|
||||||
|
|
||||||
|
if not raw_keys:
|
||||||
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
||||||
|
|
||||||
|
# 按顺序去重
|
||||||
|
ordered_keys: List[str] = []
|
||||||
|
seen: Set[str] = set()
|
||||||
|
for key in raw_keys:
|
||||||
|
if key not in seen:
|
||||||
|
ordered_keys.append(key)
|
||||||
|
seen.add(key)
|
||||||
|
|
||||||
|
self._ordered_keys = ordered_keys
|
||||||
|
self._key_index = 0
|
||||||
|
self.api_keys = ordered_keys
|
||||||
|
self.api_key = ordered_keys[0]
|
||||||
|
|
||||||
if not self.base_url and self.client_type != "gemini":
|
if not self.base_url and self.client_type != "gemini":
|
||||||
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
||||||
if not self.name:
|
if not self.name:
|
||||||
|
|
|
||||||
|
|
@ -184,5 +184,9 @@ class ClientRegistry:
|
||||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
return self.client_instance_cache[api_provider.name]
|
return self.client_instance_cache[api_provider.name]
|
||||||
|
|
||||||
|
def invalidate_provider(self, provider_name: str) -> None:
|
||||||
|
"""清理指定提供商的客户端缓存"""
|
||||||
|
self.client_instance_cache.pop(provider_name, None)
|
||||||
|
|
||||||
|
|
||||||
client_registry = ClientRegistry()
|
client_registry = ClientRegistry()
|
||||||
|
|
|
||||||
|
|
@ -241,6 +241,7 @@ class LLMRequest:
|
||||||
"""
|
"""
|
||||||
retry_remain = api_provider.max_retry
|
retry_remain = api_provider.max_retry
|
||||||
compressed_messages: Optional[List[Message]] = None
|
compressed_messages: Optional[List[Message]] = None
|
||||||
|
tried_api_keys: Set[str] = {api_provider.get_api_key()}
|
||||||
|
|
||||||
while retry_remain > 0:
|
while retry_remain > 0:
|
||||||
try:
|
try:
|
||||||
|
|
@ -280,6 +281,20 @@ class LLMRequest:
|
||||||
await asyncio.sleep(api_provider.retry_interval)
|
await asyncio.sleep(api_provider.retry_interval)
|
||||||
|
|
||||||
except RespNotOkException as e:
|
except RespNotOkException as e:
|
||||||
|
# 针对鉴权/限流错误,尝试轮换API Key
|
||||||
|
if e.status_code in {401, 403, 429}:
|
||||||
|
rotated_key = api_provider.rotate_api_key(exclude=tried_api_keys)
|
||||||
|
if rotated_key:
|
||||||
|
logger.warning(
|
||||||
|
f"模型 '{model_info.name}' 在提供商 '{api_provider.name}' 上触发 {e.status_code},已切换至新的API Key。"
|
||||||
|
)
|
||||||
|
tried_api_keys.add(rotated_key)
|
||||||
|
client_registry.invalidate_provider(api_provider.name)
|
||||||
|
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||||
|
compressed_messages = None
|
||||||
|
retry_remain = api_provider.max_retry
|
||||||
|
continue
|
||||||
|
|
||||||
# 可重试的HTTP错误
|
# 可重试的HTTP错误
|
||||||
if e.status_code == 429 or e.status_code >= 500:
|
if e.status_code == 429 or e.status_code >= 500:
|
||||||
retry_remain -= 1
|
retry_remain -= 1
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ version = "1.7.4"
|
||||||
[[api_providers]] # API服务提供商(可以配置多个)
|
[[api_providers]] # API服务提供商(可以配置多个)
|
||||||
name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
|
name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
|
||||||
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
|
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
|
||||||
api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥)
|
api_key = "your-api-key-here" # API密钥(字符串或以逗号/换行分隔的多个密钥)
|
||||||
|
# api_keys = ["key-1", "key-2"] # 可选:使用数组配置多个密钥,系统会在限流/额度耗尽时自动切换
|
||||||
client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini")
|
client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini")
|
||||||
max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
|
max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
|
||||||
timeout = 120 # API请求超时时间(单位:秒)
|
timeout = 120 # API请求超时时间(单位:秒)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue