Files
DEMO-AGENT/tests/test_regulatory_rag.py

89 lines
3.1 KiB
Python

import pytest
from review_agent.regulatory_review.services.rag_citation import (
RagIndexUnavailable,
retrieve_citations,
)
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
from review_agent.regulatory_review.services.rag_index import chunk_text
from review_agent.regulatory_review.services.rag_index import collect_source_chunks
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
calls = []
class FakeResponse:
def raise_for_status(self):
return None
def json(self):
return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
def fake_post(url, headers, json, timeout):
calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
return FakeResponse()
monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post)
provider = SiliconFlowEmbeddingProvider(
api_key="secret",
base_url="https://api.siliconflow.cn/v1",
model="Qwen/Qwen3-Embedding-4B",
dimensions=1024,
)
assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]]
assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings"
assert calls[0]["headers"]["Authorization"] == "Bearer secret"
assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B"
assert calls[0]["json"]["dimensions"] == 1024
def test_chunk_text_preserves_source_metadata():
chunks = chunk_text(
"第一段法规内容。\n" * 20,
source="法规.doc",
chunk_size=30,
overlap=5,
)
assert len(chunks) > 1
assert chunks[0].metadata["source"] == "法规.doc"
assert chunks[0].text
def test_retrieve_citations_returns_placeholder_when_no_hits():
class EmptyCollection:
def query(self, query_embeddings, n_results):
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
citations = retrieve_citations(
"注册检验报告",
embedding_provider=lambda texts: [[0.1, 0.2]],
collection=EmptyCollection(),
)
assert citations[0]["source"] == "原文依据待补充"
def test_retrieve_citations_raises_when_index_missing(settings, tmp_path):
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing"
with pytest.raises(RagIndexUnavailable):
retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])
def test_collect_source_chunks_requires_attachment4_extraction(monkeypatch, tmp_path):
source_dir = tmp_path / "sources"
source_dir.mkdir()
attachment4 = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc"
attachment4.write_bytes(b"legacy-doc")
def fail_extract(path):
raise RuntimeError("无法通过 LibreOffice 转换法规 .doc 材料")
monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fail_extract)
with pytest.raises(RuntimeError, match="附件 4"):
collect_source_chunks(source_dir)