feat(chat): 接入全局知识库上下文

This commit is contained in:
2026-06-08 21:38:12 +08:00
parent 5ecf78c5d6
commit 2244b69d62
5 changed files with 236 additions and 14 deletions

View File

@@ -16,7 +16,7 @@ class LLMRequestError(RuntimeError):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def generate_reply(conversation, user_message: str) -> str: def generate_reply(conversation, user_message: str, knowledge_context: str = "") -> str:
"""Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text.""" """Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text."""
if not settings.LLM_API_KEY: if not settings.LLM_API_KEY:
@@ -26,7 +26,7 @@ def generate_reply(conversation, user_message: str) -> str:
payload = { payload = {
"model": settings.LLM_MODEL, "model": settings.LLM_MODEL,
"messages": build_messages(conversation, user_message), "messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
"temperature": 0.3, "temperature": 0.3,
} }
body = json.dumps(payload).encode("utf-8") body = json.dumps(payload).encode("utf-8")
@@ -98,7 +98,7 @@ def generate_completion(messages: list[dict[str, str]], *, temperature: float =
raise LLMRequestError("模型接口返回格式不符合预期。") from exc raise LLMRequestError("模型接口返回格式不符合预期。") from exc
def stream_reply(conversation, user_message: str): def stream_reply(conversation, user_message: str, knowledge_context: str = ""):
"""Streams incremental assistant text from the SiliconFlow chat endpoint.""" """Streams incremental assistant text from the SiliconFlow chat endpoint."""
if not settings.LLM_API_KEY: if not settings.LLM_API_KEY:
@@ -108,7 +108,7 @@ def stream_reply(conversation, user_message: str):
payload = { payload = {
"model": settings.LLM_MODEL, "model": settings.LLM_MODEL,
"messages": build_messages(conversation, user_message), "messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
"temperature": 0.3, "temperature": 0.3,
"stream": True, "stream": True,
} }
@@ -153,10 +153,21 @@ def stream_reply(conversation, user_message: str):
raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc
def build_messages(conversation, latest_user_message: str) -> list[dict[str, str]]: def build_messages(conversation, latest_user_message: str, knowledge_context: str = "") -> list[dict[str, str]]:
"""Builds system and conversation history messages for the provider call.""" """Builds system and conversation history messages for the provider call."""
messages = [{"role": "system", "content": system_prompt()}] messages = [{"role": "system", "content": system_prompt()}]
if knowledge_context.strip():
messages.append(
{
"role": "system",
"content": (
"以下是全局知识库检索到的材料片段。回答用户时优先依据这些片段;"
"如果片段不足以支持结论,请明确说明信息不足,不要编造。\n\n"
f"{knowledge_context.strip()}"
),
}
)
for message in conversation.messages.all(): for message in conversation.messages.all():
messages.append({"role": message.role, "content": message.content}) messages.append({"role": message.role, "content": message.content})

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json import json
import logging import logging
from pathlib import Path
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from django.conf import settings from django.conf import settings
@@ -9,8 +10,10 @@ from django.utils import timezone
from .file_summary.skills.attachment_reader import AttachmentReaderSkill from .file_summary.skills.attachment_reader import AttachmentReaderSkill
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
from .knowledge_base import search_knowledge_base
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, Message from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, KnowledgeBaseDocument, Message
from .regulatory_review.services.rag_index import extract_text_from_path
from .application_form_fill.workflow import ( from .application_form_fill.workflow import (
create_application_form_fill_batch, create_application_form_fill_batch,
find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch, find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch,
@@ -104,8 +107,9 @@ def send_message(conversation: Conversation, content: str) -> tuple[Message, Mes
"""Stores one user message and one provider-backed assistant reply.""" """Stores one user message and one provider-backed assistant reply."""
user_message = append_user_message(conversation, content) user_message = append_user_message(conversation, content)
knowledge_context = build_knowledge_context(content)
try: try:
reply_content = generate_reply(conversation, content) reply_content = generate_reply(conversation, content, knowledge_context=knowledge_context)
except (LLMConfigurationError, LLMRequestError) as exc: except (LLMConfigurationError, LLMRequestError) as exc:
reply_content = f"模型调用失败:{exc}" reply_content = f"模型调用失败:{exc}"
@@ -391,8 +395,9 @@ def stream_message(conversation: Conversation, content: str):
stream_failed = False stream_failed = False
stream_error = "" stream_error = ""
knowledge_context = build_knowledge_context(content)
try: try:
for chunk in stream_reply(conversation, content): for chunk in stream_reply(conversation, content, knowledge_context=knowledge_context):
assistant_parts.append(chunk) assistant_parts.append(chunk)
yield sse_event("chunk", {"delta": chunk}) yield sse_event("chunk", {"delta": chunk})
except (LLMConfigurationError, LLMRequestError) as exc: except (LLMConfigurationError, LLMRequestError) as exc:
@@ -412,7 +417,7 @@ def stream_message(conversation: Conversation, content: str):
if stream_failed: if stream_failed:
try: try:
fallback_reply = generate_reply(conversation, content) fallback_reply = generate_reply(conversation, content, knowledge_context=knowledge_context)
assistant_parts = [fallback_reply] assistant_parts = [fallback_reply]
logger.info( logger.info(
"Non-stream fallback reply succeeded", "Non-stream fallback reply succeeded",
@@ -461,6 +466,118 @@ def build_conversation_title(content: str) -> str:
return normalized[:24] return normalized[:24]
def build_knowledge_context(content: str, *, n_results: int = 5) -> str:
"""Formats global knowledge-base search hits for normal chat prompts."""
full_document_context = build_filename_matched_document_context(content)
if full_document_context:
return full_document_context
try:
payload = search_knowledge_base(content, n_results=n_results)
except Exception as exc:
logger.warning("Knowledge-base search failed", extra={"error": str(exc)})
return ""
if payload.get("error_message"):
return ""
results = [
item
for item in _rank_knowledge_results(content, payload.get("results") or [])
if _is_relevant_knowledge_result(content, item)
]
lines: list[str] = []
for index, item in enumerate(results[:n_results], start=1):
text = " ".join(str(item.get("text") or "").split())
if not text:
continue
source = str(item.get("source") or "未知来源")
score = item.get("score")
score_label = f"score={score:.4f}" if isinstance(score, (int, float)) else ""
lines.append(f"[{index}] 来源:{source}{score_label}\n{text[:1200]}")
return "\n\n".join(lines)
def build_filename_matched_document_context(query: str, *, max_chars: int = 12000) -> str:
terms = _knowledge_query_terms(query)
if not terms:
return ""
matches = []
for document in KnowledgeBaseDocument.objects.filter(
status=KnowledgeBaseDocument.Status.ACTIVE,
is_active=True,
).order_by("-updated_at", "-id"):
filename = f"{document.display_name} {document.original_name}"
if any(term and term in filename for term in terms):
matches.append(document)
if not matches:
return ""
lines = [
"以下材料因用户问题中的关键词命中文档名称,已读取全文供回答前比对和总结。"
]
for index, document in enumerate(matches[:3], start=1):
text = _extract_managed_document_text(document)
if not text:
continue
lines.append(
f"[全文材料 {index}] 来源:用户知识库/{document.original_name}\n"
f"{' '.join(text.split())[:max_chars]}"
)
return "\n\n".join(lines).strip()
def _extract_managed_document_text(document: KnowledgeBaseDocument) -> str:
try:
return extract_text_from_path(Path(document.storage_path))
except Exception as exc:
logger.warning(
"Managed document full-text extraction failed",
extra={"document_id": document.pk, "error": str(exc)},
)
return ""
def _rank_knowledge_results(query: str, results: list[dict[str, object]]) -> list[dict[str, object]]:
terms = [term for term in _knowledge_query_terms(query) if term]
def sort_key(item: dict[str, object]) -> tuple[int, float]:
source = str(item.get("source") or "")
text = str(item.get("text") or "")
haystack = f"{source}\n{text}"
direct_hit = any(term in haystack for term in terms)
score = item.get("score")
numeric_score = float(score) if isinstance(score, (int, float)) else 999999.0
return (0 if direct_hit else 1, numeric_score)
return sorted(results, key=sort_key)
def _is_relevant_knowledge_result(query: str, item: dict[str, object]) -> bool:
terms = _knowledge_query_terms(query)
if not terms:
return False
source = str(item.get("source") or "")
text = str(item.get("text") or "")
haystack = f"{source}\n{text}"
if any(term in haystack for term in terms):
return True
metadata = item.get("metadata") or {}
if metadata.get("source_type") == "managed_document":
return True
return False
def _knowledge_query_terms(query: str) -> list[str]:
normalized = "".join((query or "").split())
if not normalized:
return []
stop_chars = set("是谁什么哪里如何怎么请问一下帮我你能告诉吗??,。.")
compact = "".join(char for char in normalized if char not in stop_chars)
terms = [compact] if compact else []
if normalized not in terms:
terms.append(normalized)
return terms
def _select_attachments_for_reader(conversation: Conversation, content: str): def _select_attachments_for_reader(conversation: Conversation, content: str):
attachments = list( attachments = list(
FileAttachment.objects.filter( FileAttachment.objects.filter(

View File

@@ -0,0 +1,59 @@
import pytest
from review_agent.models import KnowledgeBaseDocument
from review_agent.services import build_knowledge_context
pytestmark = pytest.mark.django_db
def test_build_knowledge_context_ignores_irrelevant_rag_chunks(monkeypatch):
monkeypatch.setattr(
"review_agent.services.search_knowledge_base",
lambda query, n_results=5: {
"query": query,
"results": [
{
"source": "附件 4 体外诊断试剂注册申报资料要求及说明.doc",
"text": "预期用途应明确产品用于检测的分析物和功能。",
"score": 7.636,
"metadata": {"source_type": "regulatory_document"},
}
],
"error_message": "",
},
)
context = build_knowledge_context("孙之烨是谁")
assert context == ""
def test_build_knowledge_context_uses_full_document_when_name_matches(settings, tmp_path, monkeypatch, django_user_model):
settings.MEDIA_ROOT = tmp_path
user = django_user_model.objects.create_user(username="owner", password="pass")
document_path = tmp_path / "resume.txt"
document_path.write_text(
"孙之烨,负责审核智能体项目。\n完整经历:曾组织技术分享并带队参加竞赛。",
encoding="utf-8",
)
KnowledgeBaseDocument.objects.create(
user=user,
display_name="孙之烨简历",
original_name="孙之烨-260510.txt",
storage_path=str(document_path),
file_size=document_path.stat().st_size,
status=KnowledgeBaseDocument.Status.ACTIVE,
is_active=True,
indexed_chunk_count=2,
)
monkeypatch.setattr(
"review_agent.services.search_knowledge_base",
lambda query, n_results=5: {"query": query, "results": [], "error_message": ""},
)
context = build_knowledge_context("孙之烨是谁")
assert "全文材料" in context
assert "来源:用户知识库/孙之烨-260510.txt" in context
assert "完整经历:曾组织技术分享并带队参加竞赛" in context

View File

@@ -201,17 +201,36 @@ def test_stream_message_returns_workflow_meta_when_triggered(settings, django_us
def test_stream_message_uses_normal_llm_path_when_not_triggered(monkeypatch, django_user_model): def test_stream_message_uses_normal_llm_path_when_not_triggered(monkeypatch, django_user_model):
user = django_user_model.objects.create_user(username="owner", password="pass") user = django_user_model.objects.create_user(username="owner", password="pass")
conversation = Conversation.objects.create(user=user, title="会话") conversation = Conversation.objects.create(user=user, title="会话")
calls = []
def fake_stream_reply(conversation, content): def fake_stream_reply(conversation, content, knowledge_context=""):
calls.append(knowledge_context)
yield "普通回复" yield "普通回复"
monkeypatch.setattr("review_agent.services.stream_reply", fake_stream_reply) monkeypatch.setattr("review_agent.services.stream_reply", fake_stream_reply)
monkeypatch.setattr(
"review_agent.services.search_knowledge_base",
lambda query, n_results=3: {
"query": query,
"results": [
{
"source": "用户知识库/1/2/孙之烨-260510.pdf",
"text": "孙之烨负责审核智能体项目。",
"score": 0.23,
}
],
"error_message": "",
},
)
frames = list(stream_message(conversation, "你好")) frames = list(stream_message(conversation, "孙之烨是谁"))
joined = "".join(frames) joined = "".join(frames)
assert "普通回复" in joined assert "普通回复" in joined
assert "workflow_started" not in joined assert "workflow_started" not in joined
assert calls
assert "孙之烨负责审核智能体项目" in calls[0]
assert "用户知识库/1/2/孙之烨-260510.pdf" in calls[0]
def test_stream_message_meta_uses_first_prompt_title_for_new_conversation(monkeypatch, django_user_model): def test_stream_message_meta_uses_first_prompt_title_for_new_conversation(monkeypatch, django_user_model):
@@ -257,12 +276,15 @@ def test_stream_message_falls_back_to_non_stream_reply_when_stream_breaks(monkey
user = django_user_model.objects.create_user(username="owner", password="pass") user = django_user_model.objects.create_user(username="owner", password="pass")
conversation = Conversation.objects.create(user=user, title="会话") conversation = Conversation.objects.create(user=user, title="会话")
def broken_stream_reply(conversation, content): def broken_stream_reply(conversation, content, knowledge_context=""):
yield "已生成部分内容" yield "已生成部分内容"
raise RuntimeError("provider connection reset") raise RuntimeError("provider connection reset")
monkeypatch.setattr("review_agent.services.stream_reply", broken_stream_reply) monkeypatch.setattr("review_agent.services.stream_reply", broken_stream_reply)
monkeypatch.setattr("review_agent.services.generate_reply", lambda conversation, content: "非流式完整回复") monkeypatch.setattr(
"review_agent.services.generate_reply",
lambda conversation, content, knowledge_context="": "非流式完整回复",
)
frames = list(stream_message(conversation, "普通问题")) frames = list(stream_message(conversation, "普通问题"))

View File

@@ -3,7 +3,7 @@ from urllib import request
import pytest import pytest
from review_agent.llm import stream_reply from review_agent.llm import build_messages, stream_reply
from review_agent.models import Conversation from review_agent.models import Conversation
@@ -39,3 +39,16 @@ def test_stream_reply_skips_malformed_sse_data(monkeypatch, settings, django_use
chunks = list(stream_reply(conversation, "你好")) chunks = list(stream_reply(conversation, "你好"))
assert chunks == ["A", "B"] assert chunks == ["A", "B"]
def test_build_messages_includes_knowledge_context(django_user_model):
user = django_user_model.objects.create_user(username="owner", password="pass")
conversation = Conversation.objects.create(user=user, title="会话")
messages = build_messages(conversation, "孙之烨是谁", knowledge_context="来源:简历\n孙之烨负责审核智能体项目。")
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "system"
assert "全局知识库" in messages[1]["content"]
assert "孙之烨负责审核智能体项目" in messages[1]["content"]
assert messages[-1] == {"role": "user", "content": "孙之烨是谁"}