224 lines
7.9 KiB
Python
224 lines
7.9 KiB
Python
from dataclasses import dataclass
|
||
import json
|
||
import os
|
||
from urllib.error import HTTPError, URLError
|
||
from urllib.request import Request, urlopen
|
||
|
||
|
||
class LLMConfigurationError(ValueError):
|
||
"""LLM 调用缺少关键配置时抛出的业务异常。"""
|
||
|
||
|
||
class EmbeddingConfigurationError(ValueError):
|
||
"""Embedding 调用缺少关键配置时抛出的业务异常。"""
|
||
|
||
|
||
@dataclass
|
||
class LLMResponse:
|
||
"""
|
||
统一的模型响应对象。
|
||
|
||
Agent Core 的 Orchestrator 只依赖这一个结构,而不直接感知底层供应商差异。
|
||
"""
|
||
content: str = ""
|
||
model_name: str = ""
|
||
success: bool = True
|
||
error: Exception | None = None
|
||
|
||
|
||
class MockLLMProvider:
|
||
"""
|
||
本地和测试默认使用的 Mock Provider。
|
||
|
||
设计目标不是拟真对话,而是提供一个稳定、可断言、可结构化解析的响应,
|
||
让前后端在未接入真实模型时也能完整演示链路。
|
||
"""
|
||
|
||
def __init__(self, model_name: str = "mock-model"):
|
||
self.model_name = model_name or "mock-model"
|
||
|
||
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
|
||
user_content = _find_last_user_message(messages)
|
||
return LLMResponse(
|
||
content=json.dumps(
|
||
{
|
||
"answer": f"模拟回答:{user_content}",
|
||
"confidence": "medium",
|
||
"references": [],
|
||
},
|
||
ensure_ascii=False,
|
||
),
|
||
model_name=self.model_name,
|
||
success=True,
|
||
)
|
||
|
||
|
||
class OpenAICompatibleProvider:
|
||
"""调用 OpenAI Chat Completions 兼容接口的 Provider。"""
|
||
|
||
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.model_name = model_name
|
||
|
||
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
|
||
if not self.api_key:
|
||
return LLMResponse(
|
||
model_name=self.model_name,
|
||
success=False,
|
||
error=LLMConfigurationError("LLM_API_KEY 未配置,无法调用 OpenAI 兼容模型接口"),
|
||
)
|
||
payload = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
}
|
||
if response_format:
|
||
payload["response_format"] = response_format
|
||
try:
|
||
try:
|
||
data = _post_json(
|
||
base_url=self.base_url,
|
||
endpoint="chat/completions",
|
||
api_key=self.api_key,
|
||
payload=payload,
|
||
)
|
||
except RuntimeError as exc:
|
||
# 部分 OpenAI 兼容供应商或模型不支持 response_format。
|
||
# 保留结构化优先,遇到 400 时退回普通对话,避免演示链路被接口能力差异阻断。
|
||
if not response_format or "HTTP Error 400" not in str(exc):
|
||
raise
|
||
fallback_payload = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
}
|
||
data = _post_json(
|
||
base_url=self.base_url,
|
||
endpoint="chat/completions",
|
||
api_key=self.api_key,
|
||
payload=fallback_payload,
|
||
)
|
||
choice = data.get("choices", [{}])[0]
|
||
content = choice.get("message", {}).get("content", "")
|
||
return LLMResponse(
|
||
content=content,
|
||
model_name=data.get("model", self.model_name),
|
||
success=True,
|
||
)
|
||
except Exception as exc:
|
||
return LLMResponse(model_name=self.model_name, success=False, error=exc)
|
||
|
||
|
||
class OpenAICompatibleEmbeddingProvider:
|
||
"""调用 OpenAI Embeddings 兼容接口的 Provider。"""
|
||
|
||
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.model_name = model_name
|
||
|
||
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||
if not self.api_key:
|
||
raise EmbeddingConfigurationError("EMBEDDING_API_KEY 未配置,无法调用 OpenAI 兼容 Embedding 接口")
|
||
data = _post_json(
|
||
base_url=self.base_url,
|
||
endpoint="embeddings",
|
||
api_key=self.api_key,
|
||
payload={"model": self.model_name, "input": texts},
|
||
)
|
||
return [item.get("embedding", []) for item in data.get("data", [])]
|
||
|
||
|
||
def create_llm_provider(config: dict | None = None):
|
||
"""
|
||
根据配置创建 LLM Provider。
|
||
|
||
默认策略:
|
||
- 明确指定 `LLM_PROVIDER=mock` 时使用 Mock
|
||
- 未指定但存在 `LLM_API_KEY` 时默认走 OpenAI 兼容接口
|
||
- 否则回退到 Mock,保证页面仍可闭环
|
||
"""
|
||
config = config or {}
|
||
provider_name = _resolve_provider_name(config)
|
||
model_name = config.get("LLM_MODEL", "mock-model")
|
||
if provider_name == "mock":
|
||
return MockLLMProvider(model_name=model_name)
|
||
return OpenAICompatibleProvider(
|
||
api_key=config.get("LLM_API_KEY", ""),
|
||
base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||
model_name=model_name,
|
||
)
|
||
|
||
|
||
def create_embedding_provider(config: dict | None = None):
|
||
"""
|
||
创建 Embedding Provider。
|
||
|
||
当未单独配置 Embedding Key 或 Base URL 时,会自动复用 LLM 配置,
|
||
以减少复试演示时的环境变量负担。
|
||
"""
|
||
config = config or {}
|
||
return OpenAICompatibleEmbeddingProvider(
|
||
api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")),
|
||
base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")),
|
||
model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"),
|
||
)
|
||
|
||
|
||
def get_runtime_llm_config(overrides: dict | None = None) -> dict:
|
||
"""
|
||
从环境变量读取运行时配置。
|
||
|
||
Agent Core 通过这一层读取模型配置,避免直接依赖 Django settings,
|
||
这样本模块在独立脚本、测试和 Django 环境中都可复用。
|
||
"""
|
||
config = {
|
||
"LLM_PROVIDER": os.environ.get("LLM_PROVIDER", ""),
|
||
"LLM_API_KEY": os.environ.get("LLM_API_KEY", ""),
|
||
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||
"LLM_MODEL": os.environ.get("LLM_MODEL", "mock-model"),
|
||
}
|
||
if overrides:
|
||
config.update(overrides)
|
||
return config
|
||
|
||
|
||
def _resolve_provider_name(config: dict) -> str:
|
||
"""统一推导当前应启用的 Provider 名称。"""
|
||
provider_name = config.get("LLM_PROVIDER")
|
||
if provider_name:
|
||
return provider_name
|
||
return "openai_compatible" if config.get("LLM_API_KEY") else "mock"
|
||
|
||
|
||
def _find_last_user_message(messages: list[dict]) -> str:
|
||
"""从消息列表中提取最后一条用户输入,用于 Mock Provider 回显。"""
|
||
for message in reversed(messages):
|
||
if message.get("role") == "user":
|
||
return message.get("content", "")
|
||
return ""
|
||
|
||
|
||
def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict:
|
||
"""向 OpenAI 兼容接口发送 JSON POST 请求并解析响应。"""
|
||
url = f"{base_url.rstrip('/')}/{endpoint}"
|
||
request = Request(
|
||
url,
|
||
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||
headers={
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
method="POST",
|
||
)
|
||
try:
|
||
with urlopen(request, timeout=60) as response:
|
||
return json.loads(response.read().decode("utf-8"))
|
||
except HTTPError as exc:
|
||
error_body = exc.read().decode("utf-8", errors="ignore")
|
||
error_detail = f"{exc}"
|
||
if error_body:
|
||
error_detail = f"{error_detail} {error_body}"
|
||
raise RuntimeError(f"OpenAI 兼容接口调用失败:{error_detail}") from exc
|
||
except URLError as exc:
|
||
raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc
|