Files
DEMO-AGENT/review_agent/skill_router.py

241 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from .file_summary.workflow_trigger import (
evaluate_attachment_reader_trigger,
evaluate_file_summary_trigger,
)
from .application_form_fill.constants import FORM_FILL_TRIGGER_KEYWORDS, WORKFLOW_TYPE as FORM_FILL_WORKFLOW_TYPE
from .llm import LLMConfigurationError, LLMRequestError, generate_completion
from .models import Conversation, FileAttachment
logger = logging.getLogger(__name__)
ROUTE_ACTIONS = {"normal_chat", "attachment_reader", "file_summary"}
ROUTE_ACTIONS.add("regulatory_review")
ROUTE_ACTIONS.add(FORM_FILL_WORKFLOW_TYPE)
@dataclass(frozen=True)
class SkillRoute:
action: str
skill_name: str = ""
workflow_type: str = ""
confidence: float = 0.0
reason: str = ""
source: str = "llm"
@property
def uses_attachment_reader(self) -> bool:
return self.action == "attachment_reader"
@property
def starts_file_summary(self) -> bool:
return self.action == "file_summary"
@property
def starts_regulatory_review(self) -> bool:
return self.action == "regulatory_review"
@property
def starts_application_form_fill(self) -> bool:
return self.action == FORM_FILL_WORKFLOW_TYPE
@property
def is_normal_chat(self) -> bool:
return self.action == "normal_chat"
def route_message_intent(conversation: Conversation, content: str) -> SkillRoute:
attachments = list(_active_attachments(conversation))
try:
route = _route_with_llm(conversation, content, attachments)
logger.info(
"LLM skill route selected",
extra={
"conversation_id": conversation.pk,
"action": route.action,
"skill_name": route.skill_name,
"workflow_type": route.workflow_type,
"confidence": route.confidence,
"route_source": route.source,
"reason": route.reason,
},
)
return route
except (LLMConfigurationError, LLMRequestError, ValueError, json.JSONDecodeError) as exc:
logger.warning(
"LLM skill route failed, fallback to rules",
extra={"conversation_id": conversation.pk, "error": str(exc)},
)
return _route_with_rules(conversation, content)
def _route_with_llm(
conversation: Conversation,
content: str,
attachments: list[FileAttachment],
) -> SkillRoute:
raw = generate_completion(
[
{"role": "system", "content": _router_system_prompt()},
{
"role": "user",
"content": _router_user_prompt(
user_message=content,
attachments=attachments,
),
},
],
temperature=0.0,
)
payload = _parse_json_object(raw)
action = str(payload.get("action", "normal_chat")).strip()
if action not in ROUTE_ACTIONS:
raise ValueError(f"不支持的路由动作:{action}")
if action in {"attachment_reader", "file_summary"} and not attachments:
return SkillRoute(
action=action,
skill_name="attachment_reader" if action == "attachment_reader" else "",
workflow_type="file_summary" if action == "file_summary" else "",
confidence=_float_or_zero(payload.get("confidence")),
reason=str(payload.get("reason") or "LLM 判断需要附件,但当前无附件。"),
source="llm_missing_attachment",
)
return SkillRoute(
action=action,
skill_name="attachment_reader" if action == "attachment_reader" else "",
workflow_type=action if action in {"file_summary", "regulatory_review", FORM_FILL_WORKFLOW_TYPE} else "",
confidence=_float_or_zero(payload.get("confidence")),
reason=str(payload.get("reason") or ""),
source="llm",
)
def _route_with_rules(conversation: Conversation, content: str) -> SkillRoute:
if _matches_application_form_fill(content):
return SkillRoute(
action=FORM_FILL_WORKFLOW_TYPE,
workflow_type=FORM_FILL_WORKFLOW_TYPE,
confidence=0.7,
reason="命中申报文件自动填表关键词。",
source="rule_fallback",
)
if _matches_regulatory_review(content):
return SkillRoute(
action="regulatory_review",
workflow_type="regulatory_review",
confidence=0.7,
reason="命中法规核查关键词。",
source="rule_fallback",
)
file_summary = evaluate_file_summary_trigger(conversation, content)
if file_summary.should_start or file_summary.reason == "missing_attachment":
return SkillRoute(
action="file_summary",
workflow_type="file_summary",
confidence=0.5,
reason=file_summary.reason,
source="rule_fallback",
)
attachment_reader = evaluate_attachment_reader_trigger(conversation, content)
if attachment_reader.should_start or attachment_reader.reason == "missing_attachment":
return SkillRoute(
action="attachment_reader",
skill_name="attachment_reader",
confidence=0.5,
reason=attachment_reader.reason,
source="rule_fallback",
)
return SkillRoute(
action="normal_chat",
confidence=0.5,
reason="未匹配到需要调用 Skill 或工作流的意图。",
source="rule_fallback",
)
def _active_attachments(conversation: Conversation):
return (
FileAttachment.objects.filter(conversation=conversation, is_active=True)
.exclude(upload_status=FileAttachment.UploadStatus.DELETED)
.order_by("original_name", "-version_no")
)
def _router_system_prompt() -> str:
return (
"你是审核智能体的工具路由器,只判断是否需要调用工具,不直接回答用户。"
"你必须只输出 JSON 对象,不要输出 Markdown。"
"可选 actionnormal_chat、attachment_reader、file_summary、regulatory_review、application_form_fill。"
"attachment_reader 用于用户要求阅读、提取、分析、总结、查看上传附件内容。"
"file_summary 用于用户要求自动汇总文件目录、页数、清单或生成目录页数报告。"
"regulatory_review 用于用户要求法规核查、NMPA核查、完整性核查、章节一致性核查、风险预警或整改建议。"
"application_form_fill 用于用户要求填注册证、生成申报模板、填写对应表格、安全和性能基本原则清单或自动填表。"
"normal_chat 用于不需要读取附件或执行工作流的一般问答。"
"输出字段action、confidence、reason。"
)
def _router_user_prompt(*, user_message: str, attachments: list[FileAttachment]) -> str:
attachment_lines = [
f"- id={attachment.pk}, name={attachment.original_name}, active={attachment.is_active}, status={attachment.upload_status}"
for attachment in attachments
]
attachment_text = "\n".join(attachment_lines) if attachment_lines else "无 active 附件"
return (
f"用户消息:{user_message}\n\n"
f"当前 active 附件:\n{attachment_text}\n\n"
"请判断应调用哪个 action。只输出 JSON。"
)
def _parse_json_object(raw: str) -> dict:
text = (raw or "").strip()
if text.startswith("```"):
text = text.strip("`").strip()
if text.lower().startswith("json"):
text = text[4:].strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end < start:
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
return json.loads(text[start : end + 1])
def _float_or_zero(value) -> float:
try:
return float(value)
except (TypeError, ValueError):
return 0.0
def _matches_regulatory_review(content: str) -> bool:
normalized = content.lower()
keywords = [
"法规核查",
"nmpa核查",
"nmpa 核查",
"完整性核查",
"风险预警",
"整改建议",
"章节核查",
"一致性核查",
]
return any(keyword in normalized for keyword in keywords)
def _matches_application_form_fill(content: str) -> bool:
normalized = content.lower()
return any(keyword.lower() in normalized for keyword in FORM_FILL_TRIGGER_KEYWORDS)