feat(agent-core): 增加智能编排与模型工具基础

This commit is contained in:
2026-05-30 00:08:27 +08:00
parent 35b80929b0
commit 7a6c110103
16 changed files with 806 additions and 0 deletions

122
tests/test_agent_core.py Normal file
View File

@@ -0,0 +1,122 @@
from agent_core.orchestrator import run_agent
from agent_core.rag.ingest import ingest_document
from agent_core.rag.retriever import retrieve
def test_run_agent_returns_structured_mock_result():
scenario = {
"id": "knowledge_qa",
"name": "知识库问答助手",
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
"tools": ["generate_action_items"],
"output": {"type": "general_answer"},
}
result = run_agent(scenario, "如何处理异常?")
assert result.status == "success"
assert result.answer
assert result.structured_output["output_type"] == "general_answer"
assert isinstance(result.references, list)
assert result.tool_calls[0]["tool_name"] == "generate_action_items"
def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path):
store_path = tmp_path / "rag_store.json"
text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。"
result = ingest_document(
scenario_id="quality_analysis",
source_file="quality.md",
text=text,
collection="quality_analysis",
store_path=store_path,
)
ingest_document(
scenario_id="risk_audit",
source_file="risk.md",
text="报销审核需要检查发票、金额和审批链。",
collection="risk_audit",
store_path=store_path,
)
chunks = retrieve(
scenario_id="quality_analysis",
query="质量异常批次",
collection="quality_analysis",
top_k=3,
store_path=store_path,
)
assert result.success is True
assert result.chunks_count >= 1
assert chunks
assert chunks[0]["source"] == "quality.md"
assert "质量异常" in chunks[0]["content"]
assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks)
def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
document_id=1,
scenario_id="knowledge_qa",
source_file="old.md",
text="旧制度要求人工登记。",
collection="knowledge_qa",
store_path=store_path,
)
ingest_document(
document_id=1,
scenario_id="knowledge_qa",
source_file="new.md",
text="新制度要求系统自动登记。",
collection="knowledge_qa",
store_path=store_path,
)
ingest_document(
document_id=2,
scenario_id="knowledge_qa",
source_file="other.md",
text="系统自动登记后需要生成审计记录。",
collection="knowledge_qa",
store_path=store_path,
)
chunks = retrieve(
scenario_id="knowledge_qa",
query="系统自动登记",
collection="knowledge_qa",
top_k=5,
document_ids=[1],
store_path=store_path,
)
assert chunks
assert {chunk["document_id"] for chunk in chunks} == {1}
assert all(chunk["source"] == "new.md" for chunk in chunks)
assert all("旧制度" not in chunk["content"] for chunk in chunks)
def test_run_agent_uses_retrieved_document_chunks(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
scenario_id="knowledge_qa",
source_file="sop.md",
text="异常处理 SOP先隔离现场再通知负责人。",
collection="knowledge_qa",
store_path=store_path,
)
scenario = {
"id": "knowledge_qa",
"name": "知识库问答助手",
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
"tools": [],
"output": {"type": "general_answer"},
}
result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path})
assert result.references[0]["source"] == "sop.md"
assert "隔离现场" in result.references[0]["content"]

111
tests/test_llm_provider.py Normal file
View File

@@ -0,0 +1,111 @@
from agent_core.llm_provider import (
EmbeddingConfigurationError,
LLMConfigurationError,
create_embedding_provider,
create_llm_provider,
)
def test_create_llm_provider_requires_api_key_for_openai_compatible():
provider = create_llm_provider(
{
"LLM_API_KEY": "",
"LLM_BASE_URL": "https://api.openai.com/v1",
"LLM_MODEL": "gpt-4.1-mini",
"LLM_PROVIDER": "openai_compatible",
}
)
response = provider.generate([{"role": "user", "content": "hello"}])
assert response.success is False
assert isinstance(response.error, LLMConfigurationError)
assert "LLM_API_KEY" in str(response.error)
def test_mock_provider_returns_deterministic_content():
provider = create_llm_provider({"LLM_PROVIDER": "mock", "LLM_MODEL": "demo-model"})
response = provider.generate([{"role": "user", "content": "hello"}])
assert response.success is True
assert response.model_name == "demo-model"
assert "hello" in response.content
def test_openai_compatible_provider_posts_chat_completion(monkeypatch):
captured = {}
class FakeResponse:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, traceback):
return False
def read(self):
return b'{"choices":[{"message":{"content":"ok"}}],"model":"demo-model"}'
def fake_urlopen(request, timeout):
captured["url"] = request.full_url
captured["headers"] = dict(request.header_items())
captured["body"] = request.data.decode("utf-8")
return FakeResponse()
monkeypatch.setattr("agent_core.llm_provider.urlopen", fake_urlopen)
provider = create_llm_provider(
{
"LLM_PROVIDER": "openai_compatible",
"LLM_API_KEY": "sk-test",
"LLM_BASE_URL": "https://api.siliconflow.cn/v1",
"LLM_MODEL": "demo-model",
}
)
response = provider.generate([{"role": "user", "content": "hello"}])
assert response.success is True
assert response.content == "ok"
assert captured["url"] == "https://api.siliconflow.cn/v1/chat/completions"
assert '"model": "demo-model"' in captured["body"]
assert captured["headers"]["Authorization"] == "Bearer sk-test"
def test_embedding_provider_uses_openai_compatible_embeddings(monkeypatch):
class FakeResponse:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, traceback):
return False
def read(self):
return b'{"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]}'
monkeypatch.setattr("agent_core.llm_provider.urlopen", lambda request, timeout: FakeResponse())
provider = create_embedding_provider(
{
"EMBEDDING_API_KEY": "sk-test",
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
"EMBEDDING_MODEL": "demo-embedding",
}
)
assert provider.embed_texts(["a", "b"]) == [[0.1, 0.2], [0.3, 0.4]]
def test_embedding_provider_requires_api_key():
provider = create_embedding_provider(
{
"EMBEDDING_API_KEY": "",
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
"EMBEDDING_MODEL": "demo-embedding",
}
)
try:
provider.embed_texts(["a"])
except EmbeddingConfigurationError as exc:
assert "EMBEDDING_API_KEY" in str(exc)
else:
raise AssertionError("expected EmbeddingConfigurationError")