mirror of https://github.com/Mai-with-u/MaiBot.git
fix:补全#1386号PR中关于embedding_store.py的相关基础函数
引入了一个 “dirty” 标志,用于跟踪嵌入存储(embedding store)是否需要重新构建 Faiss 索引;新增了 delete_items 方法,支持按 key 删除嵌入向量。 同时改进了 Faiss 索引的重建逻辑,在不必要时跳过重建操作;EmbeddingManager 也利用了这些增强功能,以优化索引管理流程。 另外还包含了一些 小规模重构,以及一个用于 文本哈希的静态方法。pull/1446/head
parent
35c16d2bf3
commit
8939a02d86
|
|
@ -104,7 +104,9 @@ class EmbeddingStore:
|
||||||
self.dir = dir_path
|
self.dir = dir_path
|
||||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
|
||||||
|
|
||||||
|
self.dirty = False # 标记是否有新增数据需要重建索引
|
||||||
|
|
||||||
# 多线程配置参数验证和设置
|
# 多线程配置参数验证和设置
|
||||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||||
|
|
@ -125,6 +127,11 @@ class EmbeddingStore:
|
||||||
self.faiss_index = None
|
self.faiss_index = None
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_texts(namespace: str, texts: List[str]) -> List[str]:
|
||||||
|
"""将原文计算为带前缀的键"""
|
||||||
|
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||||
# 创建新的事件循环并在完成后立即关闭
|
# 创建新的事件循环并在完成后立即关闭
|
||||||
|
|
@ -412,6 +419,7 @@ class EmbeddingStore:
|
||||||
item_hash = self.namespace + "-" + get_sha256(s)
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
if embedding: # 只有成功获取到嵌入才存入
|
if embedding: # 只有成功获取到嵌入才存入
|
||||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||||
|
self.dirty = True
|
||||||
else:
|
else:
|
||||||
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
||||||
|
|
||||||
|
|
@ -488,9 +496,17 @@ class EmbeddingStore:
|
||||||
self.build_faiss_index()
|
self.build_faiss_index()
|
||||||
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
|
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
|
||||||
self.save_to_file()
|
self.save_to_file()
|
||||||
|
self.dirty = False
|
||||||
|
|
||||||
def build_faiss_index(self) -> None:
|
def build_faiss_index(self) -> None:
|
||||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||||
|
# 空库直接跳过,清空索引映射
|
||||||
|
if not self.store:
|
||||||
|
self.idx2hash = {}
|
||||||
|
self.faiss_index = None
|
||||||
|
self.dirty = False
|
||||||
|
return
|
||||||
|
|
||||||
# 获取所有的embedding
|
# 获取所有的embedding
|
||||||
array = []
|
array = []
|
||||||
self.idx2hash = dict()
|
self.idx2hash = dict()
|
||||||
|
|
@ -498,11 +514,44 @@ class EmbeddingStore:
|
||||||
array.append(self.store[key].embedding)
|
array.append(self.store[key].embedding)
|
||||||
self.idx2hash[str(len(array) - 1)] = key
|
self.idx2hash[str(len(array) - 1)] = key
|
||||||
embeddings = np.array(array, dtype=np.float32)
|
embeddings = np.array(array, dtype=np.float32)
|
||||||
|
if embeddings.size == 0:
|
||||||
|
self.idx2hash = {}
|
||||||
|
self.faiss_index = None
|
||||||
|
self.dirty = False
|
||||||
|
return
|
||||||
# L2归一化
|
# L2归一化
|
||||||
faiss.normalize_L2(embeddings)
|
faiss.normalize_L2(embeddings)
|
||||||
# 构建索引
|
# 构建索引
|
||||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||||
self.faiss_index.add(embeddings)
|
self.faiss_index.add(embeddings)
|
||||||
|
self.dirty = False
|
||||||
|
|
||||||
|
def delete_items(self, hashes: List[str]) -> Tuple[int, int]:
|
||||||
|
"""删除指定键的嵌入并重建 idx2hash(不直接重建 faiss)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hashes: 需要删除的完整键列表(如 paragraph-xxx)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(deleted, skipped)
|
||||||
|
"""
|
||||||
|
deleted = 0
|
||||||
|
skipped = 0
|
||||||
|
for h in hashes:
|
||||||
|
if h in self.store:
|
||||||
|
self.store.pop(h)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
|
||||||
|
# 重新构建 idx2hash 映射
|
||||||
|
self.idx2hash = {}
|
||||||
|
for idx, key in enumerate(self.store.keys()):
|
||||||
|
self.idx2hash[str(idx)] = key
|
||||||
|
|
||||||
|
# 删除后标记 dirty,faiss 重建由上层统一调用
|
||||||
|
self.dirty = True
|
||||||
|
return deleted, skipped
|
||||||
|
|
||||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||||
"""搜索最相似的k个项,以余弦相似度为度量
|
"""搜索最相似的k个项,以余弦相似度为度量
|
||||||
|
|
@ -536,7 +585,7 @@ class EmbeddingStore:
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
def __init__(self, max_workers: int | None = None, chunk_size: int | None = None):
|
||||||
"""
|
"""
|
||||||
初始化EmbeddingManager
|
初始化EmbeddingManager
|
||||||
|
|
||||||
|
|
@ -544,6 +593,8 @@ class EmbeddingManager:
|
||||||
max_workers: 最大线程数
|
max_workers: 最大线程数
|
||||||
chunk_size: 每个线程处理的数据块大小
|
chunk_size: 每个线程处理的数据块大小
|
||||||
"""
|
"""
|
||||||
|
max_workers = max_workers if max_workers is not None else global_config.lpmm_knowledge.max_embedding_workers
|
||||||
|
chunk_size = chunk_size if chunk_size is not None else global_config.lpmm_knowledge.embedding_chunk_size
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
"paragraph", # type: ignore
|
"paragraph", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
|
@ -617,7 +668,19 @@ class EmbeddingManager:
|
||||||
self.relation_embedding_store.save_to_file()
|
self.relation_embedding_store.save_to_file()
|
||||||
|
|
||||||
def rebuild_faiss_index(self):
|
def rebuild_faiss_index(self):
|
||||||
"""重建Faiss索引(请在添加新数据后调用)"""
|
"""重建Faiss索引,新增数据后调用,带跳过逻辑"""
|
||||||
self.paragraphs_embedding_store.build_faiss_index()
|
|
||||||
self.entities_embedding_store.build_faiss_index()
|
def _rebuild_if_needed(store: EmbeddingStore):
|
||||||
self.relation_embedding_store.build_faiss_index()
|
if (
|
||||||
|
not store.dirty
|
||||||
|
and store.faiss_index is not None
|
||||||
|
and store.idx2hash is not None
|
||||||
|
and getattr(store.faiss_index, "ntotal", 0) == len(store.idx2hash) == len(store.store)
|
||||||
|
):
|
||||||
|
logger.info(f"{store.namespace} FaissIndex 已是最新,跳过重建")
|
||||||
|
return
|
||||||
|
store.build_faiss_index()
|
||||||
|
|
||||||
|
_rebuild_if_needed(self.paragraphs_embedding_store)
|
||||||
|
_rebuild_if_needed(self.entities_embedding_store)
|
||||||
|
_rebuild_if_needed(self.relation_embedding_store)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue