Files
DEMO-AGENT/agent_core/llm_provider.py

163 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
import json
import os
from urllib.error import URLError
from urllib.request import Request, urlopen
class LLMConfigurationError(ValueError):
pass
class EmbeddingConfigurationError(ValueError):
pass
@dataclass
class LLMResponse:
content: str = ""
model_name: str = ""
success: bool = True
error: Exception | None = None
class MockLLMProvider:
def __init__(self, model_name: str = "mock-model"):
self.model_name = model_name or "mock-model"
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
# Mock Provider 的职责是让页面和测试在未接入真实模型时也能闭环。
# 因此这里直接返回稳定 JSON方便后续统一走结构化解析逻辑。
user_content = ""
for message in reversed(messages):
if message.get("role") == "user":
user_content = message.get("content", "")
break
return LLMResponse(
content=json.dumps(
{
"answer": f"模拟回答:{user_content}",
"confidence": "medium",
"references": [],
},
ensure_ascii=False,
),
model_name=self.model_name,
success=True,
)
class OpenAICompatibleProvider:
def __init__(self, api_key: str, base_url: str, model_name: str):
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
if not self.api_key:
return LLMResponse(
model_name=self.model_name,
success=False,
error=LLMConfigurationError("LLM_API_KEY 未配置,无法调用 OpenAI 兼容模型接口"),
)
payload = {
"model": self.model_name,
"messages": messages,
}
if response_format:
payload["response_format"] = response_format
try:
data = _post_json(
base_url=self.base_url,
endpoint="chat/completions",
api_key=self.api_key,
payload=payload,
)
choice = data.get("choices", [{}])[0]
content = choice.get("message", {}).get("content", "")
return LLMResponse(
content=content,
model_name=data.get("model", self.model_name),
success=True,
)
except Exception as exc:
return LLMResponse(model_name=self.model_name, success=False, error=exc)
class OpenAICompatibleEmbeddingProvider:
def __init__(self, api_key: str, base_url: str, model_name: str):
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
def embed_texts(self, texts: list[str]) -> list[list[float]]:
if not self.api_key:
raise EmbeddingConfigurationError("EMBEDDING_API_KEY 未配置,无法调用 OpenAI 兼容 Embedding 接口")
data = _post_json(
base_url=self.base_url,
endpoint="embeddings",
api_key=self.api_key,
payload={"model": self.model_name, "input": texts},
)
return [item.get("embedding", []) for item in data.get("data", [])]
def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict:
url = f"{base_url.rstrip('/')}/{endpoint}"
request = Request(
url,
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
method="POST",
)
try:
with urlopen(request, timeout=60) as response:
return json.loads(response.read().decode("utf-8"))
except URLError as exc:
raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc
def create_llm_provider(config: dict | None = None):
config = config or {}
provider_name = config.get("LLM_PROVIDER")
if not provider_name:
provider_name = "openai_compatible" if config.get("LLM_API_KEY") else "mock"
model_name = config.get("LLM_MODEL", "mock-model")
if provider_name == "mock":
return MockLLMProvider(model_name=model_name)
return OpenAICompatibleProvider(
api_key=config.get("LLM_API_KEY", ""),
base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"),
model_name=model_name,
)
def create_embedding_provider(config: dict | None = None):
config = config or {}
return OpenAICompatibleEmbeddingProvider(
api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")),
base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")),
model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"),
)
def get_runtime_llm_config(overrides: dict | None = None) -> dict:
"""
从环境变量读取运行时配置。
Agent Core 通过这层读取模型配置,避免直接依赖 Django settings
这样本模块在独立脚本、测试和 Django 中都能复用。
"""
config = {
"LLM_PROVIDER": os.environ.get("LLM_PROVIDER", ""),
"LLM_API_KEY": os.environ.get("LLM_API_KEY", ""),
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1"),
"LLM_MODEL": os.environ.get("LLM_MODEL", "mock-model"),
}
if overrides:
config.update(overrides)
return config