Merge pull request #1446 from A-Dawn/dev

fix:补全#1386号PR中关于embedding_store.py的相关基础函数
pull/1450/head
墨梓柒 2025-12-19 18:19:38 +08:00 committed by GitHub
commit 6fdfec798a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 69 additions and 6 deletions

View File

@ -104,7 +104,9 @@ class EmbeddingStore:
self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
self.dirty = False # 标记是否有新增数据需要重建索引
# 多线程配置参数验证和设置
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
@ -125,6 +127,11 @@ class EmbeddingStore:
self.faiss_index = None
self.idx2hash = None
@staticmethod
def hash_texts(namespace: str, texts: List[str]) -> List[str]:
"""将原文计算为带前缀的键"""
return [f"{namespace}-{get_sha256(t)}" for t in texts]
def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
@ -412,6 +419,7 @@ class EmbeddingStore:
item_hash = self.namespace + "-" + get_sha256(s)
if embedding: # 只有成功获取到嵌入才存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
self.dirty = True
else:
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
@ -488,9 +496,17 @@ class EmbeddingStore:
self.build_faiss_index()
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
self.save_to_file()
self.dirty = False
def build_faiss_index(self) -> None:
"""重新构建Faiss索引以余弦相似度为度量"""
# 空库直接跳过,清空索引映射
if not self.store:
self.idx2hash = {}
self.faiss_index = None
self.dirty = False
return
# 获取所有的embedding
array = []
self.idx2hash = dict()
@ -498,11 +514,44 @@ class EmbeddingStore:
array.append(self.store[key].embedding)
self.idx2hash[str(len(array) - 1)] = key
embeddings = np.array(array, dtype=np.float32)
if embeddings.size == 0:
self.idx2hash = {}
self.faiss_index = None
self.dirty = False
return
# L2归一化
faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
self.faiss_index.add(embeddings)
self.dirty = False
def delete_items(self, hashes: List[str]) -> Tuple[int, int]:
"""删除指定键的嵌入并重建 idx2hash不直接重建 faiss
Args:
hashes: 需要删除的完整键列表 paragraph-xxx
Returns:
(deleted, skipped)
"""
deleted = 0
skipped = 0
for h in hashes:
if h in self.store:
self.store.pop(h)
deleted += 1
else:
skipped += 1
# 重新构建 idx2hash 映射
self.idx2hash = {}
for idx, key in enumerate(self.store.keys()):
self.idx2hash[str(idx)] = key
# 删除后标记 dirtyfaiss 重建由上层统一调用
self.dirty = True
return deleted, skipped
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量
@ -536,7 +585,7 @@ class EmbeddingStore:
class EmbeddingManager:
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
def __init__(self, max_workers: int | None = None, chunk_size: int | None = None):
"""
初始化EmbeddingManager
@ -544,6 +593,8 @@ class EmbeddingManager:
max_workers: 最大线程数
chunk_size: 每个线程处理的数据块大小
"""
max_workers = max_workers if max_workers is not None else global_config.lpmm_knowledge.max_embedding_workers
chunk_size = chunk_size if chunk_size is not None else global_config.lpmm_knowledge.embedding_chunk_size
self.paragraphs_embedding_store = EmbeddingStore(
"paragraph", # type: ignore
EMBEDDING_DATA_DIR_STR,
@ -617,7 +668,19 @@ class EmbeddingManager:
self.relation_embedding_store.save_to_file()
def rebuild_faiss_index(self):
"""重建Faiss索引请在添加新数据后调用"""
self.paragraphs_embedding_store.build_faiss_index()
self.entities_embedding_store.build_faiss_index()
self.relation_embedding_store.build_faiss_index()
"""重建Faiss索引新增数据后调用带跳过逻辑"""
def _rebuild_if_needed(store: EmbeddingStore):
if (
not store.dirty
and store.faiss_index is not None
and store.idx2hash is not None
and getattr(store.faiss_index, "ntotal", 0) == len(store.idx2hash) == len(store.store)
):
logger.info(f"{store.namespace} FaissIndex 已是最新,跳过重建")
return
store.build_faiss_index()
_rebuild_if_needed(self.paragraphs_embedding_store)
_rebuild_if_needed(self.entities_embedding_store)
_rebuild_if_needed(self.relation_embedding_store)