280 lines
9.9 KiB
Python
280 lines
9.9 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
|
||
from django.db.models import Q, QuerySet
|
||
from django.conf import settings
|
||
from django.utils import timezone
|
||
|
||
from .file_summary.skills.attachment_reader import AttachmentReaderSkill
|
||
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
|
||
from .file_summary.workflow_trigger import (
|
||
evaluate_attachment_reader_trigger,
|
||
evaluate_file_summary_trigger,
|
||
)
|
||
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
|
||
from .models import Conversation, FileAttachment, Message
|
||
|
||
|
||
def list_conversations(user, search: str = "") -> QuerySet[Conversation]:
|
||
"""Returns a user's conversations, optionally filtered by title or content."""
|
||
|
||
conversations = Conversation.objects.filter(user=user)
|
||
if not search:
|
||
return conversations
|
||
|
||
return conversations.filter(
|
||
Q(title__icontains=search) | Q(messages__content__icontains=search)
|
||
).distinct()
|
||
|
||
|
||
def get_conversation_for_user(user, conversation_id: int | None) -> Conversation | None:
|
||
"""Loads a conversation only when it belongs to the current user."""
|
||
|
||
if not conversation_id:
|
||
return None
|
||
return Conversation.objects.filter(user=user, pk=conversation_id).first()
|
||
|
||
|
||
def create_conversation(user) -> Conversation:
|
||
"""Creates an empty conversation that can immediately accept messages."""
|
||
|
||
now = timezone.localtime()
|
||
return Conversation.objects.create(
|
||
user=user,
|
||
title=f"新对话 {now.strftime('%m-%d %H:%M')}",
|
||
)
|
||
|
||
|
||
def append_user_message(conversation: Conversation, content: str) -> Message:
|
||
"""Appends a user message and updates the conversation title from the first prompt."""
|
||
|
||
message = Message.objects.create(
|
||
conversation=conversation,
|
||
role=Message.Role.USER,
|
||
content=content.strip(),
|
||
)
|
||
|
||
if conversation.messages.filter(role=Message.Role.USER).count() == 1:
|
||
conversation.title = build_conversation_title(content)
|
||
conversation.save(update_fields=["title", "updated_at"])
|
||
|
||
return message
|
||
|
||
|
||
def append_assistant_message(conversation: Conversation, content: str) -> Message:
|
||
"""Appends the deterministic assistant reply."""
|
||
|
||
return Message.objects.create(
|
||
conversation=conversation,
|
||
role=Message.Role.ASSISTANT,
|
||
content=content,
|
||
)
|
||
|
||
|
||
def send_message(conversation: Conversation, content: str) -> tuple[Message, Message]:
|
||
"""Stores one user message and one provider-backed assistant reply."""
|
||
|
||
user_message = append_user_message(conversation, content)
|
||
try:
|
||
reply_content = generate_reply(conversation, content)
|
||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||
reply_content = f"模型调用失败:{exc}"
|
||
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
|
||
if conversation.title.startswith("新对话"):
|
||
conversation.title = build_conversation_title(content)
|
||
conversation.save(update_fields=["title", "updated_at"])
|
||
|
||
return user_message, assistant_message
|
||
|
||
|
||
def stream_message(conversation: Conversation, content: str):
|
||
"""Yields SSE events while collecting a streamed assistant reply."""
|
||
|
||
user_message = append_user_message(conversation, content)
|
||
assistant_parts: list[str] = []
|
||
trigger = evaluate_file_summary_trigger(conversation, content)
|
||
attachment_reader_trigger = evaluate_attachment_reader_trigger(conversation, content)
|
||
|
||
yield sse_event(
|
||
"meta",
|
||
{
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title or build_conversation_title(content),
|
||
"user_message_id": user_message.pk,
|
||
"user_message": user_message.content,
|
||
},
|
||
)
|
||
|
||
if trigger.reason == "missing_attachment":
|
||
reply_content = "请先在当前对话右侧上传需要汇总的文件或压缩包,然后再发送自动汇总指令。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if attachment_reader_trigger.reason == "missing_attachment":
|
||
reply_content = "请先在当前对话右侧上传需要阅读的附件,然后再发送解析或阅读附件指令。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if attachment_reader_trigger.should_start:
|
||
attachments = _select_attachments_for_reader(conversation, content)
|
||
result = AttachmentReaderSkill().run_for_attachments(attachments)
|
||
reply_content = _format_attachment_reader_reply(result.data.get("attachments", []), result.message)
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
if trigger.should_start:
|
||
batch = create_file_summary_batch(
|
||
conversation=conversation,
|
||
user=conversation.user,
|
||
trigger_message=user_message,
|
||
)
|
||
start_file_summary_workflow(
|
||
batch,
|
||
async_run=getattr(settings, "FILE_SUMMARY_ASYNC", True),
|
||
)
|
||
reply_content = f"已启动文件目录与页数自动汇总工作流,批次号:{batch.batch_no}。"
|
||
assistant_message = append_assistant_message(conversation, reply_content)
|
||
yield sse_event(
|
||
"workflow_started",
|
||
{
|
||
"workflow_type": "file_summary",
|
||
"batch_id": batch.pk,
|
||
"batch_no": batch.batch_no,
|
||
},
|
||
)
|
||
yield sse_event("chunk", {"delta": reply_content})
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
return
|
||
|
||
try:
|
||
for chunk in stream_reply(conversation, content):
|
||
assistant_parts.append(chunk)
|
||
yield sse_event("chunk", {"delta": chunk})
|
||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||
fallback = f"模型调用失败:{exc}"
|
||
assistant_parts = [fallback]
|
||
yield sse_event("error", {"message": fallback})
|
||
|
||
assistant_message = append_assistant_message(conversation, "".join(assistant_parts).strip())
|
||
|
||
if conversation.title.startswith("新对话"):
|
||
conversation.title = build_conversation_title(content)
|
||
conversation.save(update_fields=["title", "updated_at"])
|
||
|
||
yield sse_event(
|
||
"done",
|
||
{
|
||
"assistant_message_id": assistant_message.pk,
|
||
"conversation_id": conversation.pk,
|
||
"title": conversation.title,
|
||
},
|
||
)
|
||
|
||
|
||
def build_conversation_title(content: str) -> str:
|
||
"""Creates a concise title from the first user message."""
|
||
|
||
normalized = " ".join(content.strip().split())
|
||
if not normalized:
|
||
return "新对话"
|
||
return normalized[:24]
|
||
|
||
|
||
def _select_attachments_for_reader(conversation: Conversation, content: str):
|
||
attachments = list(
|
||
FileAttachment.objects.filter(
|
||
conversation=conversation,
|
||
is_active=True,
|
||
)
|
||
.exclude(upload_status=FileAttachment.UploadStatus.DELETED)
|
||
.order_by("original_name", "-version_no")
|
||
)
|
||
matched = [attachment for attachment in attachments if attachment.original_name in content]
|
||
return matched or attachments
|
||
|
||
|
||
def _format_attachment_reader_reply(attachments: list[dict[str, object]], message: str) -> str:
|
||
if not attachments:
|
||
return message or "当前对话没有可读取的附件。"
|
||
|
||
lines = ["## 附件解析结果"]
|
||
for item in attachments:
|
||
status = item.get("status", "")
|
||
filename = item.get("filename", "")
|
||
file_type = item.get("file_type", "")
|
||
lines.extend(
|
||
[
|
||
"",
|
||
f"### {filename}",
|
||
f"- 类型:{file_type or '未知'}",
|
||
f"- 状态:{status}",
|
||
]
|
||
)
|
||
if item.get("error_message"):
|
||
lines.append(f"- 错误:{item['error_message']}")
|
||
continue
|
||
|
||
preview = str(item.get("preview_text") or "").strip()
|
||
if preview:
|
||
lines.extend(["", "摘要预览:", "```text", preview, "```"])
|
||
|
||
sections = item.get("sections") or []
|
||
if sections:
|
||
lines.append("")
|
||
lines.append("结构详情:")
|
||
for section in sections[:8]:
|
||
if not isinstance(section, dict):
|
||
continue
|
||
section_type = section.get("type", "section")
|
||
name = section.get("name", "")
|
||
extra = ""
|
||
if "row_count" in section:
|
||
extra = f",{section['row_count']} 行"
|
||
if "column_count" in section:
|
||
extra += f",{section['column_count']} 列"
|
||
lines.append(f"- {name}({section_type}{extra})")
|
||
return "\n".join(lines).strip()
|
||
|
||
|
||
def sse_event(event_name: str, payload: dict[str, object]) -> str:
|
||
"""Formats one server-sent event frame."""
|
||
|
||
return f"event: {event_name}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|