fix(kb): 完善知识库入库和重建索引
This commit is contained in:
@@ -23,6 +23,8 @@ from .rag_embedding import EmbeddingFunction
|
||||
|
||||
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
||||
|
||||
EXCLUDED_SOURCE_KEYWORDS = ("模拟题二", "试剂盒临床注册文件准备与审核Agent")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextChunk:
|
||||
@@ -227,6 +229,8 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||||
for path in sorted(source_dir.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
if is_excluded_source_path(path.relative_to(source_dir)):
|
||||
continue
|
||||
try:
|
||||
text = extract_text_from_path(path)
|
||||
except RuntimeError as exc:
|
||||
@@ -238,6 +242,11 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||||
return chunks
|
||||
|
||||
|
||||
def is_excluded_source_path(path: Path | str) -> bool:
|
||||
normalized = str(path)
|
||||
return any(keyword in normalized for keyword in EXCLUDED_SOURCE_KEYWORDS)
|
||||
|
||||
|
||||
def _is_attachment4(path: Path) -> bool:
|
||||
normalized = path.name.replace(" ", "")
|
||||
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
|
||||
@@ -249,6 +258,7 @@ def build_chroma_index(
|
||||
embedding_provider: EmbeddingFunction,
|
||||
persist_path: Path | None = None,
|
||||
collection_name: str | None = None,
|
||||
reset: bool = False,
|
||||
) -> int:
|
||||
try:
|
||||
import chromadb
|
||||
@@ -259,7 +269,22 @@ def build_chroma_index(
|
||||
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
||||
persist_path.mkdir(parents=True, exist_ok=True)
|
||||
chunks = collect_source_chunks(source_dir)
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
try:
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
except Exception:
|
||||
if not reset:
|
||||
raise
|
||||
clear_chroma_system_cache()
|
||||
clear_chroma_index_dir(persist_path)
|
||||
persist_path.mkdir(parents=True, exist_ok=True)
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
if reset:
|
||||
try:
|
||||
client.delete_collection(collection_name)
|
||||
clear_chroma_system_cache()
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
except Exception:
|
||||
pass
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
if not chunks:
|
||||
return 0
|
||||
@@ -276,3 +301,22 @@ def build_chroma_index(
|
||||
embeddings=embeddings,
|
||||
)
|
||||
return len(chunks)
|
||||
|
||||
|
||||
def clear_chroma_index_dir(persist_path: Path | str | None = None) -> None:
|
||||
chroma_path = Path(persist_path or settings.REGULATORY_RAG_CHROMA_PATH).resolve()
|
||||
media_root = Path(settings.MEDIA_ROOT).resolve()
|
||||
try:
|
||||
chroma_path.relative_to(media_root)
|
||||
except ValueError as exc:
|
||||
raise RuntimeError("法规 RAG 索引目录必须位于 MEDIA_ROOT 内。") from exc
|
||||
if chroma_path.exists():
|
||||
shutil.rmtree(chroma_path)
|
||||
|
||||
|
||||
def clear_chroma_system_cache() -> None:
|
||||
try:
|
||||
from chromadb.api.shared_system_client import SharedSystemClient
|
||||
except Exception:
|
||||
return
|
||||
SharedSystemClient.clear_system_cache()
|
||||
|
||||
Reference in New Issue
Block a user