Files
DEMO-AGENT/agent_core/llm_provider.py

224 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
import json
import os
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
class LLMConfigurationError(ValueError):
"""LLM 调用缺少关键配置时抛出的业务异常。"""
class EmbeddingConfigurationError(ValueError):
"""Embedding 调用缺少关键配置时抛出的业务异常。"""
@dataclass
class LLMResponse:
"""
统一的模型响应对象。
Agent Core 的 Orchestrator 只依赖这一个结构,而不直接感知底层供应商差异。
"""
content: str = ""
model_name: str = ""
success: bool = True
error: Exception | None = None
class MockLLMProvider:
"""
本地和测试默认使用的 Mock Provider。
设计目标不是拟真对话,而是提供一个稳定、可断言、可结构化解析的响应,
让前后端在未接入真实模型时也能完整演示链路。
"""
def __init__(self, model_name: str = "mock-model"):
self.model_name = model_name or "mock-model"
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
user_content = _find_last_user_message(messages)
return LLMResponse(
content=json.dumps(
{
"answer": f"模拟回答:{user_content}",
"confidence": "medium",
"references": [],
},
ensure_ascii=False,
),
model_name=self.model_name,
success=True,
)
class OpenAICompatibleProvider:
"""调用 OpenAI Chat Completions 兼容接口的 Provider。"""
def __init__(self, api_key: str, base_url: str, model_name: str):
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
if not self.api_key:
return LLMResponse(
model_name=self.model_name,
success=False,
error=LLMConfigurationError("LLM_API_KEY 未配置,无法调用 OpenAI 兼容模型接口"),
)
payload = {
"model": self.model_name,
"messages": messages,
}
if response_format:
payload["response_format"] = response_format
try:
try:
data = _post_json(
base_url=self.base_url,
endpoint="chat/completions",
api_key=self.api_key,
payload=payload,
)
except RuntimeError as exc:
# 部分 OpenAI 兼容供应商或模型不支持 response_format。
# 保留结构化优先,遇到 400 时退回普通对话,避免演示链路被接口能力差异阻断。
if not response_format or "HTTP Error 400" not in str(exc):
raise
fallback_payload = {
"model": self.model_name,
"messages": messages,
}
data = _post_json(
base_url=self.base_url,
endpoint="chat/completions",
api_key=self.api_key,
payload=fallback_payload,
)
choice = data.get("choices", [{}])[0]
content = choice.get("message", {}).get("content", "")
return LLMResponse(
content=content,
model_name=data.get("model", self.model_name),
success=True,
)
except Exception as exc:
return LLMResponse(model_name=self.model_name, success=False, error=exc)
class OpenAICompatibleEmbeddingProvider:
"""调用 OpenAI Embeddings 兼容接口的 Provider。"""
def __init__(self, api_key: str, base_url: str, model_name: str):
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
def embed_texts(self, texts: list[str]) -> list[list[float]]:
if not self.api_key:
raise EmbeddingConfigurationError("EMBEDDING_API_KEY 未配置,无法调用 OpenAI 兼容 Embedding 接口")
data = _post_json(
base_url=self.base_url,
endpoint="embeddings",
api_key=self.api_key,
payload={"model": self.model_name, "input": texts},
)
return [item.get("embedding", []) for item in data.get("data", [])]
def create_llm_provider(config: dict | None = None):
"""
根据配置创建 LLM Provider。
默认策略:
- 明确指定 `LLM_PROVIDER=mock` 时使用 Mock
- 未指定但存在 `LLM_API_KEY` 时默认走 OpenAI 兼容接口
- 否则回退到 Mock保证页面仍可闭环
"""
config = config or {}
provider_name = _resolve_provider_name(config)
model_name = config.get("LLM_MODEL", "mock-model")
if provider_name == "mock":
return MockLLMProvider(model_name=model_name)
return OpenAICompatibleProvider(
api_key=config.get("LLM_API_KEY", ""),
base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"),
model_name=model_name,
)
def create_embedding_provider(config: dict | None = None):
"""
创建 Embedding Provider。
当未单独配置 Embedding Key 或 Base URL 时,会自动复用 LLM 配置,
以减少复试演示时的环境变量负担。
"""
config = config or {}
return OpenAICompatibleEmbeddingProvider(
api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")),
base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")),
model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"),
)
def get_runtime_llm_config(overrides: dict | None = None) -> dict:
"""
从环境变量读取运行时配置。
Agent Core 通过这一层读取模型配置,避免直接依赖 Django settings
这样本模块在独立脚本、测试和 Django 环境中都可复用。
"""
config = {
"LLM_PROVIDER": os.environ.get("LLM_PROVIDER", ""),
"LLM_API_KEY": os.environ.get("LLM_API_KEY", ""),
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1"),
"LLM_MODEL": os.environ.get("LLM_MODEL", "mock-model"),
}
if overrides:
config.update(overrides)
return config
def _resolve_provider_name(config: dict) -> str:
"""统一推导当前应启用的 Provider 名称。"""
provider_name = config.get("LLM_PROVIDER")
if provider_name:
return provider_name
return "openai_compatible" if config.get("LLM_API_KEY") else "mock"
def _find_last_user_message(messages: list[dict]) -> str:
"""从消息列表中提取最后一条用户输入,用于 Mock Provider 回显。"""
for message in reversed(messages):
if message.get("role") == "user":
return message.get("content", "")
return ""
def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict:
"""向 OpenAI 兼容接口发送 JSON POST 请求并解析响应。"""
url = f"{base_url.rstrip('/')}/{endpoint}"
request = Request(
url,
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
method="POST",
)
try:
with urlopen(request, timeout=60) as response:
return json.loads(response.read().decode("utf-8"))
except HTTPError as exc:
error_body = exc.read().decode("utf-8", errors="ignore")
error_detail = f"{exc}"
if error_body:
error_detail = f"{error_detail} {error_body}"
raise RuntimeError(f"OpenAI 兼容接口调用失败:{error_detail}") from exc
except URLError as exc:
raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc