From 8939a02d861c0f1c5a91eba4d1797dbba40f3c89 Mon Sep 17 00:00:00 2001 From: DawnARC Date: Fri, 19 Dec 2025 18:03:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E8=A1=A5=E5=85=A8#1386=E5=8F=B7PR=E4=B8=AD?= =?UTF-8?q?=E5=85=B3=E4=BA=8Eembedding=5Fstore.py=E7=9A=84=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E5=9F=BA=E7=A1=80=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入了一个 “dirty” 标志,用于跟踪嵌入存储(embedding store)是否需要重新构建 Faiss 索引;新增了 delete_items 方法,支持按 key 删除嵌入向量。 同时改进了 Faiss 索引的重建逻辑,在不必要时跳过重建操作;EmbeddingManager 也利用了这些增强功能,以优化索引管理流程。 另外还包含了一些 小规模重构,以及一个用于 文本哈希的静态方法。 --- src/chat/knowledge/embedding_store.py | 75 ++++++++++++++++++++++++--- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 768373cf..9c460a0d 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -104,7 +104,9 @@ class EmbeddingStore: self.dir = dir_path self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.index_file_path = f"{dir_path}/{namespace}.index" - self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" + self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json" + + self.dirty = False # 标记是否有新增数据需要重建索引 # 多线程配置参数验证和设置 self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers)) @@ -125,6 +127,11 @@ class EmbeddingStore: self.faiss_index = None self.idx2hash = None + @staticmethod + def hash_texts(namespace: str, texts: List[str]) -> List[str]: + """将原文计算为带前缀的键""" + return [f"{namespace}-{get_sha256(t)}" for t in texts] + def _get_embedding(self, s: str) -> List[float]: """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" # 创建新的事件循环并在完成后立即关闭 @@ -412,6 +419,7 @@ class EmbeddingStore: item_hash = self.namespace + "-" + get_sha256(s) if embedding: # 只有成功获取到嵌入才存入 self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + self.dirty = True else: logger.warning(f"跳过存储失败的嵌入: {s[:50]}...") @@ -488,9 +496,17 @@ class EmbeddingStore: self.build_faiss_index() logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功") self.save_to_file() + self.dirty = False def build_faiss_index(self) -> None: """重新构建Faiss索引,以余弦相似度为度量""" + # 空库直接跳过,清空索引映射 + if not self.store: + self.idx2hash = {} + self.faiss_index = None + self.dirty = False + return + # 获取所有的embedding array = [] self.idx2hash = dict() @@ -498,11 +514,44 @@ class EmbeddingStore: array.append(self.store[key].embedding) self.idx2hash[str(len(array) - 1)] = key embeddings = np.array(array, dtype=np.float32) + if embeddings.size == 0: + self.idx2hash = {} + self.faiss_index = None + self.dirty = False + return # L2归一化 faiss.normalize_L2(embeddings) # 构建索引 self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension) self.faiss_index.add(embeddings) + self.dirty = False + + def delete_items(self, hashes: List[str]) -> Tuple[int, int]: + """删除指定键的嵌入并重建 idx2hash(不直接重建 faiss) + + Args: + hashes: 需要删除的完整键列表(如 paragraph-xxx) + + Returns: + (deleted, skipped) + """ + deleted = 0 + skipped = 0 + for h in hashes: + if h in self.store: + self.store.pop(h) + deleted += 1 + else: + skipped += 1 + + # 重新构建 idx2hash 映射 + self.idx2hash = {} + for idx, key in enumerate(self.store.keys()): + self.idx2hash[str(idx)] = key + + # 删除后标记 dirty,faiss 重建由上层统一调用 + self.dirty = True + return deleted, skipped def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: """搜索最相似的k个项,以余弦相似度为度量 @@ -536,7 +585,7 @@ class EmbeddingStore: class EmbeddingManager: - def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): + def __init__(self, max_workers: int | None = None, chunk_size: int | None = None): """ 初始化EmbeddingManager @@ -544,6 +593,8 @@ class EmbeddingManager: max_workers: 最大线程数 chunk_size: 每个线程处理的数据块大小 """ + max_workers = max_workers if max_workers is not None else global_config.lpmm_knowledge.max_embedding_workers + chunk_size = chunk_size if chunk_size is not None else global_config.lpmm_knowledge.embedding_chunk_size self.paragraphs_embedding_store = EmbeddingStore( "paragraph", # type: ignore EMBEDDING_DATA_DIR_STR, @@ -617,7 +668,19 @@ class EmbeddingManager: self.relation_embedding_store.save_to_file() def rebuild_faiss_index(self): - """重建Faiss索引(请在添加新数据后调用)""" - self.paragraphs_embedding_store.build_faiss_index() - self.entities_embedding_store.build_faiss_index() - self.relation_embedding_store.build_faiss_index() + """重建Faiss索引,新增数据后调用,带跳过逻辑""" + + def _rebuild_if_needed(store: EmbeddingStore): + if ( + not store.dirty + and store.faiss_index is not None + and store.idx2hash is not None + and getattr(store.faiss_index, "ntotal", 0) == len(store.idx2hash) == len(store.store) + ): + logger.info(f"{store.namespace} FaissIndex 已是最新,跳过重建") + return + store.build_faiss_index() + + _rebuild_if_needed(self.paragraphs_embedding_store) + _rebuild_if_needed(self.entities_embedding_store) + _rebuild_if_needed(self.relation_embedding_store)