diff --git a/review_agent/llm.py b/review_agent/llm.py index 9057536..ee2fd5d 100644 --- a/review_agent/llm.py +++ b/review_agent/llm.py @@ -16,7 +16,7 @@ class LLMRequestError(RuntimeError): logger = logging.getLogger(__name__) -def generate_reply(conversation, user_message: str) -> str: +def generate_reply(conversation, user_message: str, knowledge_context: str = "") -> str: """Calls the SiliconFlow OpenAI-compatible chat endpoint and returns assistant text.""" if not settings.LLM_API_KEY: @@ -26,7 +26,7 @@ def generate_reply(conversation, user_message: str) -> str: payload = { "model": settings.LLM_MODEL, - "messages": build_messages(conversation, user_message), + "messages": build_messages(conversation, user_message, knowledge_context=knowledge_context), "temperature": 0.3, } body = json.dumps(payload).encode("utf-8") @@ -98,7 +98,7 @@ def generate_completion(messages: list[dict[str, str]], *, temperature: float = raise LLMRequestError("模型接口返回格式不符合预期。") from exc -def stream_reply(conversation, user_message: str): +def stream_reply(conversation, user_message: str, knowledge_context: str = ""): """Streams incremental assistant text from the SiliconFlow chat endpoint.""" if not settings.LLM_API_KEY: @@ -108,7 +108,7 @@ def stream_reply(conversation, user_message: str): payload = { "model": settings.LLM_MODEL, - "messages": build_messages(conversation, user_message), + "messages": build_messages(conversation, user_message, knowledge_context=knowledge_context), "temperature": 0.3, "stream": True, } @@ -153,10 +153,21 @@ def stream_reply(conversation, user_message: str): raise LLMRequestError(f"模型接口调用失败:{exc.reason}") from exc -def build_messages(conversation, latest_user_message: str) -> list[dict[str, str]]: +def build_messages(conversation, latest_user_message: str, knowledge_context: str = "") -> list[dict[str, str]]: """Builds system and conversation history messages for the provider call.""" messages = [{"role": "system", "content": system_prompt()}] + if knowledge_context.strip(): + messages.append( + { + "role": "system", + "content": ( + "以下是全局知识库检索到的材料片段。回答用户时优先依据这些片段;" + "如果片段不足以支持结论,请明确说明信息不足,不要编造。\n\n" + f"{knowledge_context.strip()}" + ), + } + ) for message in conversation.messages.all(): messages.append({"role": message.role, "content": message.content}) diff --git a/review_agent/services.py b/review_agent/services.py index ac3af8d..45b1d74 100644 --- a/review_agent/services.py +++ b/review_agent/services.py @@ -2,6 +2,7 @@ from __future__ import annotations import json import logging +from pathlib import Path from django.db.models import Q, QuerySet from django.conf import settings @@ -9,8 +10,10 @@ from django.utils import timezone from .file_summary.skills.attachment_reader import AttachmentReaderSkill from .file_summary.workflow import create_file_summary_batch, start_file_summary_workflow +from .knowledge_base import search_knowledge_base from .llm import LLMConfigurationError, LLMRequestError, generate_reply, stream_reply -from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, Message +from .models import Conversation, FileAttachment, FileSummaryBatch, FileSummaryBatchAttachment, KnowledgeBaseDocument, Message +from .regulatory_review.services.rag_index import extract_text_from_path from .application_form_fill.workflow import ( create_application_form_fill_batch, find_latest_successful_summary_batch as find_latest_successful_form_fill_summary_batch, @@ -104,8 +107,9 @@ def send_message(conversation: Conversation, content: str) -> tuple[Message, Mes """Stores one user message and one provider-backed assistant reply.""" user_message = append_user_message(conversation, content) + knowledge_context = build_knowledge_context(content) try: - reply_content = generate_reply(conversation, content) + reply_content = generate_reply(conversation, content, knowledge_context=knowledge_context) except (LLMConfigurationError, LLMRequestError) as exc: reply_content = f"模型调用失败:{exc}" @@ -391,8 +395,9 @@ def stream_message(conversation: Conversation, content: str): stream_failed = False stream_error = "" + knowledge_context = build_knowledge_context(content) try: - for chunk in stream_reply(conversation, content): + for chunk in stream_reply(conversation, content, knowledge_context=knowledge_context): assistant_parts.append(chunk) yield sse_event("chunk", {"delta": chunk}) except (LLMConfigurationError, LLMRequestError) as exc: @@ -412,7 +417,7 @@ def stream_message(conversation: Conversation, content: str): if stream_failed: try: - fallback_reply = generate_reply(conversation, content) + fallback_reply = generate_reply(conversation, content, knowledge_context=knowledge_context) assistant_parts = [fallback_reply] logger.info( "Non-stream fallback reply succeeded", @@ -461,6 +466,118 @@ def build_conversation_title(content: str) -> str: return normalized[:24] +def build_knowledge_context(content: str, *, n_results: int = 5) -> str: + """Formats global knowledge-base search hits for normal chat prompts.""" + + full_document_context = build_filename_matched_document_context(content) + if full_document_context: + return full_document_context + + try: + payload = search_knowledge_base(content, n_results=n_results) + except Exception as exc: + logger.warning("Knowledge-base search failed", extra={"error": str(exc)}) + return "" + if payload.get("error_message"): + return "" + results = [ + item + for item in _rank_knowledge_results(content, payload.get("results") or []) + if _is_relevant_knowledge_result(content, item) + ] + lines: list[str] = [] + for index, item in enumerate(results[:n_results], start=1): + text = " ".join(str(item.get("text") or "").split()) + if not text: + continue + source = str(item.get("source") or "未知来源") + score = item.get("score") + score_label = f",score={score:.4f}" if isinstance(score, (int, float)) else "" + lines.append(f"[{index}] 来源:{source}{score_label}\n{text[:1200]}") + return "\n\n".join(lines) + + +def build_filename_matched_document_context(query: str, *, max_chars: int = 12000) -> str: + terms = _knowledge_query_terms(query) + if not terms: + return "" + matches = [] + for document in KnowledgeBaseDocument.objects.filter( + status=KnowledgeBaseDocument.Status.ACTIVE, + is_active=True, + ).order_by("-updated_at", "-id"): + filename = f"{document.display_name} {document.original_name}" + if any(term and term in filename for term in terms): + matches.append(document) + if not matches: + return "" + lines = [ + "以下材料因用户问题中的关键词命中文档名称,已读取全文供回答前比对和总结。" + ] + for index, document in enumerate(matches[:3], start=1): + text = _extract_managed_document_text(document) + if not text: + continue + lines.append( + f"[全文材料 {index}] 来源:用户知识库/{document.original_name}\n" + f"{' '.join(text.split())[:max_chars]}" + ) + return "\n\n".join(lines).strip() + + +def _extract_managed_document_text(document: KnowledgeBaseDocument) -> str: + try: + return extract_text_from_path(Path(document.storage_path)) + except Exception as exc: + logger.warning( + "Managed document full-text extraction failed", + extra={"document_id": document.pk, "error": str(exc)}, + ) + return "" + + +def _rank_knowledge_results(query: str, results: list[dict[str, object]]) -> list[dict[str, object]]: + terms = [term for term in _knowledge_query_terms(query) if term] + + def sort_key(item: dict[str, object]) -> tuple[int, float]: + source = str(item.get("source") or "") + text = str(item.get("text") or "") + haystack = f"{source}\n{text}" + direct_hit = any(term in haystack for term in terms) + score = item.get("score") + numeric_score = float(score) if isinstance(score, (int, float)) else 999999.0 + return (0 if direct_hit else 1, numeric_score) + + return sorted(results, key=sort_key) + + +def _is_relevant_knowledge_result(query: str, item: dict[str, object]) -> bool: + terms = _knowledge_query_terms(query) + if not terms: + return False + source = str(item.get("source") or "") + text = str(item.get("text") or "") + haystack = f"{source}\n{text}" + if any(term in haystack for term in terms): + return True + metadata = item.get("metadata") or {} + if metadata.get("source_type") == "managed_document": + return True + return False + + +def _knowledge_query_terms(query: str) -> list[str]: + normalized = "".join((query or "").split()) + if not normalized: + return [] + stop_chars = set("是谁什么哪里如何怎么请问一下帮我你能告诉吗??,,。.") + compact = "".join(char for char in normalized if char not in stop_chars) + terms = [compact] if compact else [] + if normalized not in terms: + terms.append(normalized) + return terms + + def _select_attachments_for_reader(conversation: Conversation, content: str): attachments = list( FileAttachment.objects.filter( diff --git a/tests/test_chat_knowledge_context.py b/tests/test_chat_knowledge_context.py new file mode 100644 index 0000000..a31a0f3 --- /dev/null +++ b/tests/test_chat_knowledge_context.py @@ -0,0 +1,59 @@ +import pytest + +from review_agent.models import KnowledgeBaseDocument +from review_agent.services import build_knowledge_context + + +pytestmark = pytest.mark.django_db + + +def test_build_knowledge_context_ignores_irrelevant_rag_chunks(monkeypatch): + monkeypatch.setattr( + "review_agent.services.search_knowledge_base", + lambda query, n_results=5: { + "query": query, + "results": [ + { + "source": "附件 4 体外诊断试剂注册申报资料要求及说明.doc", + "text": "预期用途应明确产品用于检测的分析物和功能。", + "score": 7.636, + "metadata": {"source_type": "regulatory_document"}, + } + ], + "error_message": "", + }, + ) + + context = build_knowledge_context("孙之烨是谁") + + assert context == "" + + +def test_build_knowledge_context_uses_full_document_when_name_matches(settings, tmp_path, monkeypatch, django_user_model): + settings.MEDIA_ROOT = tmp_path + user = django_user_model.objects.create_user(username="owner", password="pass") + document_path = tmp_path / "resume.txt" + document_path.write_text( + "孙之烨,负责审核智能体项目。\n完整经历:曾组织技术分享并带队参加竞赛。", + encoding="utf-8", + ) + KnowledgeBaseDocument.objects.create( + user=user, + display_name="孙之烨简历", + original_name="孙之烨-260510.txt", + storage_path=str(document_path), + file_size=document_path.stat().st_size, + status=KnowledgeBaseDocument.Status.ACTIVE, + is_active=True, + indexed_chunk_count=2, + ) + monkeypatch.setattr( + "review_agent.services.search_knowledge_base", + lambda query, n_results=5: {"query": query, "results": [], "error_message": ""}, + ) + + context = build_knowledge_context("孙之烨是谁") + + assert "全文材料" in context + assert "来源:用户知识库/孙之烨-260510.txt" in context + assert "完整经历:曾组织技术分享并带队参加竞赛" in context diff --git a/tests/test_file_summary_workflow.py b/tests/test_file_summary_workflow.py index 18feb42..9822751 100644 --- a/tests/test_file_summary_workflow.py +++ b/tests/test_file_summary_workflow.py @@ -201,17 +201,36 @@ def test_stream_message_returns_workflow_meta_when_triggered(settings, django_us def test_stream_message_uses_normal_llm_path_when_not_triggered(monkeypatch, django_user_model): user = django_user_model.objects.create_user(username="owner", password="pass") conversation = Conversation.objects.create(user=user, title="会话") + calls = [] - def fake_stream_reply(conversation, content): + def fake_stream_reply(conversation, content, knowledge_context=""): + calls.append(knowledge_context) yield "普通回复" monkeypatch.setattr("review_agent.services.stream_reply", fake_stream_reply) + monkeypatch.setattr( + "review_agent.services.search_knowledge_base", + lambda query, n_results=3: { + "query": query, + "results": [ + { + "source": "用户知识库/1/2/孙之烨-260510.pdf", + "text": "孙之烨负责审核智能体项目。", + "score": 0.23, + } + ], + "error_message": "", + }, + ) - frames = list(stream_message(conversation, "你好")) + frames = list(stream_message(conversation, "孙之烨是谁")) joined = "".join(frames) assert "普通回复" in joined assert "workflow_started" not in joined + assert calls + assert "孙之烨负责审核智能体项目" in calls[0] + assert "用户知识库/1/2/孙之烨-260510.pdf" in calls[0] def test_stream_message_meta_uses_first_prompt_title_for_new_conversation(monkeypatch, django_user_model): @@ -257,12 +276,15 @@ def test_stream_message_falls_back_to_non_stream_reply_when_stream_breaks(monkey user = django_user_model.objects.create_user(username="owner", password="pass") conversation = Conversation.objects.create(user=user, title="会话") - def broken_stream_reply(conversation, content): + def broken_stream_reply(conversation, content, knowledge_context=""): yield "已生成部分内容" raise RuntimeError("provider connection reset") monkeypatch.setattr("review_agent.services.stream_reply", broken_stream_reply) - monkeypatch.setattr("review_agent.services.generate_reply", lambda conversation, content: "非流式完整回复") + monkeypatch.setattr( + "review_agent.services.generate_reply", + lambda conversation, content, knowledge_context="": "非流式完整回复", + ) frames = list(stream_message(conversation, "普通问题")) diff --git a/tests/test_llm_streaming.py b/tests/test_llm_streaming.py index dae4f91..c5a4545 100644 --- a/tests/test_llm_streaming.py +++ b/tests/test_llm_streaming.py @@ -3,7 +3,7 @@ from urllib import request import pytest -from review_agent.llm import stream_reply +from review_agent.llm import build_messages, stream_reply from review_agent.models import Conversation @@ -39,3 +39,16 @@ def test_stream_reply_skips_malformed_sse_data(monkeypatch, settings, django_use chunks = list(stream_reply(conversation, "你好")) assert chunks == ["A", "B"] + + +def test_build_messages_includes_knowledge_context(django_user_model): + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + + messages = build_messages(conversation, "孙之烨是谁", knowledge_context="来源:简历\n孙之烨负责审核智能体项目。") + + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "system" + assert "全局知识库" in messages[1]["content"] + assert "孙之烨负责审核智能体项目" in messages[1]["content"] + assert messages[-1] == {"role": "user", "content": "孙之烨是谁"}