feat:添加模型列表获取API路由

pull/1385/head
墨梓柒 2025-11-26 22:42:46 +08:00
parent 74d563ad38
commit 0a0fb8ee55
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
2 changed files with 245 additions and 0 deletions

View File

@ -0,0 +1,242 @@
"""
模型列表获取API路由
提供从各个 AI 厂商 API 获取可用模型列表的代理接口
"""
import os
import httpx
from fastapi import APIRouter, HTTPException, Query
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
"openai": {
"endpoint": "/models",
"parser": "openai",
},
# Gemini 格式
"gemini": {
"endpoint": "/models",
"parser": "gemini",
},
}
def _normalize_url(url: str) -> str:
"""规范化 URL去掉尾部斜杠"""
if not url:
return ""
return url.rstrip("/")
def _parse_openai_response(data: dict) -> list[dict]:
"""
解析 OpenAI 格式的模型列表响应
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
"""
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
if isinstance(model, dict) and "id" in model:
models.append({
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
})
return models
def _parse_gemini_response(data: dict) -> list[dict]:
"""
解析 Gemini 格式的模型列表响应
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
"""
models = []
if "models" in data and isinstance(data["models"], list):
for model in data["models"]:
if isinstance(model, dict) and "name" in model:
# Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分
model_id = model["name"]
if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀
models.append({
"id": model_id,
"name": model.get("displayName") or model_id,
"owned_by": "google",
})
return models
async def _fetch_models_from_provider(
base_url: str,
api_key: str,
endpoint: str,
parser: str,
client_type: str = "openai",
) -> list[dict]:
"""
从提供商 API 获取模型列表
Args:
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini')
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
params = {}
if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise HTTPException(status_code=401, detail="API Key 无效或已过期")
elif e.response.status_code == 403:
raise HTTPException(status_code=403, detail="没有权限访问模型列表")
elif e.response.status_code == 404:
raise HTTPException(status_code=404, detail="该提供商不支持获取模型列表")
else:
raise HTTPException(
status_code=e.response.status_code,
detail=f"请求失败: {e.response.text}"
)
except Exception as e:
logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
# 根据解析器类型解析响应
if parser == "openai":
return _parse_openai_response(data)
elif parser == "gemini":
return _parse_gemini_response(data)
else:
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
def _get_provider_config(provider_name: str) -> Optional[dict]:
"""
model_config.toml 获取指定提供商的配置
Args:
provider_name: 提供商名称
Returns:
提供商配置如果未找到则返回 None
"""
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
providers = config_data.get("api_providers", [])
for provider in providers:
if provider.get("name") == provider_name:
return dict(provider)
return None
except Exception as e:
logger.error(f"读取提供商配置失败: {e}")
return None
@router.get("/list")
async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
):
"""
获取指定提供商的可用模型列表
通过提供商名称查找配置然后请求对应的模型列表端点
"""
# 获取提供商配置
provider_config = _get_provider_config(provider_name)
if not provider_config:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key")
client_type = provider_config.get("client_type", "openai")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
# 获取模型列表
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"provider": provider_name,
"count": len(models),
}
@router.get("/list-by-url")
async def get_models_by_url(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: str = Query(..., description="API Key"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
):
"""
通过 URL 直接获取模型列表用于自定义提供商
"""
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"count": len(models),
}

View File

@ -13,6 +13,7 @@ from .emoji_routes import router as emoji_router
from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
logger = get_logger("webui.api")
@ -35,6 +36,8 @@ router.include_router(plugin_router)
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
class TokenVerifyRequest(BaseModel):