refactor(core): 梳理模型配置与审计脱敏服务
This commit is contained in:
@@ -6,15 +6,20 @@ from urllib.request import Request, urlopen
|
|||||||
|
|
||||||
|
|
||||||
class LLMConfigurationError(ValueError):
|
class LLMConfigurationError(ValueError):
|
||||||
pass
|
"""LLM 调用缺少关键配置时抛出的业务异常。"""
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigurationError(ValueError):
|
class EmbeddingConfigurationError(ValueError):
|
||||||
pass
|
"""Embedding 调用缺少关键配置时抛出的业务异常。"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
|
"""
|
||||||
|
统一的模型响应对象。
|
||||||
|
|
||||||
|
Agent Core 的 Orchestrator 只依赖这一个结构,而不直接感知底层供应商差异。
|
||||||
|
"""
|
||||||
content: str = ""
|
content: str = ""
|
||||||
model_name: str = ""
|
model_name: str = ""
|
||||||
success: bool = True
|
success: bool = True
|
||||||
@@ -22,17 +27,18 @@ class LLMResponse:
|
|||||||
|
|
||||||
|
|
||||||
class MockLLMProvider:
|
class MockLLMProvider:
|
||||||
|
"""
|
||||||
|
本地和测试默认使用的 Mock Provider。
|
||||||
|
|
||||||
|
设计目标不是拟真对话,而是提供一个稳定、可断言、可结构化解析的响应,
|
||||||
|
让前后端在未接入真实模型时也能完整演示链路。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "mock-model"):
|
def __init__(self, model_name: str = "mock-model"):
|
||||||
self.model_name = model_name or "mock-model"
|
self.model_name = model_name or "mock-model"
|
||||||
|
|
||||||
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
|
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
|
||||||
# Mock Provider 的职责是让页面和测试在未接入真实模型时也能闭环。
|
user_content = _find_last_user_message(messages)
|
||||||
# 因此这里直接返回稳定 JSON,方便后续统一走结构化解析逻辑。
|
|
||||||
user_content = ""
|
|
||||||
for message in reversed(messages):
|
|
||||||
if message.get("role") == "user":
|
|
||||||
user_content = message.get("content", "")
|
|
||||||
break
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=json.dumps(
|
content=json.dumps(
|
||||||
{
|
{
|
||||||
@@ -48,6 +54,8 @@ class MockLLMProvider:
|
|||||||
|
|
||||||
|
|
||||||
class OpenAICompatibleProvider:
|
class OpenAICompatibleProvider:
|
||||||
|
"""调用 OpenAI Chat Completions 兼容接口的 Provider。"""
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: str, model_name: str):
|
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@@ -85,6 +93,8 @@ class OpenAICompatibleProvider:
|
|||||||
|
|
||||||
|
|
||||||
class OpenAICompatibleEmbeddingProvider:
|
class OpenAICompatibleEmbeddingProvider:
|
||||||
|
"""调用 OpenAI Embeddings 兼容接口的 Provider。"""
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: str, model_name: str):
|
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@@ -102,7 +112,78 @@ class OpenAICompatibleEmbeddingProvider:
|
|||||||
return [item.get("embedding", []) for item in data.get("data", [])]
|
return [item.get("embedding", []) for item in data.get("data", [])]
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_provider(config: dict | None = None):
|
||||||
|
"""
|
||||||
|
根据配置创建 LLM Provider。
|
||||||
|
|
||||||
|
默认策略:
|
||||||
|
- 明确指定 `LLM_PROVIDER=mock` 时使用 Mock
|
||||||
|
- 未指定但存在 `LLM_API_KEY` 时默认走 OpenAI 兼容接口
|
||||||
|
- 否则回退到 Mock,保证页面仍可闭环
|
||||||
|
"""
|
||||||
|
config = config or {}
|
||||||
|
provider_name = _resolve_provider_name(config)
|
||||||
|
model_name = config.get("LLM_MODEL", "mock-model")
|
||||||
|
if provider_name == "mock":
|
||||||
|
return MockLLMProvider(model_name=model_name)
|
||||||
|
return OpenAICompatibleProvider(
|
||||||
|
api_key=config.get("LLM_API_KEY", ""),
|
||||||
|
base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_provider(config: dict | None = None):
|
||||||
|
"""
|
||||||
|
创建 Embedding Provider。
|
||||||
|
|
||||||
|
当未单独配置 Embedding Key 或 Base URL 时,会自动复用 LLM 配置,
|
||||||
|
以减少复试演示时的环境变量负担。
|
||||||
|
"""
|
||||||
|
config = config or {}
|
||||||
|
return OpenAICompatibleEmbeddingProvider(
|
||||||
|
api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")),
|
||||||
|
base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")),
|
||||||
|
model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_runtime_llm_config(overrides: dict | None = None) -> dict:
|
||||||
|
"""
|
||||||
|
从环境变量读取运行时配置。
|
||||||
|
|
||||||
|
Agent Core 通过这一层读取模型配置,避免直接依赖 Django settings,
|
||||||
|
这样本模块在独立脚本、测试和 Django 环境中都可复用。
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"LLM_PROVIDER": os.environ.get("LLM_PROVIDER", ""),
|
||||||
|
"LLM_API_KEY": os.environ.get("LLM_API_KEY", ""),
|
||||||
|
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||||||
|
"LLM_MODEL": os.environ.get("LLM_MODEL", "mock-model"),
|
||||||
|
}
|
||||||
|
if overrides:
|
||||||
|
config.update(overrides)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_name(config: dict) -> str:
|
||||||
|
"""统一推导当前应启用的 Provider 名称。"""
|
||||||
|
provider_name = config.get("LLM_PROVIDER")
|
||||||
|
if provider_name:
|
||||||
|
return provider_name
|
||||||
|
return "openai_compatible" if config.get("LLM_API_KEY") else "mock"
|
||||||
|
|
||||||
|
|
||||||
|
def _find_last_user_message(messages: list[dict]) -> str:
|
||||||
|
"""从消息列表中提取最后一条用户输入,用于 Mock Provider 回显。"""
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message.get("role") == "user":
|
||||||
|
return message.get("content", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict:
|
def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict:
|
||||||
|
"""向 OpenAI 兼容接口发送 JSON POST 请求并解析响应。"""
|
||||||
url = f"{base_url.rstrip('/')}/{endpoint}"
|
url = f"{base_url.rstrip('/')}/{endpoint}"
|
||||||
request = Request(
|
request = Request(
|
||||||
url,
|
url,
|
||||||
@@ -118,45 +199,3 @@ def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dic
|
|||||||
return json.loads(response.read().decode("utf-8"))
|
return json.loads(response.read().decode("utf-8"))
|
||||||
except URLError as exc:
|
except URLError as exc:
|
||||||
raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc
|
raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
def create_llm_provider(config: dict | None = None):
|
|
||||||
config = config or {}
|
|
||||||
provider_name = config.get("LLM_PROVIDER")
|
|
||||||
if not provider_name:
|
|
||||||
provider_name = "openai_compatible" if config.get("LLM_API_KEY") else "mock"
|
|
||||||
model_name = config.get("LLM_MODEL", "mock-model")
|
|
||||||
if provider_name == "mock":
|
|
||||||
return MockLLMProvider(model_name=model_name)
|
|
||||||
return OpenAICompatibleProvider(
|
|
||||||
api_key=config.get("LLM_API_KEY", ""),
|
|
||||||
base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_provider(config: dict | None = None):
|
|
||||||
config = config or {}
|
|
||||||
return OpenAICompatibleEmbeddingProvider(
|
|
||||||
api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")),
|
|
||||||
base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")),
|
|
||||||
model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_runtime_llm_config(overrides: dict | None = None) -> dict:
|
|
||||||
"""
|
|
||||||
从环境变量读取运行时配置。
|
|
||||||
|
|
||||||
Agent Core 通过这层读取模型配置,避免直接依赖 Django settings,
|
|
||||||
这样本模块在独立脚本、测试和 Django 中都能复用。
|
|
||||||
"""
|
|
||||||
config = {
|
|
||||||
"LLM_PROVIDER": os.environ.get("LLM_PROVIDER", ""),
|
|
||||||
"LLM_API_KEY": os.environ.get("LLM_API_KEY", ""),
|
|
||||||
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
|
||||||
"LLM_MODEL": os.environ.get("LLM_MODEL", "mock-model"),
|
|
||||||
}
|
|
||||||
if overrides:
|
|
||||||
config.update(overrides)
|
|
||||||
return config
|
|
||||||
|
|||||||
@@ -3,23 +3,20 @@ from agent_core.results import AgentResult
|
|||||||
from .models import AgentAuditLog
|
from .models import AgentAuditLog
|
||||||
|
|
||||||
|
|
||||||
def _mask_sensitive_text(value: str) -> str:
|
|
||||||
masked = value
|
|
||||||
for marker in ("LLM_API_KEY=", "EMBEDDING_API_KEY="):
|
|
||||||
if marker in masked:
|
|
||||||
prefix, _, suffix = masked.partition(marker)
|
|
||||||
secret, separator, rest = suffix.partition(" ")
|
|
||||||
masked_secret = "sk-***" if secret.startswith("sk-") else "***"
|
|
||||||
masked = f"{prefix}{marker}{masked_secret}{separator}{rest}"
|
|
||||||
return masked
|
|
||||||
|
|
||||||
|
|
||||||
def create_audit_log(
|
def create_audit_log(
|
||||||
scenario_id: str,
|
scenario_id: str,
|
||||||
scenario_name: str,
|
scenario_name: str,
|
||||||
user_input: str,
|
user_input: str,
|
||||||
agent_result: AgentResult,
|
agent_result: AgentResult,
|
||||||
) -> AgentAuditLog:
|
) -> AgentAuditLog:
|
||||||
|
"""
|
||||||
|
将一次 Agent 执行结果落库为审计日志。
|
||||||
|
|
||||||
|
设计原则:
|
||||||
|
- 成功与失败都必须记录,方便复盘整条执行链路
|
||||||
|
- 敏感信息在写库前先脱敏,避免误存 API Key
|
||||||
|
- 对前端和 Django Model 统一输出稳定字段
|
||||||
|
"""
|
||||||
return AgentAuditLog.objects.create(
|
return AgentAuditLog.objects.create(
|
||||||
scenario_id=scenario_id,
|
scenario_id=scenario_id,
|
||||||
scenario_name=scenario_name,
|
scenario_name=scenario_name,
|
||||||
@@ -32,5 +29,29 @@ def create_audit_log(
|
|||||||
model_name=agent_result.model_name,
|
model_name=agent_result.model_name,
|
||||||
latency_ms=max(agent_result.latency_ms, 0),
|
latency_ms=max(agent_result.latency_ms, 0),
|
||||||
status=agent_result.status,
|
status=agent_result.status,
|
||||||
error_message=_mask_sensitive_text(agent_result.error),
|
error_message=mask_sensitive_text(agent_result.error),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_sensitive_text(value: str) -> str:
|
||||||
|
"""
|
||||||
|
对错误文本中的敏感配置进行脱敏。
|
||||||
|
|
||||||
|
当前至少处理:
|
||||||
|
- `LLM_API_KEY=...`
|
||||||
|
- `EMBEDDING_API_KEY=...`
|
||||||
|
"""
|
||||||
|
masked = value
|
||||||
|
for marker in ("LLM_API_KEY=", "EMBEDDING_API_KEY="):
|
||||||
|
masked = _mask_token_after_marker(masked, marker)
|
||||||
|
return masked
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_token_after_marker(value: str, marker: str) -> str:
|
||||||
|
"""将 marker 后紧跟的 token 替换为脱敏占位符。"""
|
||||||
|
if marker not in value:
|
||||||
|
return value
|
||||||
|
prefix, _, suffix = value.partition(marker)
|
||||||
|
secret, separator, rest = suffix.partition(" ")
|
||||||
|
masked_secret = "sk-***" if secret.startswith("sk-") else "***"
|
||||||
|
return f"{prefix}{marker}{masked_secret}{separator}{rest}"
|
||||||
|
|||||||
@@ -75,6 +75,19 @@ def test_create_audit_log_masks_api_keys_from_error_message(db):
|
|||||||
assert "sk-***" in log.error_message
|
assert "sk-***" in log.error_message
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_audit_log_masks_embedding_api_keys_from_error_message(db):
|
||||||
|
result = AgentResult(
|
||||||
|
answer="",
|
||||||
|
status="failed",
|
||||||
|
error="EMBEDDING_API_KEY=embed-secret 调用失败",
|
||||||
|
)
|
||||||
|
|
||||||
|
log = create_audit_log("knowledge_qa", "知识库问答助手", "问题", result)
|
||||||
|
|
||||||
|
assert "embed-secret" not in log.error_message
|
||||||
|
assert "EMBEDDING_API_KEY=***" in log.error_message
|
||||||
|
|
||||||
|
|
||||||
def test_query_demo_records_reads_demo_business_record_table(db):
|
def test_query_demo_records_reads_demo_business_record_table(db):
|
||||||
DemoBusinessRecord.objects.create(
|
DemoBusinessRecord.objects.create(
|
||||||
scenario_id="quality_analysis",
|
scenario_id="quality_analysis",
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from agent_core.llm_provider import (
|
|||||||
LLMConfigurationError,
|
LLMConfigurationError,
|
||||||
create_embedding_provider,
|
create_embedding_provider,
|
||||||
create_llm_provider,
|
create_llm_provider,
|
||||||
|
get_runtime_llm_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -109,3 +110,17 @@ def test_embedding_provider_requires_api_key():
|
|||||||
assert "EMBEDDING_API_KEY" in str(exc)
|
assert "EMBEDDING_API_KEY" in str(exc)
|
||||||
else:
|
else:
|
||||||
raise AssertionError("expected EmbeddingConfigurationError")
|
raise AssertionError("expected EmbeddingConfigurationError")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_runtime_llm_config_uses_environment_and_overrides(monkeypatch):
|
||||||
|
monkeypatch.setenv("LLM_PROVIDER", "mock")
|
||||||
|
monkeypatch.setenv("LLM_API_KEY", "sk-env")
|
||||||
|
monkeypatch.setenv("LLM_BASE_URL", "https://env.example/v1")
|
||||||
|
monkeypatch.setenv("LLM_MODEL", "env-model")
|
||||||
|
|
||||||
|
config = get_runtime_llm_config({"LLM_MODEL": "override-model"})
|
||||||
|
|
||||||
|
assert config["LLM_PROVIDER"] == "mock"
|
||||||
|
assert config["LLM_API_KEY"] == "sk-env"
|
||||||
|
assert config["LLM_BASE_URL"] == "https://env.example/v1"
|
||||||
|
assert config["LLM_MODEL"] == "override-model"
|
||||||
|
|||||||
Reference in New Issue
Block a user