220 lines
7.4 KiB
Python
220 lines
7.4 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass
|
||
|
||
from .file_summary.workflow_trigger import (
|
||
evaluate_attachment_reader_trigger,
|
||
evaluate_file_summary_trigger,
|
||
)
|
||
from .llm import LLMConfigurationError, LLMRequestError, generate_completion
|
||
from .models import Conversation, FileAttachment
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
ROUTE_ACTIONS = {"normal_chat", "attachment_reader", "file_summary"}
|
||
ROUTE_ACTIONS.add("regulatory_review")
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class SkillRoute:
|
||
action: str
|
||
skill_name: str = ""
|
||
workflow_type: str = ""
|
||
confidence: float = 0.0
|
||
reason: str = ""
|
||
source: str = "llm"
|
||
|
||
@property
|
||
def uses_attachment_reader(self) -> bool:
|
||
return self.action == "attachment_reader"
|
||
|
||
@property
|
||
def starts_file_summary(self) -> bool:
|
||
return self.action == "file_summary"
|
||
|
||
@property
|
||
def starts_regulatory_review(self) -> bool:
|
||
return self.action == "regulatory_review"
|
||
|
||
@property
|
||
def is_normal_chat(self) -> bool:
|
||
return self.action == "normal_chat"
|
||
|
||
|
||
def route_message_intent(conversation: Conversation, content: str) -> SkillRoute:
|
||
attachments = list(_active_attachments(conversation))
|
||
try:
|
||
route = _route_with_llm(conversation, content, attachments)
|
||
logger.info(
|
||
"LLM skill route selected",
|
||
extra={
|
||
"conversation_id": conversation.pk,
|
||
"action": route.action,
|
||
"skill_name": route.skill_name,
|
||
"workflow_type": route.workflow_type,
|
||
"confidence": route.confidence,
|
||
"route_source": route.source,
|
||
"reason": route.reason,
|
||
},
|
||
)
|
||
return route
|
||
except (LLMConfigurationError, LLMRequestError, ValueError, json.JSONDecodeError) as exc:
|
||
logger.warning(
|
||
"LLM skill route failed, fallback to rules",
|
||
extra={"conversation_id": conversation.pk, "error": str(exc)},
|
||
)
|
||
return _route_with_rules(conversation, content)
|
||
|
||
|
||
def _route_with_llm(
|
||
conversation: Conversation,
|
||
content: str,
|
||
attachments: list[FileAttachment],
|
||
) -> SkillRoute:
|
||
raw = generate_completion(
|
||
[
|
||
{"role": "system", "content": _router_system_prompt()},
|
||
{
|
||
"role": "user",
|
||
"content": _router_user_prompt(
|
||
user_message=content,
|
||
attachments=attachments,
|
||
),
|
||
},
|
||
],
|
||
temperature=0.0,
|
||
)
|
||
payload = _parse_json_object(raw)
|
||
action = str(payload.get("action", "normal_chat")).strip()
|
||
if action not in ROUTE_ACTIONS:
|
||
raise ValueError(f"不支持的路由动作:{action}")
|
||
|
||
if action in {"attachment_reader", "file_summary"} and not attachments:
|
||
return SkillRoute(
|
||
action=action,
|
||
skill_name="attachment_reader" if action == "attachment_reader" else "",
|
||
workflow_type="file_summary" if action == "file_summary" else "",
|
||
confidence=_float_or_zero(payload.get("confidence")),
|
||
reason=str(payload.get("reason") or "LLM 判断需要附件,但当前无附件。"),
|
||
source="llm_missing_attachment",
|
||
)
|
||
|
||
return SkillRoute(
|
||
action=action,
|
||
skill_name="attachment_reader" if action == "attachment_reader" else "",
|
||
workflow_type=action if action in {"file_summary", "regulatory_review"} else "",
|
||
confidence=_float_or_zero(payload.get("confidence")),
|
||
reason=str(payload.get("reason") or ""),
|
||
source="llm",
|
||
)
|
||
|
||
|
||
def _route_with_rules(conversation: Conversation, content: str) -> SkillRoute:
|
||
if _matches_regulatory_review(content):
|
||
return SkillRoute(
|
||
action="regulatory_review",
|
||
workflow_type="regulatory_review",
|
||
confidence=0.7,
|
||
reason="命中法规核查关键词。",
|
||
source="rule_fallback",
|
||
)
|
||
|
||
file_summary = evaluate_file_summary_trigger(conversation, content)
|
||
if file_summary.should_start or file_summary.reason == "missing_attachment":
|
||
return SkillRoute(
|
||
action="file_summary",
|
||
workflow_type="file_summary",
|
||
confidence=0.5,
|
||
reason=file_summary.reason,
|
||
source="rule_fallback",
|
||
)
|
||
|
||
attachment_reader = evaluate_attachment_reader_trigger(conversation, content)
|
||
if attachment_reader.should_start or attachment_reader.reason == "missing_attachment":
|
||
return SkillRoute(
|
||
action="attachment_reader",
|
||
skill_name="attachment_reader",
|
||
confidence=0.5,
|
||
reason=attachment_reader.reason,
|
||
source="rule_fallback",
|
||
)
|
||
|
||
return SkillRoute(
|
||
action="normal_chat",
|
||
confidence=0.5,
|
||
reason="未匹配到需要调用 Skill 或工作流的意图。",
|
||
source="rule_fallback",
|
||
)
|
||
|
||
|
||
def _active_attachments(conversation: Conversation):
|
||
return (
|
||
FileAttachment.objects.filter(conversation=conversation, is_active=True)
|
||
.exclude(upload_status=FileAttachment.UploadStatus.DELETED)
|
||
.order_by("original_name", "-version_no")
|
||
)
|
||
|
||
|
||
def _router_system_prompt() -> str:
|
||
return (
|
||
"你是审核智能体的工具路由器,只判断是否需要调用工具,不直接回答用户。"
|
||
"你必须只输出 JSON 对象,不要输出 Markdown。"
|
||
"可选 action:normal_chat、attachment_reader、file_summary、regulatory_review。"
|
||
"attachment_reader 用于用户要求阅读、提取、分析、总结、查看上传附件内容。"
|
||
"file_summary 用于用户要求自动汇总文件目录、页数、清单或生成目录页数报告。"
|
||
"regulatory_review 用于用户要求法规核查、NMPA核查、完整性核查、章节一致性核查、风险预警或整改建议。"
|
||
"normal_chat 用于不需要读取附件或执行工作流的一般问答。"
|
||
"输出字段:action、confidence、reason。"
|
||
)
|
||
|
||
|
||
def _router_user_prompt(*, user_message: str, attachments: list[FileAttachment]) -> str:
|
||
attachment_lines = [
|
||
f"- id={attachment.pk}, name={attachment.original_name}, active={attachment.is_active}, status={attachment.upload_status}"
|
||
for attachment in attachments
|
||
]
|
||
attachment_text = "\n".join(attachment_lines) if attachment_lines else "无 active 附件"
|
||
return (
|
||
f"用户消息:{user_message}\n\n"
|
||
f"当前 active 附件:\n{attachment_text}\n\n"
|
||
"请判断应调用哪个 action。只输出 JSON。"
|
||
)
|
||
|
||
|
||
def _parse_json_object(raw: str) -> dict:
|
||
text = (raw or "").strip()
|
||
if text.startswith("```"):
|
||
text = text.strip("`").strip()
|
||
if text.lower().startswith("json"):
|
||
text = text[4:].strip()
|
||
start = text.find("{")
|
||
end = text.rfind("}")
|
||
if start == -1 or end == -1 or end < start:
|
||
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
|
||
return json.loads(text[start : end + 1])
|
||
|
||
|
||
def _float_or_zero(value) -> float:
|
||
try:
|
||
return float(value)
|
||
except (TypeError, ValueError):
|
||
return 0.0
|
||
|
||
|
||
def _matches_regulatory_review(content: str) -> bool:
|
||
normalized = content.lower()
|
||
keywords = [
|
||
"法规核查",
|
||
"nmpa核查",
|
||
"nmpa 核查",
|
||
"完整性核查",
|
||
"风险预警",
|
||
"整改建议",
|
||
"章节核查",
|
||
"一致性核查",
|
||
]
|
||
return any(keyword in normalized for keyword in keywords)
|