feat(regulatory): 增加本地法规RAG索引检索

This commit is contained in:
2026-06-07 00:30:53 +08:00
parent 2a4dd6cfab
commit 26490f7c46
7 changed files with 411 additions and 0 deletions

View File

@@ -0,0 +1,72 @@
import pytest
from review_agent.regulatory_review.services.rag_citation import (
RagIndexUnavailable,
retrieve_citations,
)
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
from review_agent.regulatory_review.services.rag_index import chunk_text
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
calls = []
class FakeResponse:
def raise_for_status(self):
return None
def json(self):
return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
def fake_post(url, headers, json, timeout):
calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
return FakeResponse()
monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post)
provider = SiliconFlowEmbeddingProvider(
api_key="secret",
base_url="https://api.siliconflow.cn/v1",
model="Qwen/Qwen3-Embedding-4B",
dimensions=1024,
)
assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]]
assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings"
assert calls[0]["headers"]["Authorization"] == "Bearer secret"
assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B"
assert calls[0]["json"]["dimensions"] == 1024
def test_chunk_text_preserves_source_metadata():
chunks = chunk_text(
"第一段法规内容。\n" * 20,
source="法规.doc",
chunk_size=30,
overlap=5,
)
assert len(chunks) > 1
assert chunks[0].metadata["source"] == "法规.doc"
assert chunks[0].text
def test_retrieve_citations_returns_placeholder_when_no_hits():
class EmptyCollection:
def query(self, query_embeddings, n_results):
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
citations = retrieve_citations(
"注册检验报告",
embedding_provider=lambda texts: [[0.1, 0.2]],
collection=EmptyCollection(),
)
assert citations[0]["source"] == "原文依据待补充"
def test_retrieve_citations_raises_when_index_missing(settings, tmp_path):
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing"
with pytest.raises(RagIndexUnavailable):
retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])