413 lines
15 KiB
Python
413 lines
15 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
|
||
from django.db.models import Q, QuerySet
|
||
from django.conf import settings
|
||
from django.utils import timezone
|
||
|
||
from .file_summary.skills.attachment_reader import AttachmentReaderSkill
|
||
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
|
||
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
|
||
from .models import Conversation, FileAttachment, Message
|
||
from .regulatory_review.workflow import (
|
||
create_regulatory_review_batch,
|
||
find_latest_successful_summary_batch,
|
||
start_regulatory_review_workflow,
|
||
)
|
||
from .skill_router import route_message_intent
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def list_conversations(user, search: str = "") -> QuerySet[Conversation]:
|
||
"""Returns a user's conversations, optionally filtered by title or content."""
|
||
|
||
conversations = Conversation.objects.filter(user=user)
|
||
if not search:
|
||
return conversations
|
||
|
||
return conversations.filter(
|
||
Q(title__icontains=search) | Q(messages__content__icontains=search)
|
||
).distinct()
|
||
|
||
|
||
def get_conversation_for_user(user, conversation_id: int | None) -> Conversation | None:
|
||
"""Loads a conversation only when it belongs to the current user."""
|
||
|
||
if not conversation_id:
|
||
return None
|
||
return Conversation.objects.filter(user=user, pk=conversation_id).first()
|
||
|
||
|
||
def create_conversation(user) -> Conversation:
|
||
"""Creates an empty conversation that can immediately accept messages."""
|
||
|
||
now = timezone.localtime()
|
||
return Conversation.objects.create(
|
||
user=user,
|
||
title=f"新对话 {now.strftime('%m-%d %H:%M')}",
|
||
)
|
||
|
||
|
||
def append_user_message(conversation: Conversation, content: str) -> Message:
|
||
"""Appends a user message and updates the conversation title from the first prompt."""
|
||
|
||
message = Message.objects.create(
|
||
conversation=conversation,
|
||
role=Message.Role.USER,
|
||
content=content.strip(),
|
||
)
|
||
logger.info(
|
||
"User message appended",
|
||
extra={
|
||
"conversation_id": conversation.pk,
|
||
"message_id": message.pk,
|
||
"content_length": len(message.content),
|
||
},
|
||
)
|
||
|
||
if conversation.messages.filter(role=Message.Role.USER).count() == 1:
|
||
conversation.title = build_conversation_title(content)
|
||
conversation.save(update_fields=["title", "updated_at"])
|
||
|
||
return message
|
||
|
||
|
||
def append_assistant_message(conversation: Conversation, content: str) -> Message:
|
||
"""Appends the deterministic assistant reply."""
|
||
|
||
message = Message.objects.create(
|
||
conversation=conversation,
|
||
role=Message.Role.ASSISTANT,
|
||
content=content,
|
||
)
|
||
logger.info(
|
||
"Assistant message appended",
|
||
extra={
|
||
"conversation_id": conversation.pk,
|
||
"message_id": message.pk,
|
||
"content_length": len(content or ""),
|
||
},
|
||
)
|
||
return message
|
||
|
||
|
||
def send_message(conversation: Conversation, content: str) -> tuple[Message, Message]:
|
||
"""Stores one user message and one provider-backed assistant reply."""
|
||
|
||
user_message = append_user_message(conversation, content)
|
||
try:
|
||
reply_content = generate_reply(conversation, content)
|
||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||
reply_content = f"模型调用失败:{exc}"
|
||
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
|
||
if conversation.title.startswith("新对话"):
|
||
conversation.title = build_conversation_title(content)
|
||
conversation.save(update_fields=["title", "updated_at"])
|
||
|
||
return user_message, assistant_message
|
||
|
||
|
||
def stream_message(conversation: Conversation, content: str):
|
||
"""Yields SSE events while collecting a streamed assistant reply."""
|
||
|
||
user_message = append_user_message(conversation, content)
|
||
assistant_parts: list[str] = []
|
||
route = route_message_intent(conversation, content)
|
||
logger.info(
|
||
"Stream message started",
|
||
extra={
|
||
"conversation_id": conversation.pk,
|
||
"user_message_id": user_message.pk,
|
||
"route_action": route.action,
|
||
"route_source": route.source,
|
||
"route_confidence": route.confidence,
|
||
"route_reason": route.reason,
|
||
},
|
||
)
|
||
|
||
yield sse_event(
|
||
"meta",
|
||
{
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title or build_conversation_title(content),
|
||
"user_message_id": user_message.pk,
|
||
"user_message": user_message.content,
|
||
},
|
||
)
|
||
|
||
if route.starts_file_summary and not _has_active_attachments(conversation):
|
||
reply_content = "请先在当前对话右侧上传需要汇总的文件或压缩包,然后再发送自动汇总指令。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if route.uses_attachment_reader and not _has_active_attachments(conversation):
|
||
reply_content = "请先在当前对话右侧上传需要阅读的附件,然后再发送解析或阅读附件指令。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if route.uses_attachment_reader:
|
||
attachments = _select_attachments_for_reader(conversation, content)
|
||
logger.info(
|
||
"Attachment reader path selected",
|
||
extra={
|
||
"conversation_id": conversation.pk,
|
||
"attachment_count": len(attachments),
|
||
"attachment_ids": [attachment.pk for attachment in attachments],
|
||
},
|
||
)
|
||
result = AttachmentReaderSkill().run_for_attachments(attachments)
|
||
reply_content = _format_attachment_reader_reply(result.data.get("attachments", []), result.message)
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if route.starts_file_summary:
|
||
batch = create_file_summary_batch(
|
||
conversation=conversation,
|
||
user=conversation.user,
|
||
trigger_message=user_message,
|
||
)
|
||
start_file_summary_workflow(
|
||
batch,
|
||
async_run=getattr(settings, "FILE_SUMMARY_ASYNC", True),
|
||
)
|
||
reply_content = f"已启动文件目录与页数自动汇总工作流,批次号:{batch.batch_no}。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event(
|
||
"workflow_started",
|
||
{
|
||
"workflow_type": "file_summary",
|
||
"batch_id": batch.pk,
|
||
"batch_no": batch.batch_no,
|
||
},
|
||
)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if route.starts_regulatory_review:
|
||
source_summary_batch = find_latest_successful_summary_batch(conversation)
|
||
if not source_summary_batch:
|
||
reply_content = "请先执行自动汇总,生成成功的文件汇总批次后再启动法规核查。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
batch = create_regulatory_review_batch(
|
||
conversation=conversation,
|
||
user=conversation.user,
|
||
trigger_message=user_message,
|
||
source_summary_batch=source_summary_batch,
|
||
)
|
||
start_regulatory_review_workflow(
|
||
batch,
|
||
async_run=getattr(settings, "REGULATORY_REVIEW_ASYNC", True),
|
||
)
|
||
reply_content = f"已启动 NMPA 注册资料法规核查工作流,批次号:{batch.batch_no}。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event(
|
||
"workflow_started",
|
||
{
|
||
"workflow_type": "regulatory_review",
|
||
"batch_id": batch.pk,
|
||
"batch_no": batch.batch_no,
|
||
},
|
||
)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
stream_failed = False
|
||
stream_error = ""
|
||
try:
|
||
for chunk in stream_reply(conversation, content):
|
||
assistant_parts.append(chunk)
|
||
yield sse_event("chunk", {"delta": chunk})
|
||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||
stream_failed = True
|
||
stream_error = str(exc)
|
||
logger.warning(
|
||
"LLM stream failed",
|
||
extra={"conversation_id": conversation.pk, "error": str(exc)},
|
||
)
|
||
except Exception as exc:
|
||
stream_failed = True
|
||
stream_error = str(exc)
|
||
logger.exception(
|
||
"Unexpected stream failure",
|
||
extra={"conversation_id": conversation.pk, "error": str(exc)},
|
||
)
|
||
|
||
if stream_failed:
|
||
try:
|
||
fallback_reply = generate_reply(conversation, content)
|
||
assistant_parts = [fallback_reply]
|
||
logger.info(
|
||
"Non-stream fallback reply succeeded",
|
||
extra={"conversation_id": conversation.pk, "content_length": len(fallback_reply)},
|
||
)
|
||
yield sse_event("replace", {"content": fallback_reply})
|
||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||
fallback = f"模型调用失败:{exc}"
|
||
assistant_parts = [fallback]
|
||
logger.warning(
|
||
"Non-stream fallback reply failed",
|
||
extra={"conversation_id": conversation.pk, "error": str(exc), "stream_error": stream_error},
|
||
)
|
||
yield sse_event("error", {"message": fallback})
|
||
except Exception as exc:
|
||
fallback = f"回复生成中断:{stream_error or exc}"
|
||
assistant_parts.append("\n\n" + fallback)
|
||
logger.exception(
|
||
"Non-stream fallback crashed",
|
||
extra={"conversation_id": conversation.pk, "error": str(exc), "stream_error": stream_error},
|
||
)
|
||
yield sse_event("error", {"message": fallback})
|
||
|
||
assistant_message = append_assistant_message(conversation, "".join(assistant_parts).strip())
|
||
|
||
if conversation.title.startswith("新对话"):
|
||
conversation.title = build_conversation_title(content)
|
||
conversation.save(update_fields=["title", "updated_at"])
|
||
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
|
||
|
||
def build_conversation_title(content: str) -> str:
|
||
"""Creates a concise title from the first user message."""
|
||
|
||
normalized = " ".join(content.strip().split())
|
||
if not normalized:
|
||
return "新对话"
|
||
return normalized[:24]
|
||
|
||
|
||
def _select_attachments_for_reader(conversation: Conversation, content: str):
|
||
attachments = list(
|
||
FileAttachment.objects.filter(
|
||
conversation=conversation,
|
||
is_active=True,
|
||
)
|
||
.exclude(upload_status=FileAttachment.UploadStatus.DELETED)
|
||
.order_by("original_name", "-version_no")
|
||
)
|
||
matched = [attachment for attachment in attachments if attachment.original_name in content]
|
||
return matched or attachments
|
||
|
||
|
||
def _has_active_attachments(conversation: Conversation) -> bool:
|
||
return (
|
||
FileAttachment.objects.filter(conversation=conversation, is_active=True)
|
||
.exclude(upload_status=FileAttachment.UploadStatus.DELETED)
|
||
.exists()
|
||
)
|
||
|
||
|
||
def _format_attachment_reader_reply(attachments: list[dict[str, object]], message: str) -> str:
|
||
if not attachments:
|
||
return message or "当前对话没有可读取的附件。"
|
||
|
||
lines = ["## 附件解析结果"]
|
||
for item in attachments:
|
||
status = item.get("status", "")
|
||
filename = item.get("filename", "")
|
||
file_type = item.get("file_type", "")
|
||
lines.extend(
|
||
[
|
||
"",
|
||
f"### {filename}",
|
||
f"- 类型:{file_type or '未知'}",
|
||
f"- 状态:{status}",
|
||
]
|
||
)
|
||
if item.get("error_message"):
|
||
lines.append(f"- 错误:{item['error_message']}")
|
||
continue
|
||
|
||
preview = str(item.get("preview_text") or "").strip()
|
||
if preview:
|
||
lines.extend(["", "摘要预览:", "```text", preview, "```"])
|
||
|
||
sections = item.get("sections") or []
|
||
if sections:
|
||
lines.append("")
|
||
lines.append("结构详情:")
|
||
for section in sections[:8]:
|
||
if not isinstance(section, dict):
|
||
continue
|
||
section_type = section.get("type", "section")
|
||
name = section.get("name", "")
|
||
extra = ""
|
||
if "row_count" in section:
|
||
extra = f",{section['row_count']} 行"
|
||
if "column_count" in section:
|
||
extra += f",{section['column_count']} 列"
|
||
lines.append(f"- {name}({section_type}{extra})")
|
||
return "\n".join(lines).strip()
|
||
|
||
|
||
def sse_event(event_name: str, payload: dict[str, object]) -> str:
|
||
"""Formats one server-sent event frame."""
|
||
|
||
return f"event: {event_name}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|