230 lines
8.0 KiB
Python
230 lines
8.0 KiB
Python
import sys
|
|
|
|
import pytest
|
|
|
|
from review_agent.regulatory_review.services.rag_citation import (
|
|
RagIndexUnavailable,
|
|
retrieve_citations,
|
|
)
|
|
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
|
|
from review_agent.regulatory_review.services.rag_index import chunk_text
|
|
from review_agent.regulatory_review.services.rag_index import collect_source_chunks
|
|
from review_agent.regulatory_review.services.rag_index import build_chroma_index
|
|
|
|
|
|
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
|
|
calls = []
|
|
|
|
class FakeResponse:
|
|
def raise_for_status(self):
|
|
return None
|
|
|
|
def json(self):
|
|
return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
|
|
|
|
def fake_post(url, headers, json, timeout):
|
|
calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
|
|
return FakeResponse()
|
|
|
|
monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post)
|
|
|
|
provider = SiliconFlowEmbeddingProvider(
|
|
api_key="secret",
|
|
base_url="https://api.siliconflow.cn/v1",
|
|
model="Qwen/Qwen3-Embedding-4B",
|
|
dimensions=1024,
|
|
)
|
|
|
|
assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]]
|
|
assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings"
|
|
assert calls[0]["headers"]["Authorization"] == "Bearer secret"
|
|
assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B"
|
|
assert calls[0]["json"]["dimensions"] == 1024
|
|
|
|
|
|
def test_chunk_text_preserves_source_metadata():
|
|
chunks = chunk_text(
|
|
"第一段法规内容。\n" * 20,
|
|
source="法规.doc",
|
|
chunk_size=30,
|
|
overlap=5,
|
|
)
|
|
|
|
assert len(chunks) > 1
|
|
assert chunks[0].metadata["source"] == "法规.doc"
|
|
assert chunks[0].text
|
|
|
|
|
|
def test_retrieve_citations_returns_placeholder_when_no_hits():
|
|
class EmptyCollection:
|
|
def query(self, query_embeddings, n_results):
|
|
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
|
|
|
|
citations = retrieve_citations(
|
|
"注册检验报告",
|
|
embedding_provider=lambda texts: [[0.1, 0.2]],
|
|
collection=EmptyCollection(),
|
|
)
|
|
|
|
assert citations[0]["source"] == "原文依据待补充"
|
|
|
|
|
|
def test_retrieve_citations_raises_when_index_missing(settings, tmp_path):
|
|
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing"
|
|
|
|
with pytest.raises(RagIndexUnavailable):
|
|
retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])
|
|
|
|
|
|
def test_collect_source_chunks_requires_attachment4_extraction(monkeypatch, tmp_path):
|
|
source_dir = tmp_path / "sources"
|
|
source_dir.mkdir()
|
|
attachment4 = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc"
|
|
attachment4.write_bytes(b"legacy-doc")
|
|
|
|
def fail_extract(path):
|
|
raise RuntimeError("无法通过 LibreOffice 转换法规 .doc 材料")
|
|
|
|
monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fail_extract)
|
|
|
|
with pytest.raises(RuntimeError, match="附件 4"):
|
|
collect_source_chunks(source_dir)
|
|
|
|
|
|
def test_collect_source_chunks_excludes_demo_agent_materials(monkeypatch, tmp_path):
|
|
source_dir = tmp_path / "sources"
|
|
source_dir.mkdir()
|
|
demo_dir = source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent"
|
|
demo_dir.mkdir()
|
|
(demo_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.md").write_text("题目材料", encoding="utf-8")
|
|
(source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.docx").write_bytes(b"demo")
|
|
real_source = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc"
|
|
real_source.write_bytes(b"rule")
|
|
|
|
def fake_extract(path):
|
|
return "附件4 正文" if path == real_source else "不应被抽取"
|
|
|
|
monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fake_extract)
|
|
|
|
chunks = collect_source_chunks(source_dir)
|
|
|
|
assert chunks
|
|
assert all("模拟题二" not in chunk.metadata["source"] for chunk in chunks)
|
|
|
|
|
|
def test_build_chroma_index_reset_recreates_collection_without_deleting_index_dir(settings, monkeypatch, tmp_path):
|
|
settings.MEDIA_ROOT = tmp_path
|
|
persist_path = tmp_path / "chroma"
|
|
persist_path.mkdir()
|
|
stale_file = persist_path / "chroma.sqlite3"
|
|
stale_file.write_text("stale", encoding="utf-8")
|
|
source_dir = tmp_path / "sources"
|
|
source_dir.mkdir()
|
|
(source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8")
|
|
client_states = []
|
|
deleted_collections = []
|
|
|
|
class FakeCollection:
|
|
def upsert(self, **kwargs):
|
|
return None
|
|
|
|
class FakeClient:
|
|
def __init__(self, path):
|
|
client_states.append({"path": path, "stale_exists": stale_file.exists()})
|
|
|
|
def delete_collection(self, name):
|
|
deleted_collections.append(name)
|
|
|
|
def get_or_create_collection(self, name):
|
|
return FakeCollection()
|
|
|
|
class FakeSharedSystemClient:
|
|
@staticmethod
|
|
def clear_system_cache():
|
|
client_states.append({"path": "cache-cleared", "stale_exists": stale_file.exists()})
|
|
|
|
monkeypatch.setitem(sys.modules, "chromadb", type("FakeChromaModule", (), {"PersistentClient": FakeClient}))
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"chromadb.api.shared_system_client",
|
|
type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}),
|
|
)
|
|
|
|
count = build_chroma_index(
|
|
source_dir=source_dir,
|
|
embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts],
|
|
persist_path=persist_path,
|
|
collection_name="test",
|
|
reset=True,
|
|
)
|
|
|
|
assert count == 1
|
|
assert client_states == [
|
|
{"path": str(persist_path), "stale_exists": True},
|
|
{"path": "cache-cleared", "stale_exists": True},
|
|
{"path": str(persist_path), "stale_exists": True},
|
|
]
|
|
assert stale_file.exists()
|
|
assert deleted_collections == ["test"]
|
|
|
|
|
|
def test_build_chroma_index_reset_clears_bad_index_dir_after_chroma_cache_reset(settings, monkeypatch, tmp_path):
|
|
settings.MEDIA_ROOT = tmp_path
|
|
persist_path = tmp_path / "chroma"
|
|
persist_path.mkdir()
|
|
stale_file = persist_path / "chroma.sqlite3"
|
|
stale_file.write_text("stale", encoding="utf-8")
|
|
source_dir = tmp_path / "sources"
|
|
source_dir.mkdir()
|
|
(source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8")
|
|
events = []
|
|
|
|
class FakeCollection:
|
|
def upsert(self, **kwargs):
|
|
return None
|
|
|
|
class BrokenThenFreshClient:
|
|
attempts = 0
|
|
|
|
def __init__(self, path):
|
|
BrokenThenFreshClient.attempts += 1
|
|
events.append(("client", BrokenThenFreshClient.attempts, stale_file.exists()))
|
|
if BrokenThenFreshClient.attempts == 1:
|
|
raise ValueError("Could not connect to tenant default_tenant")
|
|
|
|
def get_or_create_collection(self, name):
|
|
return FakeCollection()
|
|
|
|
class FakeSharedSystemClient:
|
|
@staticmethod
|
|
def clear_system_cache():
|
|
events.append(("clear_cache", stale_file.exists()))
|
|
|
|
fake_chromadb = type(
|
|
"FakeChromaModule",
|
|
(),
|
|
{"PersistentClient": BrokenThenFreshClient},
|
|
)
|
|
monkeypatch.setitem(sys.modules, "chromadb", fake_chromadb)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"chromadb.api.shared_system_client",
|
|
type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}),
|
|
)
|
|
|
|
count = build_chroma_index(
|
|
source_dir=source_dir,
|
|
embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts],
|
|
persist_path=persist_path,
|
|
collection_name="test",
|
|
reset=True,
|
|
)
|
|
|
|
assert count == 1
|
|
assert events == [
|
|
("client", 1, True),
|
|
("clear_cache", True),
|
|
("client", 2, False),
|
|
]
|
|
assert not stale_file.exists()
|