253 lines
9.9 KiB
Python
253 lines
9.9 KiB
Python
import json
|
|
import time
|
|
|
|
from .governance import load_governance_config
|
|
from .llm_provider import create_llm_provider, get_runtime_llm_config
|
|
from .results import AgentResult
|
|
from .structured_output import (
|
|
build_response_schema_hint,
|
|
extract_answer_from_structured_output,
|
|
parse_structured_output,
|
|
)
|
|
from .tool_registry import run_declared_tools
|
|
from .rag.retriever import retrieve
|
|
|
|
|
|
def run_agent(scenario_config: dict, user_input: str, options: dict | None = None) -> AgentResult:
|
|
"""
|
|
执行当前场景的最小 Agent 闭环。
|
|
|
|
处理顺序保持和设计文档一致:
|
|
1. 读取场景配置
|
|
2. 执行 RAG 检索
|
|
3. 执行声明式工具
|
|
4. 构造 Prompt 并调用 LLM
|
|
5. 解析结构化结果
|
|
6. 统一返回 AgentResult
|
|
"""
|
|
started_at = time.perf_counter()
|
|
options = options or {}
|
|
output_type = scenario_config.get("output", {}).get("type", "general_answer")
|
|
|
|
references = _collect_references(scenario_config=scenario_config, user_input=user_input, options=options)
|
|
tool_calls = run_declared_tools(scenario_config.get("tools", []), user_input)
|
|
messages = build_messages(
|
|
scenario_config=scenario_config,
|
|
user_input=user_input,
|
|
references=references,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
provider = options.get("llm_provider") or create_llm_provider(
|
|
get_runtime_llm_config(options.get("llm_config"))
|
|
)
|
|
llm_response = provider.generate(
|
|
messages,
|
|
response_format=build_response_schema_hint(output_type),
|
|
)
|
|
latency_ms = int((time.perf_counter() - started_at) * 1000)
|
|
|
|
if not llm_response.success:
|
|
return AgentResult(
|
|
answer="模型调用失败,请检查配置或稍后重试。",
|
|
structured_output={},
|
|
references=references,
|
|
tool_calls=tool_calls,
|
|
raw_output="",
|
|
model_name=llm_response.model_name or "unknown-model",
|
|
latency_ms=latency_ms,
|
|
status="failed",
|
|
error=str(llm_response.error or "未知模型错误"),
|
|
conversation_id=str(options.get("conversation_id", "")),
|
|
batch_id=str(options.get("batch_id", "")),
|
|
product_name=str(options.get("product_name", "")),
|
|
notification_payload=_build_notification_payload(
|
|
{"notify_reason": "task_failed", "owner_roles": []},
|
|
options=options,
|
|
status="failed",
|
|
),
|
|
)
|
|
|
|
structured_output, _ = parse_structured_output(llm_response.content, output_type)
|
|
answer = extract_answer_from_structured_output(structured_output, llm_response.content)
|
|
return AgentResult(
|
|
answer=answer,
|
|
structured_output=structured_output,
|
|
references=references,
|
|
tool_calls=tool_calls,
|
|
raw_output=llm_response.content,
|
|
model_name=llm_response.model_name or "unknown-model",
|
|
latency_ms=latency_ms,
|
|
status="success",
|
|
conversation_id=str(options.get("conversation_id", "")),
|
|
batch_id=str(options.get("batch_id", "")),
|
|
product_name=str(options.get("product_name", "")),
|
|
node_results=_build_node_results(output_type, structured_output),
|
|
notification_payload=_build_notification_payload(structured_output, options=options, status="success"),
|
|
)
|
|
|
|
|
|
def build_messages(
|
|
scenario_config: dict,
|
|
user_input: str,
|
|
references: list[dict],
|
|
tool_calls: list[dict],
|
|
) -> list[dict]:
|
|
"""将场景配置、检索结果和工具结果整合为最小可解释 Prompt。"""
|
|
agent_config = scenario_config.get("agent", {})
|
|
system_message = "\n".join(
|
|
[
|
|
f"你当前扮演的角色:{agent_config.get('role', '通用业务助手')}",
|
|
f"当前任务目标:{agent_config.get('goal', '根据输入生成结构化结果')}",
|
|
"执行要求:",
|
|
_format_instructions(agent_config.get("instructions", [])),
|
|
f"输出类型:{scenario_config.get('output', {}).get('type', 'general_answer')}",
|
|
"请优先输出 JSON 对象,字段必须贴近约定输出结构。",
|
|
]
|
|
)
|
|
context_message = "\n".join(
|
|
[
|
|
f"当前场景:{scenario_config.get('name', '未命名场景')}",
|
|
_format_references(references),
|
|
_format_tool_calls(tool_calls),
|
|
]
|
|
)
|
|
return [
|
|
{"role": "system", "content": system_message},
|
|
{"role": "assistant", "content": context_message},
|
|
{"role": "user", "content": user_input},
|
|
]
|
|
|
|
|
|
def _collect_references(scenario_config: dict, user_input: str, options: dict) -> list[dict]:
|
|
"""按场景配置执行检索,并保持无 RAG 场景也能正常返回空列表。"""
|
|
rag_config = scenario_config.get("rag", {})
|
|
if not rag_config.get("enabled"):
|
|
return []
|
|
return retrieve(
|
|
scenario_id=scenario_config.get("id", ""),
|
|
query=user_input,
|
|
collection=rag_config.get("collection", scenario_config.get("id", "")),
|
|
top_k=rag_config.get("top_k", 5),
|
|
document_ids=options.get("document_ids"),
|
|
store_path=options.get("rag_store_path"),
|
|
)
|
|
|
|
|
|
def _format_instructions(instructions: list[str]) -> str:
|
|
if not instructions:
|
|
return "1. 结合知识库和工具结果回答。\n2. 信息不足时明确说明。"
|
|
return "\n".join(f"{index}. {item}" for index, item in enumerate(instructions, start=1))
|
|
|
|
|
|
def _format_references(references: list[dict]) -> str:
|
|
if not references:
|
|
return "知识库引用:当前没有检索到可用片段。"
|
|
lines = ["知识库引用:"]
|
|
for index, reference in enumerate(references, start=1):
|
|
lines.append(
|
|
f"{index}. 来源={reference.get('source', '未知来源')} 内容={reference.get('content', '')}"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _format_tool_calls(tool_calls: list[dict]) -> str:
|
|
if not tool_calls:
|
|
return "工具结果:当前场景未声明工具或无需调用工具。"
|
|
lines = ["工具结果:"]
|
|
for index, tool_call in enumerate(tool_calls, start=1):
|
|
if tool_call.get("success"):
|
|
lines.append(
|
|
f"{index}. 工具={tool_call.get('tool_name')} 结果={json.dumps(tool_call.get('result', {}), ensure_ascii=False)}"
|
|
)
|
|
else:
|
|
lines.append(
|
|
f"{index}. 工具={tool_call.get('tool_name')} 失败={tool_call.get('error', '未知错误')}"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _build_node_results(output_type: str, structured_output: dict) -> list[dict]:
|
|
if output_type.startswith("registration_") or output_type == "feishu_notification_report":
|
|
return _build_registration_node_results(output_type, structured_output)
|
|
return [
|
|
{
|
|
"code": output_type,
|
|
"label": output_type,
|
|
"status": "已完成",
|
|
"summary": structured_output.get("summary") or structured_output.get("answer", ""),
|
|
}
|
|
]
|
|
|
|
|
|
def _build_notification_payload(structured_output: dict, options: dict, status: str) -> dict:
|
|
notify_reason = structured_output.get("notify_reason") or (
|
|
"task_completed" if status == "success" else "task_failed"
|
|
)
|
|
owners = structured_output.get("owner_roles") or []
|
|
if not owners:
|
|
owners = load_governance_config()["owner_mappings"]
|
|
return {
|
|
"batch_id": str(options.get("batch_id", "")),
|
|
"conversation_id": str(options.get("conversation_id", "")),
|
|
"product_name": str(options.get("product_name", "")),
|
|
"notify_reason": notify_reason,
|
|
"owners": owners,
|
|
"status": status,
|
|
}
|
|
|
|
|
|
def _build_registration_node_results(output_type: str, structured_output: dict) -> list[dict]:
|
|
nodes = [
|
|
{"code": "package_import", "label": "资料包导入", "status": "已完成"},
|
|
{"code": "overview", "label": "目录汇总", "status": "待处理"},
|
|
{"code": "completeness", "label": "法规完整性检查", "status": "待处理"},
|
|
{"code": "field_extraction", "label": "字段抽取", "status": "待处理"},
|
|
{"code": "consistency", "label": "一致性核查", "status": "待处理"},
|
|
{"code": "risk", "label": "风险预警", "status": "待处理"},
|
|
{"code": "word_export", "label": "Word 回填导出", "status": "待处理"},
|
|
{"code": "feishu_notify", "label": "飞书通知", "status": "待处理"},
|
|
]
|
|
progression_map = {
|
|
"registration_overview_report": 1,
|
|
"registration_completeness_report": 2,
|
|
"registration_field_extraction_report": 3,
|
|
"registration_consistency_report": 4,
|
|
"registration_risk_report": 5,
|
|
"registration_word_export_report": 6,
|
|
"feishu_notification_report": 7,
|
|
}
|
|
completed_index = progression_map.get(output_type, 0)
|
|
for index in range(1, completed_index + 1):
|
|
nodes[index]["status"] = "已完成"
|
|
|
|
if output_type == "registration_risk_report":
|
|
pass_status = structured_output.get("pass_status", "")
|
|
if pass_status in {"blocked", "failed"}:
|
|
nodes[5]["status"] = "已阻断"
|
|
elif pass_status in {"review_required", "manual_review"}:
|
|
nodes[5]["status"] = "待复核"
|
|
else:
|
|
nodes[5]["status"] = "已完成"
|
|
return nodes
|
|
|
|
if output_type == "registration_word_export_report":
|
|
export_status = structured_output.get("export_status", "")
|
|
if export_status in {"blocked", "draft_only"}:
|
|
nodes[6]["status"] = "已阻断" if export_status == "blocked" else "待复核"
|
|
else:
|
|
nodes[6]["status"] = "已完成"
|
|
return nodes
|
|
|
|
if output_type == "feishu_notification_report":
|
|
message_status = structured_output.get("message_status", "")
|
|
if message_status in {"failed", "error"}:
|
|
nodes[7]["status"] = "失败"
|
|
elif message_status in {"sent", "success"}:
|
|
nodes[7]["status"] = "已完成"
|
|
else:
|
|
nodes[7]["status"] = "待处理"
|
|
return nodes
|
|
|
|
return nodes
|