import sys import pytest from review_agent.regulatory_review.services.rag_citation import ( RagIndexUnavailable, retrieve_citations, ) from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider from review_agent.regulatory_review.services.rag_index import chunk_text from review_agent.regulatory_review.services.rag_index import collect_source_chunks from review_agent.regulatory_review.services.rag_index import build_chroma_index def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch): calls = [] class FakeResponse: def raise_for_status(self): return None def json(self): return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]} def fake_post(url, headers, json, timeout): calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout}) return FakeResponse() monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post) provider = SiliconFlowEmbeddingProvider( api_key="secret", base_url="https://api.siliconflow.cn/v1", model="Qwen/Qwen3-Embedding-4B", dimensions=1024, ) assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]] assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings" assert calls[0]["headers"]["Authorization"] == "Bearer secret" assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B" assert calls[0]["json"]["dimensions"] == 1024 def test_chunk_text_preserves_source_metadata(): chunks = chunk_text( "第一段法规内容。\n" * 20, source="法规.doc", chunk_size=30, overlap=5, ) assert len(chunks) > 1 assert chunks[0].metadata["source"] == "法规.doc" assert chunks[0].text def test_retrieve_citations_returns_placeholder_when_no_hits(): class EmptyCollection: def query(self, query_embeddings, n_results): return {"documents": [[]], "metadatas": [[]], "distances": [[]]} citations = retrieve_citations( "注册检验报告", embedding_provider=lambda texts: [[0.1, 0.2]], collection=EmptyCollection(), ) assert citations[0]["source"] == "原文依据待补充" def test_retrieve_citations_raises_when_index_missing(settings, tmp_path): settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing" with pytest.raises(RagIndexUnavailable): retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]]) def test_collect_source_chunks_requires_attachment4_extraction(monkeypatch, tmp_path): source_dir = tmp_path / "sources" source_dir.mkdir() attachment4 = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc" attachment4.write_bytes(b"legacy-doc") def fail_extract(path): raise RuntimeError("无法通过 LibreOffice 转换法规 .doc 材料") monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fail_extract) with pytest.raises(RuntimeError, match="附件 4"): collect_source_chunks(source_dir) def test_collect_source_chunks_excludes_demo_agent_materials(monkeypatch, tmp_path): source_dir = tmp_path / "sources" source_dir.mkdir() demo_dir = source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent" demo_dir.mkdir() (demo_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.md").write_text("题目材料", encoding="utf-8") (source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.docx").write_bytes(b"demo") real_source = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc" real_source.write_bytes(b"rule") def fake_extract(path): return "附件4 正文" if path == real_source else "不应被抽取" monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fake_extract) chunks = collect_source_chunks(source_dir) assert chunks assert all("模拟题二" not in chunk.metadata["source"] for chunk in chunks) def test_build_chroma_index_reset_recreates_collection_without_deleting_index_dir(settings, monkeypatch, tmp_path): settings.MEDIA_ROOT = tmp_path persist_path = tmp_path / "chroma" persist_path.mkdir() stale_file = persist_path / "chroma.sqlite3" stale_file.write_text("stale", encoding="utf-8") source_dir = tmp_path / "sources" source_dir.mkdir() (source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8") client_states = [] deleted_collections = [] class FakeCollection: def upsert(self, **kwargs): return None class FakeClient: def __init__(self, path): client_states.append({"path": path, "stale_exists": stale_file.exists()}) def delete_collection(self, name): deleted_collections.append(name) def get_or_create_collection(self, name): return FakeCollection() class FakeSharedSystemClient: @staticmethod def clear_system_cache(): client_states.append({"path": "cache-cleared", "stale_exists": stale_file.exists()}) monkeypatch.setitem(sys.modules, "chromadb", type("FakeChromaModule", (), {"PersistentClient": FakeClient})) monkeypatch.setitem( sys.modules, "chromadb.api.shared_system_client", type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}), ) count = build_chroma_index( source_dir=source_dir, embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts], persist_path=persist_path, collection_name="test", reset=True, ) assert count == 1 assert client_states == [ {"path": str(persist_path), "stale_exists": True}, {"path": "cache-cleared", "stale_exists": True}, {"path": str(persist_path), "stale_exists": True}, ] assert stale_file.exists() assert deleted_collections == ["test"] def test_build_chroma_index_reset_clears_bad_index_dir_after_chroma_cache_reset(settings, monkeypatch, tmp_path): settings.MEDIA_ROOT = tmp_path persist_path = tmp_path / "chroma" persist_path.mkdir() stale_file = persist_path / "chroma.sqlite3" stale_file.write_text("stale", encoding="utf-8") source_dir = tmp_path / "sources" source_dir.mkdir() (source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8") events = [] class FakeCollection: def upsert(self, **kwargs): return None class BrokenThenFreshClient: attempts = 0 def __init__(self, path): BrokenThenFreshClient.attempts += 1 events.append(("client", BrokenThenFreshClient.attempts, stale_file.exists())) if BrokenThenFreshClient.attempts == 1: raise ValueError("Could not connect to tenant default_tenant") def get_or_create_collection(self, name): return FakeCollection() class FakeSharedSystemClient: @staticmethod def clear_system_cache(): events.append(("clear_cache", stale_file.exists())) fake_chromadb = type( "FakeChromaModule", (), {"PersistentClient": BrokenThenFreshClient}, ) monkeypatch.setitem(sys.modules, "chromadb", fake_chromadb) monkeypatch.setitem( sys.modules, "chromadb.api.shared_system_client", type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}), ) count = build_chroma_index( source_dir=source_dir, embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts], persist_path=persist_path, collection_name="test", reset=True, ) assert count == 1 assert events == [ ("client", 1, True), ("clear_cache", True), ("client", 2, False), ] assert not stale_file.exists()