Files
DEMO-AGENT/tests/test_agent_core.py

224 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from agent_core.orchestrator import build_messages, run_agent
from agent_core.rag.ingest import ingest_document
from agent_core.rag.retriever import retrieve
def test_run_agent_returns_structured_result_from_llm_output():
scenario = {
"id": "knowledge_qa",
"name": "知识库问答助手",
"agent": {
"role": "知识库助手",
"goal": "基于资料回答问题",
"instructions": ["仅根据证据回答"],
},
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
"tools": ["generate_action_items"],
"output": {"type": "general_answer"},
}
provider_response = """
{
"answer": "请先隔离异常现场,再通知负责人。",
"confidence": "high",
"references": [
{"source": "sop.md", "excerpt": "异常处理 SOP先隔离现场"}
]
}
"""
class FakeProvider:
def generate(self, messages, response_format=None):
from agent_core.llm_provider import LLMResponse
return LLMResponse(
content=provider_response,
model_name="demo-model",
success=True,
)
result = run_agent(
scenario,
"如何处理异常?",
options={"llm_provider": FakeProvider()},
)
assert result.status == "success"
assert result.answer == "请先隔离异常现场,再通知负责人。"
assert result.structured_output["output_type"] == "general_answer"
assert result.structured_output["confidence"] == "high"
assert isinstance(result.references, list)
assert result.tool_calls[0]["tool_name"] == "generate_action_items"
assert result.model_name == "demo-model"
def test_run_agent_falls_back_when_llm_returns_non_json():
scenario = {
"id": "document_review",
"name": "文档审核助手",
"agent": {
"role": "审核助手",
"goal": "总结审核意见",
"instructions": ["输出重点问题"],
},
"rag": {"enabled": False},
"tools": [],
"output": {"type": "document_review_report"},
}
class FakeProvider:
def generate(self, messages, response_format=None):
from agent_core.llm_provider import LLMResponse
return LLMResponse(
content="这是非 JSON 的普通回答",
model_name="demo-model",
success=True,
)
result = run_agent(
scenario,
"请检查合同风险",
options={"llm_provider": FakeProvider()},
)
assert result.status == "success"
assert result.answer == "这是非 JSON 的普通回答"
assert result.structured_output["output_type"] == "document_review_report"
assert result.structured_output["summary"] == "这是非 JSON 的普通回答"
assert result.structured_output["parse_mode"] == "fallback"
def test_build_messages_contains_role_goal_references_and_tool_results():
scenario = {
"name": "质量异常分析助手",
"agent": {
"role": "质量管理专家",
"goal": "生成结构化质量分析报告",
"instructions": ["必须引用知识库", "缺失信息要说明"],
},
"output": {"type": "quality_report"},
}
messages = build_messages(
scenario_config=scenario,
user_input="分析 A 线异常",
references=[{"source": "sop.md", "content": "先隔离现场"}],
tool_calls=[
{
"tool_name": "query_demo_records",
"success": True,
"result": {"records": [{"title": "A线缺陷"}]},
"error": "",
}
],
)
assert messages[0]["role"] == "system"
assert "质量管理专家" in messages[0]["content"]
assert "生成结构化质量分析报告" in messages[0]["content"]
assert "quality_report" in messages[0]["content"]
assert "先隔离现场" in messages[1]["content"]
assert "A线缺陷" in messages[1]["content"]
assert "分析 A 线异常" in messages[2]["content"]
def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path):
store_path = tmp_path / "rag_store.json"
text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。"
result = ingest_document(
scenario_id="quality_analysis",
source_file="quality.md",
text=text,
collection="quality_analysis",
store_path=store_path,
)
ingest_document(
scenario_id="risk_audit",
source_file="risk.md",
text="报销审核需要检查发票、金额和审批链。",
collection="risk_audit",
store_path=store_path,
)
chunks = retrieve(
scenario_id="quality_analysis",
query="质量异常批次",
collection="quality_analysis",
top_k=3,
store_path=store_path,
)
assert result.success is True
assert result.chunks_count >= 1
assert chunks
assert chunks[0]["source"] == "quality.md"
assert "质量异常" in chunks[0]["content"]
assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks)
def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
document_id=1,
scenario_id="knowledge_qa",
source_file="old.md",
text="旧制度要求人工登记。",
collection="knowledge_qa",
store_path=store_path,
)
ingest_document(
document_id=1,
scenario_id="knowledge_qa",
source_file="new.md",
text="新制度要求系统自动登记。",
collection="knowledge_qa",
store_path=store_path,
)
ingest_document(
document_id=2,
scenario_id="knowledge_qa",
source_file="other.md",
text="系统自动登记后需要生成审计记录。",
collection="knowledge_qa",
store_path=store_path,
)
chunks = retrieve(
scenario_id="knowledge_qa",
query="系统自动登记",
collection="knowledge_qa",
top_k=5,
document_ids=[1],
store_path=store_path,
)
assert chunks
assert {chunk["document_id"] for chunk in chunks} == {1}
assert all(chunk["source"] == "new.md" for chunk in chunks)
assert all("旧制度" not in chunk["content"] for chunk in chunks)
def test_run_agent_uses_retrieved_document_chunks(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
scenario_id="knowledge_qa",
source_file="sop.md",
text="异常处理 SOP先隔离现场再通知负责人。",
collection="knowledge_qa",
store_path=store_path,
)
scenario = {
"id": "knowledge_qa",
"name": "知识库问答助手",
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
"tools": [],
"output": {"type": "general_answer"},
}
result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path})
assert result.references[0]["source"] == "sop.md"
assert "隔离现场" in result.references[0]["content"]