mirror of https://github.com/Mai-with-u/MaiBot.git
feat: 添加提供商连接测试接口,支持通过 URL 和名称验证连接状态
parent
3d1f26ae1b
commit
e06a35fe81
|
|
@ -242,3 +242,126 @@ async def get_models_by_url(
|
||||||
"models": models,
|
"models": models,
|
||||||
"count": len(models),
|
"count": len(models),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test-connection")
|
||||||
|
async def test_provider_connection(
|
||||||
|
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||||
|
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试提供商连接状态
|
||||||
|
|
||||||
|
分两步测试:
|
||||||
|
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||||
|
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- network_ok: 网络是否连通
|
||||||
|
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||||
|
- latency_ms: 响应延迟(毫秒)
|
||||||
|
- error: 错误信息(如果有)
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
base_url = _normalize_url(base_url)
|
||||||
|
if not base_url:
|
||||||
|
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"network_ok": False,
|
||||||
|
"api_key_valid": None,
|
||||||
|
"latency_ms": None,
|
||||||
|
"error": None,
|
||||||
|
"http_status": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 第一步:测试网络连通性
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||||
|
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||||
|
response = await client.get(base_url)
|
||||||
|
latency = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
result["network_ok"] = True
|
||||||
|
result["latency_ms"] = round(latency, 2)
|
||||||
|
result["http_status"] = response.status_code
|
||||||
|
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||||
|
return result
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
result["error"] = "连接超时:服务器响应时间过长"
|
||||||
|
return result
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
result["error"] = f"请求错误:{str(e)}"
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = f"未知错误:{str(e)}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 第二步:如果提供了 API Key,验证其有效性
|
||||||
|
if api_key:
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
# 尝试获取模型列表
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result["api_key_valid"] = True
|
||||||
|
elif response.status_code in (401, 403):
|
||||||
|
result["api_key_valid"] = False
|
||||||
|
result["error"] = "API Key 无效或已过期"
|
||||||
|
else:
|
||||||
|
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||||
|
result["api_key_valid"] = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# API Key 验证失败不影响网络连通性结果
|
||||||
|
logger.warning(f"API Key 验证失败: {e}")
|
||||||
|
result["api_key_valid"] = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-connection-by-name")
|
||||||
|
async def test_provider_connection_by_name(
|
||||||
|
provider_name: str = Query(..., description="提供商名称"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通过提供商名称测试连接(从配置文件读取信息)
|
||||||
|
"""
|
||||||
|
# 读取配置文件
|
||||||
|
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
|
if not os.path.exists(model_config_path):
|
||||||
|
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||||
|
|
||||||
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = tomlkit.load(f)
|
||||||
|
|
||||||
|
# 查找提供商
|
||||||
|
providers = config.get("api_providers", [])
|
||||||
|
provider = None
|
||||||
|
for p in providers:
|
||||||
|
if p.get("name") == provider_name:
|
||||||
|
provider = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||||
|
|
||||||
|
base_url = provider.get("base_url", "")
|
||||||
|
api_key = provider.get("api_key", "")
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||||
|
|
||||||
|
# 调用测试接口
|
||||||
|
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue