308 lines
10 KiB
Python
308 lines
10 KiB
Python
from agent_core.orchestrator import build_messages, run_agent
|
||
from agent_core.rag.ingest import _split_text, ingest_document
|
||
from agent_core.rag.retriever import retrieve
|
||
from agent_core.schemas.outputs import SUPPORTED_OUTPUT_TYPES
|
||
|
||
|
||
def test_run_agent_returns_structured_result_from_llm_output():
|
||
scenario = {
|
||
"id": "knowledge_qa",
|
||
"name": "知识库问答助手",
|
||
"agent": {
|
||
"role": "知识库助手",
|
||
"goal": "基于资料回答问题",
|
||
"instructions": ["仅根据证据回答"],
|
||
},
|
||
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
|
||
"tools": ["generate_action_items"],
|
||
"output": {"type": "general_answer"},
|
||
}
|
||
provider_response = """
|
||
{
|
||
"answer": "请先隔离异常现场,再通知负责人。",
|
||
"confidence": "high",
|
||
"references": [
|
||
{"source": "sop.md", "excerpt": "异常处理 SOP:先隔离现场"}
|
||
]
|
||
}
|
||
"""
|
||
|
||
class FakeProvider:
|
||
def generate(self, messages, response_format=None):
|
||
from agent_core.llm_provider import LLMResponse
|
||
|
||
return LLMResponse(
|
||
content=provider_response,
|
||
model_name="demo-model",
|
||
success=True,
|
||
)
|
||
|
||
result = run_agent(
|
||
scenario,
|
||
"如何处理异常?",
|
||
options={"llm_provider": FakeProvider()},
|
||
)
|
||
|
||
assert result.status == "success"
|
||
assert result.answer == "请先隔离异常现场,再通知负责人。"
|
||
assert result.structured_output["output_type"] == "general_answer"
|
||
assert result.structured_output["confidence"] == "high"
|
||
assert isinstance(result.references, list)
|
||
assert result.tool_calls[0]["tool_name"] == "generate_action_items"
|
||
assert result.model_name == "demo-model"
|
||
|
||
|
||
def test_run_agent_falls_back_when_llm_returns_non_json():
|
||
scenario = {
|
||
"id": "document_review",
|
||
"name": "文档审核助手",
|
||
"agent": {
|
||
"role": "审核助手",
|
||
"goal": "总结审核意见",
|
||
"instructions": ["输出重点问题"],
|
||
},
|
||
"rag": {"enabled": False},
|
||
"tools": [],
|
||
"output": {"type": "document_review_report"},
|
||
}
|
||
|
||
class FakeProvider:
|
||
def generate(self, messages, response_format=None):
|
||
from agent_core.llm_provider import LLMResponse
|
||
|
||
return LLMResponse(
|
||
content="这是非 JSON 的普通回答",
|
||
model_name="demo-model",
|
||
success=True,
|
||
)
|
||
|
||
result = run_agent(
|
||
scenario,
|
||
"请检查合同风险",
|
||
options={"llm_provider": FakeProvider()},
|
||
)
|
||
|
||
assert result.status == "success"
|
||
assert result.answer == "这是非 JSON 的普通回答"
|
||
assert result.structured_output["output_type"] == "document_review_report"
|
||
assert result.structured_output["summary"] == "这是非 JSON 的普通回答"
|
||
assert result.structured_output["parse_mode"] == "fallback"
|
||
|
||
|
||
def test_build_messages_contains_role_goal_references_and_tool_results():
|
||
scenario = {
|
||
"name": "质量异常分析助手",
|
||
"agent": {
|
||
"role": "质量管理专家",
|
||
"goal": "生成结构化质量分析报告",
|
||
"instructions": ["必须引用知识库", "缺失信息要说明"],
|
||
},
|
||
"output": {"type": "quality_report"},
|
||
}
|
||
|
||
messages = build_messages(
|
||
scenario_config=scenario,
|
||
user_input="分析 A 线异常",
|
||
references=[{"source": "sop.md", "content": "先隔离现场"}],
|
||
tool_calls=[
|
||
{
|
||
"tool_name": "query_demo_records",
|
||
"success": True,
|
||
"result": {"records": [{"title": "A线缺陷"}]},
|
||
"error": "",
|
||
}
|
||
],
|
||
)
|
||
|
||
assert messages[0]["role"] == "system"
|
||
assert "质量管理专家" in messages[0]["content"]
|
||
assert "生成结构化质量分析报告" in messages[0]["content"]
|
||
assert "quality_report" in messages[0]["content"]
|
||
assert "先隔离现场" in messages[1]["content"]
|
||
assert "A线缺陷" in messages[1]["content"]
|
||
assert "分析 A 线异常" in messages[2]["content"]
|
||
|
||
|
||
def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path):
|
||
store_path = tmp_path / "rag_store.json"
|
||
text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。"
|
||
|
||
result = ingest_document(
|
||
scenario_id="quality_analysis",
|
||
source_file="quality.md",
|
||
text=text,
|
||
collection="quality_analysis",
|
||
store_path=store_path,
|
||
)
|
||
ingest_document(
|
||
scenario_id="risk_audit",
|
||
source_file="risk.md",
|
||
text="报销审核需要检查发票、金额和审批链。",
|
||
collection="risk_audit",
|
||
store_path=store_path,
|
||
)
|
||
|
||
chunks = retrieve(
|
||
scenario_id="quality_analysis",
|
||
query="质量异常批次",
|
||
collection="quality_analysis",
|
||
top_k=3,
|
||
store_path=store_path,
|
||
)
|
||
|
||
assert result.success is True
|
||
assert result.chunks_count >= 1
|
||
assert chunks
|
||
assert chunks[0]["source"] == "quality.md"
|
||
assert "质量异常" in chunks[0]["content"]
|
||
assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks)
|
||
|
||
|
||
def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path):
|
||
store_path = tmp_path / "rag_store.json"
|
||
|
||
ingest_document(
|
||
document_id=1,
|
||
scenario_id="knowledge_qa",
|
||
source_file="old.md",
|
||
text="旧制度要求人工登记。",
|
||
collection="knowledge_qa",
|
||
store_path=store_path,
|
||
)
|
||
ingest_document(
|
||
document_id=1,
|
||
scenario_id="knowledge_qa",
|
||
source_file="new.md",
|
||
text="新制度要求系统自动登记。",
|
||
collection="knowledge_qa",
|
||
store_path=store_path,
|
||
)
|
||
ingest_document(
|
||
document_id=2,
|
||
scenario_id="knowledge_qa",
|
||
source_file="other.md",
|
||
text="系统自动登记后需要生成审计记录。",
|
||
collection="knowledge_qa",
|
||
store_path=store_path,
|
||
)
|
||
|
||
chunks = retrieve(
|
||
scenario_id="knowledge_qa",
|
||
query="系统自动登记",
|
||
collection="knowledge_qa",
|
||
top_k=5,
|
||
document_ids=[1],
|
||
store_path=store_path,
|
||
)
|
||
|
||
assert chunks
|
||
assert {chunk["document_id"] for chunk in chunks} == {1}
|
||
assert all(chunk["source"] == "new.md" for chunk in chunks)
|
||
assert all("旧制度" not in chunk["content"] for chunk in chunks)
|
||
|
||
|
||
def test_run_agent_uses_retrieved_document_chunks(tmp_path):
|
||
store_path = tmp_path / "rag_store.json"
|
||
ingest_document(
|
||
scenario_id="knowledge_qa",
|
||
source_file="sop.md",
|
||
text="异常处理 SOP:先隔离现场,再通知负责人。",
|
||
collection="knowledge_qa",
|
||
store_path=store_path,
|
||
)
|
||
scenario = {
|
||
"id": "knowledge_qa",
|
||
"name": "知识库问答助手",
|
||
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
|
||
"tools": [],
|
||
"output": {"type": "general_answer"},
|
||
}
|
||
|
||
result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path})
|
||
|
||
assert result.references[0]["source"] == "sop.md"
|
||
assert "隔离现场" in result.references[0]["content"]
|
||
|
||
|
||
def test_rag_split_text_keeps_overlap_and_non_empty_chunks():
|
||
chunks = _split_text("A" * 20, chunk_size=8, overlap=3)
|
||
|
||
assert chunks == ["AAAAAAAA", "AAAAAAAA", "AAAAAAAA", "AAAAA"]
|
||
|
||
|
||
def test_retrieve_returns_empty_when_query_has_no_overlap(tmp_path):
|
||
store_path = tmp_path / "rag_store.json"
|
||
ingest_document(
|
||
scenario_id="knowledge_qa",
|
||
source_file="rules.md",
|
||
text="这里描述的是报销流程和审批链。",
|
||
collection="knowledge_qa",
|
||
store_path=store_path,
|
||
)
|
||
|
||
chunks = retrieve(
|
||
scenario_id="knowledge_qa",
|
||
query="设备点检",
|
||
collection="knowledge_qa",
|
||
top_k=3,
|
||
store_path=store_path,
|
||
)
|
||
|
||
assert chunks == []
|
||
|
||
|
||
def test_registration_risk_result_includes_owner_fields_and_notification_payload():
|
||
scenario = {
|
||
"id": "document_review",
|
||
"name": "注册审核智能体",
|
||
"agent": {
|
||
"role": "注册审核助手",
|
||
"goal": "输出风险结果",
|
||
"instructions": ["输出结构化风险结果"],
|
||
},
|
||
"rag": {"enabled": False},
|
||
"tools": [],
|
||
"output": {"type": "registration_risk_report"},
|
||
}
|
||
provider_response = """
|
||
{
|
||
"summary": "存在高风险项,需人工复核。",
|
||
"highest_risk_level": "high",
|
||
"pass_status": "blocked",
|
||
"owner_roles": [
|
||
{
|
||
"owner_role": "注册资料负责人",
|
||
"owner_name": "张三",
|
||
"department": "注册事务部",
|
||
"chapter_scope": "CH1",
|
||
"risk_scope": "字段冲突",
|
||
"feishu_user_id": "ou_demo_1",
|
||
"feishu_open_id": "on_demo_1",
|
||
"feishu_name": "张三",
|
||
"notify_enabled": true
|
||
}
|
||
],
|
||
"notify_reason": "task_completed"
|
||
}
|
||
"""
|
||
|
||
class FakeProvider:
|
||
def generate(self, messages, response_format=None):
|
||
from agent_core.llm_provider import LLMResponse
|
||
|
||
return LLMResponse(content=provider_response, model_name="demo-model", success=True)
|
||
|
||
result = run_agent(scenario, "请输出风险结果", options={"llm_provider": FakeProvider()})
|
||
|
||
owner = result.notification_payload["owners"][0]
|
||
assert result.structured_output["output_type"] == "registration_risk_report"
|
||
assert owner["owner_role"] == "注册资料负责人"
|
||
assert owner["feishu_user_id"] == "ou_demo_1"
|
||
assert owner["feishu_open_id"] == "on_demo_1"
|
||
assert result.notification_payload["notify_reason"] == "task_completed"
|
||
|
||
|
||
def test_supported_output_types_include_word_export_and_feishu_notification():
|
||
assert "registration_word_export_report" in SUPPORTED_OUTPUT_TYPES
|
||
assert "feishu_notification_report" in SUPPORTED_OUTPUT_TYPES
|