73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
import pytest
|
|
|
|
from review_agent.regulatory_review.services.rag_citation import (
|
|
RagIndexUnavailable,
|
|
retrieve_citations,
|
|
)
|
|
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
|
|
from review_agent.regulatory_review.services.rag_index import chunk_text
|
|
|
|
|
|
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
|
|
calls = []
|
|
|
|
class FakeResponse:
|
|
def raise_for_status(self):
|
|
return None
|
|
|
|
def json(self):
|
|
return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
|
|
|
|
def fake_post(url, headers, json, timeout):
|
|
calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
|
|
return FakeResponse()
|
|
|
|
monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post)
|
|
|
|
provider = SiliconFlowEmbeddingProvider(
|
|
api_key="secret",
|
|
base_url="https://api.siliconflow.cn/v1",
|
|
model="Qwen/Qwen3-Embedding-4B",
|
|
dimensions=1024,
|
|
)
|
|
|
|
assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]]
|
|
assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings"
|
|
assert calls[0]["headers"]["Authorization"] == "Bearer secret"
|
|
assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B"
|
|
assert calls[0]["json"]["dimensions"] == 1024
|
|
|
|
|
|
def test_chunk_text_preserves_source_metadata():
|
|
chunks = chunk_text(
|
|
"第一段法规内容。\n" * 20,
|
|
source="法规.doc",
|
|
chunk_size=30,
|
|
overlap=5,
|
|
)
|
|
|
|
assert len(chunks) > 1
|
|
assert chunks[0].metadata["source"] == "法规.doc"
|
|
assert chunks[0].text
|
|
|
|
|
|
def test_retrieve_citations_returns_placeholder_when_no_hits():
|
|
class EmptyCollection:
|
|
def query(self, query_embeddings, n_results):
|
|
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
|
|
|
|
citations = retrieve_citations(
|
|
"注册检验报告",
|
|
embedding_provider=lambda texts: [[0.1, 0.2]],
|
|
collection=EmptyCollection(),
|
|
)
|
|
|
|
assert citations[0]["source"] == "原文依据待补充"
|
|
|
|
|
|
def test_retrieve_citations_raises_when_index_missing(settings, tmp_path):
|
|
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing"
|
|
|
|
with pytest.raises(RagIndexUnavailable):
|
|
retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])
|