mirror of https://github.com/Mai-with-u/MaiBot.git
commit
6fdfec798a
|
|
@ -104,7 +104,9 @@ class EmbeddingStore:
|
|||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
|
||||
|
||||
self.dirty = False # 标记是否有新增数据需要重建索引
|
||||
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
|
|
@ -125,6 +127,11 @@ class EmbeddingStore:
|
|||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def hash_texts(namespace: str, texts: List[str]) -> List[str]:
|
||||
"""将原文计算为带前缀的键"""
|
||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
|
|
@ -412,6 +419,7 @@ class EmbeddingStore:
|
|||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if embedding: # 只有成功获取到嵌入才存入
|
||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||
self.dirty = True
|
||||
else:
|
||||
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
||||
|
||||
|
|
@ -488,9 +496,17 @@ class EmbeddingStore:
|
|||
self.build_faiss_index()
|
||||
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
|
||||
self.save_to_file()
|
||||
self.dirty = False
|
||||
|
||||
def build_faiss_index(self) -> None:
|
||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||
# 空库直接跳过,清空索引映射
|
||||
if not self.store:
|
||||
self.idx2hash = {}
|
||||
self.faiss_index = None
|
||||
self.dirty = False
|
||||
return
|
||||
|
||||
# 获取所有的embedding
|
||||
array = []
|
||||
self.idx2hash = dict()
|
||||
|
|
@ -498,11 +514,44 @@ class EmbeddingStore:
|
|||
array.append(self.store[key].embedding)
|
||||
self.idx2hash[str(len(array) - 1)] = key
|
||||
embeddings = np.array(array, dtype=np.float32)
|
||||
if embeddings.size == 0:
|
||||
self.idx2hash = {}
|
||||
self.faiss_index = None
|
||||
self.dirty = False
|
||||
return
|
||||
# L2归一化
|
||||
faiss.normalize_L2(embeddings)
|
||||
# 构建索引
|
||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||
self.faiss_index.add(embeddings)
|
||||
self.dirty = False
|
||||
|
||||
def delete_items(self, hashes: List[str]) -> Tuple[int, int]:
|
||||
"""删除指定键的嵌入并重建 idx2hash(不直接重建 faiss)
|
||||
|
||||
Args:
|
||||
hashes: 需要删除的完整键列表(如 paragraph-xxx)
|
||||
|
||||
Returns:
|
||||
(deleted, skipped)
|
||||
"""
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
for h in hashes:
|
||||
if h in self.store:
|
||||
self.store.pop(h)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
# 重新构建 idx2hash 映射
|
||||
self.idx2hash = {}
|
||||
for idx, key in enumerate(self.store.keys()):
|
||||
self.idx2hash[str(idx)] = key
|
||||
|
||||
# 删除后标记 dirty,faiss 重建由上层统一调用
|
||||
self.dirty = True
|
||||
return deleted, skipped
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
|
|
@ -536,7 +585,7 @@ class EmbeddingStore:
|
|||
|
||||
|
||||
class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(self, max_workers: int | None = None, chunk_size: int | None = None):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
|
|
@ -544,6 +593,8 @@ class EmbeddingManager:
|
|||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
"""
|
||||
max_workers = max_workers if max_workers is not None else global_config.lpmm_knowledge.max_embedding_workers
|
||||
chunk_size = chunk_size if chunk_size is not None else global_config.lpmm_knowledge.embedding_chunk_size
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
"paragraph", # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
|
|
@ -617,7 +668,19 @@ class EmbeddingManager:
|
|||
self.relation_embedding_store.save_to_file()
|
||||
|
||||
def rebuild_faiss_index(self):
|
||||
"""重建Faiss索引(请在添加新数据后调用)"""
|
||||
self.paragraphs_embedding_store.build_faiss_index()
|
||||
self.entities_embedding_store.build_faiss_index()
|
||||
self.relation_embedding_store.build_faiss_index()
|
||||
"""重建Faiss索引,新增数据后调用,带跳过逻辑"""
|
||||
|
||||
def _rebuild_if_needed(store: EmbeddingStore):
|
||||
if (
|
||||
not store.dirty
|
||||
and store.faiss_index is not None
|
||||
and store.idx2hash is not None
|
||||
and getattr(store.faiss_index, "ntotal", 0) == len(store.idx2hash) == len(store.store)
|
||||
):
|
||||
logger.info(f"{store.namespace} FaissIndex 已是最新,跳过重建")
|
||||
return
|
||||
store.build_faiss_index()
|
||||
|
||||
_rebuild_if_needed(self.paragraphs_embedding_store)
|
||||
_rebuild_if_needed(self.entities_embedding_store)
|
||||
_rebuild_if_needed(self.relation_embedding_store)
|
||||
|
|
|
|||
Loading…
Reference in New Issue