feat(chat): 接入全局知识库上下文
This commit is contained in:
@@ -16,7 +16,7 @@ class LLMRequestError(RuntimeError):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_reply(conversation, user_message: str) -> str:
|
||||
def generate_reply(conversation, user_message: str, knowledge_context: str = "") -> str:
|
||||
"""Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text."""
|
||||
|
||||
if not settings.LLM_API_KEY:
|
||||
@@ -26,7 +26,7 @@ def generate_reply(conversation, user_message: str) -> str:
|
||||
|
||||
payload = {
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": build_messages(conversation, user_message),
|
||||
"messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
|
||||
"temperature": 0.3,
|
||||
}
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
@@ -98,7 +98,7 @@ def generate_completion(messages: list[dict[str, str]], *, temperature: float =
|
||||
raise LLMRequestError("模型接口返回格式不符合预期。") from exc
|
||||
|
||||
|
||||
def stream_reply(conversation, user_message: str):
|
||||
def stream_reply(conversation, user_message: str, knowledge_context: str = ""):
|
||||
"""Streams incremental assistant text from the SiliconFlow chat endpoint."""
|
||||
|
||||
if not settings.LLM_API_KEY:
|
||||
@@ -108,7 +108,7 @@ def stream_reply(conversation, user_message: str):
|
||||
|
||||
payload = {
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": build_messages(conversation, user_message),
|
||||
"messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
|
||||
"temperature": 0.3,
|
||||
"stream": True,
|
||||
}
|
||||
@@ -153,10 +153,21 @@ def stream_reply(conversation, user_message: str):
|
||||
raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc
|
||||
|
||||
|
||||
def build_messages(conversation, latest_user_message: str) -> list[dict[str, str]]:
|
||||
def build_messages(conversation, latest_user_message: str, knowledge_context: str = "") -> list[dict[str, str]]:
|
||||
"""Builds system and conversation history messages for the provider call."""
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt()}]
|
||||
if knowledge_context.strip():
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"以下是全局知识库检索到的材料片段。回答用户时优先依据这些片段;"
|
||||
"如果片段不足以支持结论,请明确说明信息不足,不要编造。\n\n"
|
||||
f"{knowledge_context.strip()}"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
for message in conversation.messages.all():
|
||||
messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from django.db.models import Q, QuerySet
|
||||
from django.conf import settings
|
||||
@@ -9,8 +10,10 @@ from django.utils import timezone
|
||||
|
||||
from .file_summary.skills.attachment_reader import AttachmentReaderSkill
|
||||
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
|
||||
from .knowledge_base import search_knowledge_base
|
||||
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
|
||||
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, Message
|
||||
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, KnowledgeBaseDocument, Message
|
||||
from .regulatory_review.services.rag_index import extract_text_from_path
|
||||
from .application_form_fill.workflow import (
|
||||
create_application_form_fill_batch,
|
||||
find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch,
|
||||
@@ -104,8 +107,9 @@ def send_message(conversation: Conversation, content: str) -> tuple[Message, Mes
|
||||
"""Stores one user message and one provider-backed assistant reply."""
|
||||
|
||||
user_message = append_user_message(conversation, content)
|
||||
knowledge_context = build_knowledge_context(content)
|
||||
try:
|
||||
reply_content = generate_reply(conversation, content)
|
||||
reply_content = generate_reply(conversation, content, knowledge_context=knowledge_context)
|
||||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||||
reply_content = f"模型调用失败:{exc}"
|
||||
|
||||
@@ -391,8 +395,9 @@ def stream_message(conversation: Conversation, content: str):
|
||||
|
||||
stream_failed = False
|
||||
stream_error = ""
|
||||
knowledge_context = build_knowledge_context(content)
|
||||
try:
|
||||
for chunk in stream_reply(conversation, content):
|
||||
for chunk in stream_reply(conversation, content, knowledge_context=knowledge_context):
|
||||
assistant_parts.append(chunk)
|
||||
yield sse_event("chunk", {"delta": chunk})
|
||||
except (LLMConfigurationError, LLMRequestError) as exc:
|
||||
@@ -412,7 +417,7 @@ def stream_message(conversation: Conversation, content: str):
|
||||
|
||||
if stream_failed:
|
||||
try:
|
||||
fallback_reply = generate_reply(conversation, content)
|
||||
fallback_reply = generate_reply(conversation, content, knowledge_context=knowledge_context)
|
||||
assistant_parts = [fallback_reply]
|
||||
logger.info(
|
||||
"Non-stream fallback reply succeeded",
|
||||
@@ -461,6 +466,118 @@ def build_conversation_title(content: str) -> str:
|
||||
return normalized[:24]
|
||||
|
||||
|
||||
def build_knowledge_context(content: str, *, n_results: int = 5) -> str:
|
||||
"""Formats global knowledge-base search hits for normal chat prompts."""
|
||||
|
||||
full_document_context = build_filename_matched_document_context(content)
|
||||
if full_document_context:
|
||||
return full_document_context
|
||||
|
||||
try:
|
||||
payload = search_knowledge_base(content, n_results=n_results)
|
||||
except Exception as exc:
|
||||
logger.warning("Knowledge-base search failed", extra={"error": str(exc)})
|
||||
return ""
|
||||
if payload.get("error_message"):
|
||||
return ""
|
||||
results = [
|
||||
item
|
||||
for item in _rank_knowledge_results(content, payload.get("results") or [])
|
||||
if _is_relevant_knowledge_result(content, item)
|
||||
]
|
||||
lines: list[str] = []
|
||||
for index, item in enumerate(results[:n_results], start=1):
|
||||
text = " ".join(str(item.get("text") or "").split())
|
||||
if not text:
|
||||
continue
|
||||
source = str(item.get("source") or "未知来源")
|
||||
score = item.get("score")
|
||||
score_label = f",score={score:.4f}" if isinstance(score, (int, float)) else ""
|
||||
lines.append(f"[{index}] 来源:{source}{score_label}\n{text[:1200]}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def build_filename_matched_document_context(query: str, *, max_chars: int = 12000) -> str:
|
||||
terms = _knowledge_query_terms(query)
|
||||
if not terms:
|
||||
return ""
|
||||
matches = []
|
||||
for document in KnowledgeBaseDocument.objects.filter(
|
||||
status=KnowledgeBaseDocument.Status.ACTIVE,
|
||||
is_active=True,
|
||||
).order_by("-updated_at", "-id"):
|
||||
filename = f"{document.display_name} {document.original_name}"
|
||||
if any(term and term in filename for term in terms):
|
||||
matches.append(document)
|
||||
if not matches:
|
||||
return ""
|
||||
lines = [
|
||||
"以下材料因用户问题中的关键词命中文档名称,已读取全文供回答前比对和总结。"
|
||||
]
|
||||
for index, document in enumerate(matches[:3], start=1):
|
||||
text = _extract_managed_document_text(document)
|
||||
if not text:
|
||||
continue
|
||||
lines.append(
|
||||
f"[全文材料 {index}] 来源:用户知识库/{document.original_name}\n"
|
||||
f"{' '.join(text.split())[:max_chars]}"
|
||||
)
|
||||
return "\n\n".join(lines).strip()
|
||||
|
||||
|
||||
def _extract_managed_document_text(document: KnowledgeBaseDocument) -> str:
|
||||
try:
|
||||
return extract_text_from_path(Path(document.storage_path))
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Managed document full-text extraction failed",
|
||||
extra={"document_id": document.pk, "error": str(exc)},
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def _rank_knowledge_results(query: str, results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
terms = [term for term in _knowledge_query_terms(query) if term]
|
||||
|
||||
def sort_key(item: dict[str, object]) -> tuple[int, float]:
|
||||
source = str(item.get("source") or "")
|
||||
text = str(item.get("text") or "")
|
||||
haystack = f"{source}\n{text}"
|
||||
direct_hit = any(term in haystack for term in terms)
|
||||
score = item.get("score")
|
||||
numeric_score = float(score) if isinstance(score, (int, float)) else 999999.0
|
||||
return (0 if direct_hit else 1, numeric_score)
|
||||
|
||||
return sorted(results, key=sort_key)
|
||||
|
||||
|
||||
def _is_relevant_knowledge_result(query: str, item: dict[str, object]) -> bool:
|
||||
terms = _knowledge_query_terms(query)
|
||||
if not terms:
|
||||
return False
|
||||
source = str(item.get("source") or "")
|
||||
text = str(item.get("text") or "")
|
||||
haystack = f"{source}\n{text}"
|
||||
if any(term in haystack for term in terms):
|
||||
return True
|
||||
metadata = item.get("metadata") or {}
|
||||
if metadata.get("source_type") == "managed_document":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _knowledge_query_terms(query: str) -> list[str]:
|
||||
normalized = "".join((query or "").split())
|
||||
if not normalized:
|
||||
return []
|
||||
stop_chars = set("是谁什么哪里如何怎么请问一下帮我你能告诉吗??,,。.")
|
||||
compact = "".join(char for char in normalized if char not in stop_chars)
|
||||
terms = [compact] if compact else []
|
||||
if normalized not in terms:
|
||||
terms.append(normalized)
|
||||
return terms
|
||||
|
||||
|
||||
def _select_attachments_for_reader(conversation: Conversation, content: str):
|
||||
attachments = list(
|
||||
FileAttachment.objects.filter(
|
||||
|
||||
Reference in New Issue
Block a user