feat(chat): 接入全局知识库上下文

This commit is contained in:
2026-06-08 21:38:12 +08:00
parent 5ecf78c5d6
commit 2244b69d62
5 changed files with 236 additions and 14 deletions

View File

@@ -16,7 +16,7 @@ class LLMRequestError(RuntimeError):
logger = logging.getLogger(__name__)
def generate_reply(conversation, user_message: str) -> str:
def generate_reply(conversation, user_message: str, knowledge_context: str = "") -> str:
"""Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text."""
if not settings.LLM_API_KEY:
@@ -26,7 +26,7 @@ def generate_reply(conversation, user_message: str) -> str:
payload = {
"model": settings.LLM_MODEL,
"messages": build_messages(conversation, user_message),
"messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
"temperature": 0.3,
}
body = json.dumps(payload).encode("utf-8")
@@ -98,7 +98,7 @@ def generate_completion(messages: list[dict[str, str]], *, temperature: float =
raise LLMRequestError("模型接口返回格式不符合预期。") from exc
def stream_reply(conversation, user_message: str):
def stream_reply(conversation, user_message: str, knowledge_context: str = ""):
"""Streams incremental assistant text from the SiliconFlow chat endpoint."""
if not settings.LLM_API_KEY:
@@ -108,7 +108,7 @@ def stream_reply(conversation, user_message: str):
payload = {
"model": settings.LLM_MODEL,
"messages": build_messages(conversation, user_message),
"messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
"temperature": 0.3,
"stream": True,
}
@@ -153,10 +153,21 @@ def stream_reply(conversation, user_message: str):
raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc
def build_messages(conversation, latest_user_message: str) -> list[dict[str, str]]:
def build_messages(conversation, latest_user_message: str, knowledge_context: str = "") -> list[dict[str, str]]:
"""Builds system and conversation history messages for the provider call."""
messages = [{"role": "system", "content": system_prompt()}]
if knowledge_context.strip():
messages.append(
{
"role": "system",
"content": (
"以下是全局知识库检索到的材料片段。回答用户时优先依据这些片段;"
"如果片段不足以支持结论,请明确说明信息不足,不要编造。\n\n"
f"{knowledge_context.strip()}"
),
}
)
for message in conversation.messages.all():
messages.append({"role": message.role, "content": message.content})

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json
import logging
from pathlib import Path
from django.db.models import Q, QuerySet
from django.conf import settings
@@ -9,8 +10,10 @@ from django.utils import timezone
from .file_summary.skills.attachment_reader import AttachmentReaderSkill
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
from .knowledge_base import search_knowledge_base
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, Message
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, KnowledgeBaseDocument, Message
from .regulatory_review.services.rag_index import extract_text_from_path
from .application_form_fill.workflow import (
create_application_form_fill_batch,
find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch,
@@ -104,8 +107,9 @@ def send_message(conversation: Conversation, content: str) -> tuple[Message, Mes
"""Stores one user message and one provider-backed assistant reply."""
user_message = append_user_message(conversation, content)
knowledge_context = build_knowledge_context(content)
try:
reply_content = generate_reply(conversation, content)
reply_content = generate_reply(conversation, content, knowledge_context=knowledge_context)
except (LLMConfigurationError, LLMRequestError) as exc:
reply_content = f"模型调用失败:{exc}"
@@ -391,8 +395,9 @@ def stream_message(conversation: Conversation, content: str):
stream_failed = False
stream_error = ""
knowledge_context = build_knowledge_context(content)
try:
for chunk in stream_reply(conversation, content):
for chunk in stream_reply(conversation, content, knowledge_context=knowledge_context):
assistant_parts.append(chunk)
yield sse_event("chunk", {"delta": chunk})
except (LLMConfigurationError, LLMRequestError) as exc:
@@ -412,7 +417,7 @@ def stream_message(conversation: Conversation, content: str):
if stream_failed:
try:
fallback_reply = generate_reply(conversation, content)
fallback_reply = generate_reply(conversation, content, knowledge_context=knowledge_context)
assistant_parts = [fallback_reply]
logger.info(
"Non-stream fallback reply succeeded",
@@ -461,6 +466,118 @@ def build_conversation_title(content: str) -> str:
return normalized[:24]
def build_knowledge_context(content: str, *, n_results: int = 5) -> str:
"""Formats global knowledge-base search hits for normal chat prompts."""
full_document_context = build_filename_matched_document_context(content)
if full_document_context:
return full_document_context
try:
payload = search_knowledge_base(content, n_results=n_results)
except Exception as exc:
logger.warning("Knowledge-base search failed", extra={"error": str(exc)})
return ""
if payload.get("error_message"):
return ""
results = [
item
for item in _rank_knowledge_results(content, payload.get("results") or [])
if _is_relevant_knowledge_result(content, item)
]
lines: list[str] = []
for index, item in enumerate(results[:n_results], start=1):
text = " ".join(str(item.get("text") or "").split())
if not text:
continue
source = str(item.get("source") or "未知来源")
score = item.get("score")
score_label = f"score={score:.4f}" if isinstance(score, (int, float)) else ""
lines.append(f"[{index}] 来源:{source}{score_label}\n{text[:1200]}")
return "\n\n".join(lines)
def build_filename_matched_document_context(query: str, *, max_chars: int = 12000) -> str:
terms = _knowledge_query_terms(query)
if not terms:
return ""
matches = []
for document in KnowledgeBaseDocument.objects.filter(
status=KnowledgeBaseDocument.Status.ACTIVE,
is_active=True,
).order_by("-updated_at", "-id"):
filename = f"{document.display_name} {document.original_name}"
if any(term and term in filename for term in terms):
matches.append(document)
if not matches:
return ""
lines = [
"以下材料因用户问题中的关键词命中文档名称,已读取全文供回答前比对和总结。"
]
for index, document in enumerate(matches[:3], start=1):
text = _extract_managed_document_text(document)
if not text:
continue
lines.append(
f"[全文材料 {index}] 来源:用户知识库/{document.original_name}\n"
f"{' '.join(text.split())[:max_chars]}"
)
return "\n\n".join(lines).strip()
def _extract_managed_document_text(document: KnowledgeBaseDocument) -> str:
try:
return extract_text_from_path(Path(document.storage_path))
except Exception as exc:
logger.warning(
"Managed document full-text extraction failed",
extra={"document_id": document.pk, "error": str(exc)},
)
return ""
def _rank_knowledge_results(query: str, results: list[dict[str, object]]) -> list[dict[str, object]]:
terms = [term for term in _knowledge_query_terms(query) if term]
def sort_key(item: dict[str, object]) -> tuple[int, float]:
source = str(item.get("source") or "")
text = str(item.get("text") or "")
haystack = f"{source}\n{text}"
direct_hit = any(term in haystack for term in terms)
score = item.get("score")
numeric_score = float(score) if isinstance(score, (int, float)) else 999999.0
return (0 if direct_hit else 1, numeric_score)
return sorted(results, key=sort_key)
def _is_relevant_knowledge_result(query: str, item: dict[str, object]) -> bool:
terms = _knowledge_query_terms(query)
if not terms:
return False
source = str(item.get("source") or "")
text = str(item.get("text") or "")
haystack = f"{source}\n{text}"
if any(term in haystack for term in terms):
return True
metadata = item.get("metadata") or {}
if metadata.get("source_type") == "managed_document":
return True
return False
def _knowledge_query_terms(query: str) -> list[str]:
normalized = "".join((query or "").split())
if not normalized:
return []
stop_chars = set("是谁什么哪里如何怎么请问一下帮我你能告诉吗??,。.")
compact = "".join(char for char in normalized if char not in stop_chars)
terms = [compact] if compact else []
if normalized not in terms:
terms.append(normalized)
return terms
def _select_attachments_for_reader(conversation: Conversation, content: str):
attachments = list(
FileAttachment.objects.filter(