feat(regulatory): 增加本地法规RAG索引检索
This commit is contained in:
72
tests/test_regulatory_rag.py
Normal file
72
tests/test_regulatory_rag.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
|
||||
from review_agent.regulatory_review.services.rag_citation import (
|
||||
RagIndexUnavailable,
|
||||
retrieve_citations,
|
||||
)
|
||||
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
|
||||
from review_agent.regulatory_review.services.rag_index import chunk_text
|
||||
|
||||
|
||||
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeResponse:
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post)
|
||||
|
||||
provider = SiliconFlowEmbeddingProvider(
|
||||
api_key="secret",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
model="Qwen/Qwen3-Embedding-4B",
|
||||
dimensions=1024,
|
||||
)
|
||||
|
||||
assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]]
|
||||
assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings"
|
||||
assert calls[0]["headers"]["Authorization"] == "Bearer secret"
|
||||
assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B"
|
||||
assert calls[0]["json"]["dimensions"] == 1024
|
||||
|
||||
|
||||
def test_chunk_text_preserves_source_metadata():
|
||||
chunks = chunk_text(
|
||||
"第一段法规内容。\n" * 20,
|
||||
source="法规.doc",
|
||||
chunk_size=30,
|
||||
overlap=5,
|
||||
)
|
||||
|
||||
assert len(chunks) > 1
|
||||
assert chunks[0].metadata["source"] == "法规.doc"
|
||||
assert chunks[0].text
|
||||
|
||||
|
||||
def test_retrieve_citations_returns_placeholder_when_no_hits():
|
||||
class EmptyCollection:
|
||||
def query(self, query_embeddings, n_results):
|
||||
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
|
||||
|
||||
citations = retrieve_citations(
|
||||
"注册检验报告",
|
||||
embedding_provider=lambda texts: [[0.1, 0.2]],
|
||||
collection=EmptyCollection(),
|
||||
)
|
||||
|
||||
assert citations[0]["source"] == "原文依据待补充"
|
||||
|
||||
|
||||
def test_retrieve_citations_raises_when_index_missing(settings, tmp_path):
|
||||
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing"
|
||||
|
||||
with pytest.raises(RagIndexUnavailable):
|
||||
retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])
|
||||
Reference in New Issue
Block a user