feat(chat): 接入全局知识库上下文
This commit is contained in:
59
tests/test_chat_knowledge_context.py
Normal file
59
tests/test_chat_knowledge_context.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
|
||||
from review_agent.models import KnowledgeBaseDocument
|
||||
from review_agent.services import build_knowledge_context
|
||||
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
def test_build_knowledge_context_ignores_irrelevant_rag_chunks(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"review_agent.services.search_knowledge_base",
|
||||
lambda query, n_results=5: {
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"source": "附件 4 体外诊断试剂注册申报资料要求及说明.doc",
|
||||
"text": "预期用途应明确产品用于检测的分析物和功能。",
|
||||
"score": 7.636,
|
||||
"metadata": {"source_type": "regulatory_document"},
|
||||
}
|
||||
],
|
||||
"error_message": "",
|
||||
},
|
||||
)
|
||||
|
||||
context = build_knowledge_context("孙之烨是谁")
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
def test_build_knowledge_context_uses_full_document_when_name_matches(settings, tmp_path, monkeypatch, django_user_model):
|
||||
settings.MEDIA_ROOT = tmp_path
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
document_path = tmp_path / "resume.txt"
|
||||
document_path.write_text(
|
||||
"孙之烨,负责审核智能体项目。\n完整经历:曾组织技术分享并带队参加竞赛。",
|
||||
encoding="utf-8",
|
||||
)
|
||||
KnowledgeBaseDocument.objects.create(
|
||||
user=user,
|
||||
display_name="孙之烨简历",
|
||||
original_name="孙之烨-260510.txt",
|
||||
storage_path=str(document_path),
|
||||
file_size=document_path.stat().st_size,
|
||||
status=KnowledgeBaseDocument.Status.ACTIVE,
|
||||
is_active=True,
|
||||
indexed_chunk_count=2,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"review_agent.services.search_knowledge_base",
|
||||
lambda query, n_results=5: {"query": query, "results": [], "error_message": ""},
|
||||
)
|
||||
|
||||
context = build_knowledge_context("孙之烨是谁")
|
||||
|
||||
assert "全文材料" in context
|
||||
assert "来源:用户知识库/孙之烨-260510.txt" in context
|
||||
assert "完整经历:曾组织技术分享并带队参加竞赛" in context
|
||||
@@ -201,17 +201,36 @@ def test_stream_message_returns_workflow_meta_when_triggered(settings, django_us
|
||||
def test_stream_message_uses_normal_llm_path_when_not_triggered(monkeypatch, django_user_model):
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
calls = []
|
||||
|
||||
def fake_stream_reply(conversation, content):
|
||||
def fake_stream_reply(conversation, content, knowledge_context=""):
|
||||
calls.append(knowledge_context)
|
||||
yield "普通回复"
|
||||
|
||||
monkeypatch.setattr("review_agent.services.stream_reply", fake_stream_reply)
|
||||
monkeypatch.setattr(
|
||||
"review_agent.services.search_knowledge_base",
|
||||
lambda query, n_results=3: {
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"source": "用户知识库/1/2/孙之烨-260510.pdf",
|
||||
"text": "孙之烨负责审核智能体项目。",
|
||||
"score": 0.23,
|
||||
}
|
||||
],
|
||||
"error_message": "",
|
||||
},
|
||||
)
|
||||
|
||||
frames = list(stream_message(conversation, "你好"))
|
||||
frames = list(stream_message(conversation, "孙之烨是谁"))
|
||||
|
||||
joined = "".join(frames)
|
||||
assert "普通回复" in joined
|
||||
assert "workflow_started" not in joined
|
||||
assert calls
|
||||
assert "孙之烨负责审核智能体项目" in calls[0]
|
||||
assert "用户知识库/1/2/孙之烨-260510.pdf" in calls[0]
|
||||
|
||||
|
||||
def test_stream_message_meta_uses_first_prompt_title_for_new_conversation(monkeypatch, django_user_model):
|
||||
@@ -257,12 +276,15 @@ def test_stream_message_falls_back_to_non_stream_reply_when_stream_breaks(monkey
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
|
||||
def broken_stream_reply(conversation, content):
|
||||
def broken_stream_reply(conversation, content, knowledge_context=""):
|
||||
yield "已生成部分内容"
|
||||
raise RuntimeError("provider connection reset")
|
||||
|
||||
monkeypatch.setattr("review_agent.services.stream_reply", broken_stream_reply)
|
||||
monkeypatch.setattr("review_agent.services.generate_reply", lambda conversation, content: "非流式完整回复")
|
||||
monkeypatch.setattr(
|
||||
"review_agent.services.generate_reply",
|
||||
lambda conversation, content, knowledge_context="": "非流式完整回复",
|
||||
)
|
||||
|
||||
frames = list(stream_message(conversation, "普通问题"))
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from urllib import request
|
||||
|
||||
import pytest
|
||||
|
||||
from review_agent.llm import stream_reply
|
||||
from review_agent.llm import build_messages, stream_reply
|
||||
from review_agent.models import Conversation
|
||||
|
||||
|
||||
@@ -39,3 +39,16 @@ def test_stream_reply_skips_malformed_sse_data(monkeypatch, settings, django_use
|
||||
chunks = list(stream_reply(conversation, "你好"))
|
||||
|
||||
assert chunks == ["A", "B"]
|
||||
|
||||
|
||||
def test_build_messages_includes_knowledge_context(django_user_model):
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
|
||||
messages = build_messages(conversation, "孙之烨是谁", knowledge_context="来源:简历\n孙之烨负责审核智能体项目。")
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "system"
|
||||
assert "全局知识库" in messages[1]["content"]
|
||||
assert "孙之烨负责审核智能体项目" in messages[1]["content"]
|
||||
assert messages[-1] == {"role": "user", "content": "孙之烨是谁"}
|
||||
|
||||
Reference in New Issue
Block a user