feat(chat): 接入全局知识库上下文
This commit is contained in:
@@ -16,7 +16,7 @@ class LLMRequestError(RuntimeError):
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(conversation, user_message: str) -> str:
|
def generate_reply(conversation, user_message: str, knowledge_context: str = "") -> str:
|
||||||
"""Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text."""
|
"""Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text."""
|
||||||
|
|
||||||
if not settings.LLM_API_KEY:
|
if not settings.LLM_API_KEY:
|
||||||
@@ -26,7 +26,7 @@ def generate_reply(conversation, user_message: str) -> str:
|
|||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": settings.LLM_MODEL,
|
"model": settings.LLM_MODEL,
|
||||||
"messages": build_messages(conversation, user_message),
|
"messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
}
|
}
|
||||||
body = json.dumps(payload).encode("utf-8")
|
body = json.dumps(payload).encode("utf-8")
|
||||||
@@ -98,7 +98,7 @@ def generate_completion(messages: list[dict[str, str]], *, temperature: float =
|
|||||||
raise LLMRequestError("模型接口返回格式不符合预期。") from exc
|
raise LLMRequestError("模型接口返回格式不符合预期。") from exc
|
||||||
|
|
||||||
|
|
||||||
def stream_reply(conversation, user_message: str):
|
def stream_reply(conversation, user_message: str, knowledge_context: str = ""):
|
||||||
"""Streams incremental assistant text from the SiliconFlow chat endpoint."""
|
"""Streams incremental assistant text from the SiliconFlow chat endpoint."""
|
||||||
|
|
||||||
if not settings.LLM_API_KEY:
|
if not settings.LLM_API_KEY:
|
||||||
@@ -108,7 +108,7 @@ def stream_reply(conversation, user_message: str):
|
|||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": settings.LLM_MODEL,
|
"model": settings.LLM_MODEL,
|
||||||
"messages": build_messages(conversation, user_message),
|
"messages": build_messages(conversation, user_message, knowledge_context=knowledge_context),
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
@@ -153,10 +153,21 @@ def stream_reply(conversation, user_message: str):
|
|||||||
raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc
|
raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc
|
||||||
|
|
||||||
|
|
||||||
def build_messages(conversation, latest_user_message: str) -> list[dict[str, str]]:
|
def build_messages(conversation, latest_user_message: str, knowledge_context: str = "") -> list[dict[str, str]]:
|
||||||
"""Builds system and conversation history messages for the provider call."""
|
"""Builds system and conversation history messages for the provider call."""
|
||||||
|
|
||||||
messages = [{"role": "system", "content": system_prompt()}]
|
messages = [{"role": "system", "content": system_prompt()}]
|
||||||
|
if knowledge_context.strip():
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"以下是全局知识库检索到的材料片段。回答用户时优先依据这些片段;"
|
||||||
|
"如果片段不足以支持结论,请明确说明信息不足,不要编造。\n\n"
|
||||||
|
f"{knowledge_context.strip()}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for message in conversation.messages.all():
|
for message in conversation.messages.all():
|
||||||
messages.append({"role": message.role, "content": message.content})
|
messages.append({"role": message.role, "content": message.content})
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from django.db.models import Q, QuerySet
|
from django.db.models import Q, QuerySet
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
@@ -9,8 +10,10 @@ from django.utils import timezone
|
|||||||
|
|
||||||
from .file_summary.skills.attachment_reader import AttachmentReaderSkill
|
from .file_summary.skills.attachment_reader import AttachmentReaderSkill
|
||||||
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
|
from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow
|
||||||
|
from .knowledge_base import search_knowledge_base
|
||||||
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
|
from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply
|
||||||
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, Message
|
from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, KnowledgeBaseDocument, Message
|
||||||
|
from .regulatory_review.services.rag_index import extract_text_from_path
|
||||||
from .application_form_fill.workflow import (
|
from .application_form_fill.workflow import (
|
||||||
create_application_form_fill_batch,
|
create_application_form_fill_batch,
|
||||||
find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch,
|
find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch,
|
||||||
@@ -104,8 +107,9 @@ def send_message(conversation: Conversation, content: str) -> tuple[Message, Mes
|
|||||||
"""Stores one user message and one provider-backed assistant reply."""
|
"""Stores one user message and one provider-backed assistant reply."""
|
||||||
|
|
||||||
user_message = append_user_message(conversation, content)
|
user_message = append_user_message(conversation, content)
|
||||||
|
knowledge_context = build_knowledge_context(content)
|
||||||
try:
|
try:
|
||||||
reply_content = generate_reply(conversation, content)
|
reply_content = generate_reply(conversation, content, knowledge_context=knowledge_context)
|
||||||
except (LLMConfigurationError, LLMRequestError) as exc:
|
except (LLMConfigurationError, LLMRequestError) as exc:
|
||||||
reply_content = f"模型调用失败:{exc}"
|
reply_content = f"模型调用失败:{exc}"
|
||||||
|
|
||||||
@@ -391,8 +395,9 @@ def stream_message(conversation: Conversation, content: str):
|
|||||||
|
|
||||||
stream_failed = False
|
stream_failed = False
|
||||||
stream_error = ""
|
stream_error = ""
|
||||||
|
knowledge_context = build_knowledge_context(content)
|
||||||
try:
|
try:
|
||||||
for chunk in stream_reply(conversation, content):
|
for chunk in stream_reply(conversation, content, knowledge_context=knowledge_context):
|
||||||
assistant_parts.append(chunk)
|
assistant_parts.append(chunk)
|
||||||
yield sse_event("chunk", {"delta": chunk})
|
yield sse_event("chunk", {"delta": chunk})
|
||||||
except (LLMConfigurationError, LLMRequestError) as exc:
|
except (LLMConfigurationError, LLMRequestError) as exc:
|
||||||
@@ -412,7 +417,7 @@ def stream_message(conversation: Conversation, content: str):
|
|||||||
|
|
||||||
if stream_failed:
|
if stream_failed:
|
||||||
try:
|
try:
|
||||||
fallback_reply = generate_reply(conversation, content)
|
fallback_reply = generate_reply(conversation, content, knowledge_context=knowledge_context)
|
||||||
assistant_parts = [fallback_reply]
|
assistant_parts = [fallback_reply]
|
||||||
logger.info(
|
logger.info(
|
||||||
"Non-stream fallback reply succeeded",
|
"Non-stream fallback reply succeeded",
|
||||||
@@ -461,6 +466,118 @@ def build_conversation_title(content: str) -> str:
|
|||||||
return normalized[:24]
|
return normalized[:24]
|
||||||
|
|
||||||
|
|
||||||
|
def build_knowledge_context(content: str, *, n_results: int = 5) -> str:
|
||||||
|
"""Formats global knowledge-base search hits for normal chat prompts."""
|
||||||
|
|
||||||
|
full_document_context = build_filename_matched_document_context(content)
|
||||||
|
if full_document_context:
|
||||||
|
return full_document_context
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = search_knowledge_base(content, n_results=n_results)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Knowledge-base search failed", extra={"error": str(exc)})
|
||||||
|
return ""
|
||||||
|
if payload.get("error_message"):
|
||||||
|
return ""
|
||||||
|
results = [
|
||||||
|
item
|
||||||
|
for item in _rank_knowledge_results(content, payload.get("results") or [])
|
||||||
|
if _is_relevant_knowledge_result(content, item)
|
||||||
|
]
|
||||||
|
lines: list[str] = []
|
||||||
|
for index, item in enumerate(results[:n_results], start=1):
|
||||||
|
text = " ".join(str(item.get("text") or "").split())
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
source = str(item.get("source") or "未知来源")
|
||||||
|
score = item.get("score")
|
||||||
|
score_label = f",score={score:.4f}" if isinstance(score, (int, float)) else ""
|
||||||
|
lines.append(f"[{index}] 来源:{source}{score_label}\n{text[:1200]}")
|
||||||
|
return "\n\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def build_filename_matched_document_context(query: str, *, max_chars: int = 12000) -> str:
|
||||||
|
terms = _knowledge_query_terms(query)
|
||||||
|
if not terms:
|
||||||
|
return ""
|
||||||
|
matches = []
|
||||||
|
for document in KnowledgeBaseDocument.objects.filter(
|
||||||
|
status=KnowledgeBaseDocument.Status.ACTIVE,
|
||||||
|
is_active=True,
|
||||||
|
).order_by("-updated_at", "-id"):
|
||||||
|
filename = f"{document.display_name} {document.original_name}"
|
||||||
|
if any(term and term in filename for term in terms):
|
||||||
|
matches.append(document)
|
||||||
|
if not matches:
|
||||||
|
return ""
|
||||||
|
lines = [
|
||||||
|
"以下材料因用户问题中的关键词命中文档名称,已读取全文供回答前比对和总结。"
|
||||||
|
]
|
||||||
|
for index, document in enumerate(matches[:3], start=1):
|
||||||
|
text = _extract_managed_document_text(document)
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
lines.append(
|
||||||
|
f"[全文材料 {index}] 来源:用户知识库/{document.original_name}\n"
|
||||||
|
f"{' '.join(text.split())[:max_chars]}"
|
||||||
|
)
|
||||||
|
return "\n\n".join(lines).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_managed_document_text(document: KnowledgeBaseDocument) -> str:
|
||||||
|
try:
|
||||||
|
return extract_text_from_path(Path(document.storage_path))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Managed document full-text extraction failed",
|
||||||
|
extra={"document_id": document.pk, "error": str(exc)},
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _rank_knowledge_results(query: str, results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||||
|
terms = [term for term in _knowledge_query_terms(query) if term]
|
||||||
|
|
||||||
|
def sort_key(item: dict[str, object]) -> tuple[int, float]:
|
||||||
|
source = str(item.get("source") or "")
|
||||||
|
text = str(item.get("text") or "")
|
||||||
|
haystack = f"{source}\n{text}"
|
||||||
|
direct_hit = any(term in haystack for term in terms)
|
||||||
|
score = item.get("score")
|
||||||
|
numeric_score = float(score) if isinstance(score, (int, float)) else 999999.0
|
||||||
|
return (0 if direct_hit else 1, numeric_score)
|
||||||
|
|
||||||
|
return sorted(results, key=sort_key)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_relevant_knowledge_result(query: str, item: dict[str, object]) -> bool:
|
||||||
|
terms = _knowledge_query_terms(query)
|
||||||
|
if not terms:
|
||||||
|
return False
|
||||||
|
source = str(item.get("source") or "")
|
||||||
|
text = str(item.get("text") or "")
|
||||||
|
haystack = f"{source}\n{text}"
|
||||||
|
if any(term in haystack for term in terms):
|
||||||
|
return True
|
||||||
|
metadata = item.get("metadata") or {}
|
||||||
|
if metadata.get("source_type") == "managed_document":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _knowledge_query_terms(query: str) -> list[str]:
|
||||||
|
normalized = "".join((query or "").split())
|
||||||
|
if not normalized:
|
||||||
|
return []
|
||||||
|
stop_chars = set("是谁什么哪里如何怎么请问一下帮我你能告诉吗??,,。.")
|
||||||
|
compact = "".join(char for char in normalized if char not in stop_chars)
|
||||||
|
terms = [compact] if compact else []
|
||||||
|
if normalized not in terms:
|
||||||
|
terms.append(normalized)
|
||||||
|
return terms
|
||||||
|
|
||||||
|
|
||||||
def _select_attachments_for_reader(conversation: Conversation, content: str):
|
def _select_attachments_for_reader(conversation: Conversation, content: str):
|
||||||
attachments = list(
|
attachments = list(
|
||||||
FileAttachment.objects.filter(
|
FileAttachment.objects.filter(
|
||||||
|
|||||||
59
tests/test_chat_knowledge_context.py
Normal file
59
tests/test_chat_knowledge_context.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from review_agent.models import KnowledgeBaseDocument
|
||||||
|
from review_agent.services import build_knowledge_context
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_knowledge_context_ignores_irrelevant_rag_chunks(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"review_agent.services.search_knowledge_base",
|
||||||
|
lambda query, n_results=5: {
|
||||||
|
"query": query,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"source": "附件 4 体外诊断试剂注册申报资料要求及说明.doc",
|
||||||
|
"text": "预期用途应明确产品用于检测的分析物和功能。",
|
||||||
|
"score": 7.636,
|
||||||
|
"metadata": {"source_type": "regulatory_document"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"error_message": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
context = build_knowledge_context("孙之烨是谁")
|
||||||
|
|
||||||
|
assert context == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_knowledge_context_uses_full_document_when_name_matches(settings, tmp_path, monkeypatch, django_user_model):
|
||||||
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
document_path = tmp_path / "resume.txt"
|
||||||
|
document_path.write_text(
|
||||||
|
"孙之烨,负责审核智能体项目。\n完整经历:曾组织技术分享并带队参加竞赛。",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
KnowledgeBaseDocument.objects.create(
|
||||||
|
user=user,
|
||||||
|
display_name="孙之烨简历",
|
||||||
|
original_name="孙之烨-260510.txt",
|
||||||
|
storage_path=str(document_path),
|
||||||
|
file_size=document_path.stat().st_size,
|
||||||
|
status=KnowledgeBaseDocument.Status.ACTIVE,
|
||||||
|
is_active=True,
|
||||||
|
indexed_chunk_count=2,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"review_agent.services.search_knowledge_base",
|
||||||
|
lambda query, n_results=5: {"query": query, "results": [], "error_message": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
context = build_knowledge_context("孙之烨是谁")
|
||||||
|
|
||||||
|
assert "全文材料" in context
|
||||||
|
assert "来源:用户知识库/孙之烨-260510.txt" in context
|
||||||
|
assert "完整经历:曾组织技术分享并带队参加竞赛" in context
|
||||||
@@ -201,17 +201,36 @@ def test_stream_message_returns_workflow_meta_when_triggered(settings, django_us
|
|||||||
def test_stream_message_uses_normal_llm_path_when_not_triggered(monkeypatch, django_user_model):
|
def test_stream_message_uses_normal_llm_path_when_not_triggered(monkeypatch, django_user_model):
|
||||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
conversation = Conversation.objects.create(user=user, title="会话")
|
conversation = Conversation.objects.create(user=user, title="会话")
|
||||||
|
calls = []
|
||||||
|
|
||||||
def fake_stream_reply(conversation, content):
|
def fake_stream_reply(conversation, content, knowledge_context=""):
|
||||||
|
calls.append(knowledge_context)
|
||||||
yield "普通回复"
|
yield "普通回复"
|
||||||
|
|
||||||
monkeypatch.setattr("review_agent.services.stream_reply", fake_stream_reply)
|
monkeypatch.setattr("review_agent.services.stream_reply", fake_stream_reply)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"review_agent.services.search_knowledge_base",
|
||||||
|
lambda query, n_results=3: {
|
||||||
|
"query": query,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"source": "用户知识库/1/2/孙之烨-260510.pdf",
|
||||||
|
"text": "孙之烨负责审核智能体项目。",
|
||||||
|
"score": 0.23,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"error_message": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
frames = list(stream_message(conversation, "你好"))
|
frames = list(stream_message(conversation, "孙之烨是谁"))
|
||||||
|
|
||||||
joined = "".join(frames)
|
joined = "".join(frames)
|
||||||
assert "普通回复" in joined
|
assert "普通回复" in joined
|
||||||
assert "workflow_started" not in joined
|
assert "workflow_started" not in joined
|
||||||
|
assert calls
|
||||||
|
assert "孙之烨负责审核智能体项目" in calls[0]
|
||||||
|
assert "用户知识库/1/2/孙之烨-260510.pdf" in calls[0]
|
||||||
|
|
||||||
|
|
||||||
def test_stream_message_meta_uses_first_prompt_title_for_new_conversation(monkeypatch, django_user_model):
|
def test_stream_message_meta_uses_first_prompt_title_for_new_conversation(monkeypatch, django_user_model):
|
||||||
@@ -257,12 +276,15 @@ def test_stream_message_falls_back_to_non_stream_reply_when_stream_breaks(monkey
|
|||||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
conversation = Conversation.objects.create(user=user, title="会话")
|
conversation = Conversation.objects.create(user=user, title="会话")
|
||||||
|
|
||||||
def broken_stream_reply(conversation, content):
|
def broken_stream_reply(conversation, content, knowledge_context=""):
|
||||||
yield "已生成部分内容"
|
yield "已生成部分内容"
|
||||||
raise RuntimeError("provider connection reset")
|
raise RuntimeError("provider connection reset")
|
||||||
|
|
||||||
monkeypatch.setattr("review_agent.services.stream_reply", broken_stream_reply)
|
monkeypatch.setattr("review_agent.services.stream_reply", broken_stream_reply)
|
||||||
monkeypatch.setattr("review_agent.services.generate_reply", lambda conversation, content: "非流式完整回复")
|
monkeypatch.setattr(
|
||||||
|
"review_agent.services.generate_reply",
|
||||||
|
lambda conversation, content, knowledge_context="": "非流式完整回复",
|
||||||
|
)
|
||||||
|
|
||||||
frames = list(stream_message(conversation, "普通问题"))
|
frames = list(stream_message(conversation, "普通问题"))
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from urllib import request
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from review_agent.llm import stream_reply
|
from review_agent.llm import build_messages, stream_reply
|
||||||
from review_agent.models import Conversation
|
from review_agent.models import Conversation
|
||||||
|
|
||||||
|
|
||||||
@@ -39,3 +39,16 @@ def test_stream_reply_skips_malformed_sse_data(monkeypatch, settings, django_use
|
|||||||
chunks = list(stream_reply(conversation, "你好"))
|
chunks = list(stream_reply(conversation, "你好"))
|
||||||
|
|
||||||
assert chunks == ["A", "B"]
|
assert chunks == ["A", "B"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_messages_includes_knowledge_context(django_user_model):
|
||||||
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
conversation = Conversation.objects.create(user=user, title="会话")
|
||||||
|
|
||||||
|
messages = build_messages(conversation, "孙之烨是谁", knowledge_context="来源:简历\n孙之烨负责审核智能体项目。")
|
||||||
|
|
||||||
|
assert messages[0]["role"] == "system"
|
||||||
|
assert messages[1]["role"] == "system"
|
||||||
|
assert "全局知识库" in messages[1]["content"]
|
||||||
|
assert "孙之烨负责审核智能体项目" in messages[1]["content"]
|
||||||
|
assert messages[-1] == {"role": "user", "content": "孙之烨是谁"}
|
||||||
|
|||||||
Reference in New Issue
Block a user