Files
DEMO-AGENT/tests/test_agent_core.py

123 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from agent_core.orchestrator import run_agent
from agent_core.rag.ingest import ingest_document
from agent_core.rag.retriever import retrieve
def test_run_agent_returns_structured_mock_result():
scenario = {
"id": "knowledge_qa",
"name": "知识库问答助手",
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
"tools": ["generate_action_items"],
"output": {"type": "general_answer"},
}
result = run_agent(scenario, "如何处理异常?")
assert result.status == "success"
assert result.answer
assert result.structured_output["output_type"] == "general_answer"
assert isinstance(result.references, list)
assert result.tool_calls[0]["tool_name"] == "generate_action_items"
def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path):
store_path = tmp_path / "rag_store.json"
text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。"
result = ingest_document(
scenario_id="quality_analysis",
source_file="quality.md",
text=text,
collection="quality_analysis",
store_path=store_path,
)
ingest_document(
scenario_id="risk_audit",
source_file="risk.md",
text="报销审核需要检查发票、金额和审批链。",
collection="risk_audit",
store_path=store_path,
)
chunks = retrieve(
scenario_id="quality_analysis",
query="质量异常批次",
collection="quality_analysis",
top_k=3,
store_path=store_path,
)
assert result.success is True
assert result.chunks_count >= 1
assert chunks
assert chunks[0]["source"] == "quality.md"
assert "质量异常" in chunks[0]["content"]
assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks)
def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
document_id=1,
scenario_id="knowledge_qa",
source_file="old.md",
text="旧制度要求人工登记。",
collection="knowledge_qa",
store_path=store_path,
)
ingest_document(
document_id=1,
scenario_id="knowledge_qa",
source_file="new.md",
text="新制度要求系统自动登记。",
collection="knowledge_qa",
store_path=store_path,
)
ingest_document(
document_id=2,
scenario_id="knowledge_qa",
source_file="other.md",
text="系统自动登记后需要生成审计记录。",
collection="knowledge_qa",
store_path=store_path,
)
chunks = retrieve(
scenario_id="knowledge_qa",
query="系统自动登记",
collection="knowledge_qa",
top_k=5,
document_ids=[1],
store_path=store_path,
)
assert chunks
assert {chunk["document_id"] for chunk in chunks} == {1}
assert all(chunk["source"] == "new.md" for chunk in chunks)
assert all("旧制度" not in chunk["content"] for chunk in chunks)
def test_run_agent_uses_retrieved_document_chunks(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
scenario_id="knowledge_qa",
source_file="sop.md",
text="异常处理 SOP先隔离现场再通知负责人。",
collection="knowledge_qa",
store_path=store_path,
)
scenario = {
"id": "knowledge_qa",
"name": "知识库问答助手",
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
"tools": [],
"output": {"type": "general_answer"},
}
result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path})
assert result.references[0]["source"] == "sop.md"
assert "隔离现场" in result.references[0]["content"]