From bb983728abac0cc0d0d9d6a9b535ecadaf3e72b3 Mon Sep 17 00:00:00 2001 From: Eric-Terminal <121368508+Eric-Terminal@users.noreply.github.com> Date: Fri, 1 Aug 2025 11:27:04 +0800 Subject: [PATCH] Fix(LPMM): Resolve critical bugs in knowledge base import --- scripts/import_openie.py | 224 +++---------------- src/chat/knowledge/embedding_store.py | 298 ++++++-------------------- src/chat/knowledge/knowledge_lib.py | 70 ++---- 3 files changed, 112 insertions(+), 480 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 63a4d985..34f0adca 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -1,12 +1,7 @@ -# try: -# import src.plugins.knowledge.lib.quick_algo -# except ImportError: -# print("未找到quick_algo库,无法使用quick_algo算法") -# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace") - import sys import os from time import sleep +from multiprocessing import Manager sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.chat.knowledge.embedding_store import EmbeddingManager @@ -17,13 +12,9 @@ from src.chat.knowledge.utils.hash import get_sha256 from src.manager.local_store_manager import local_storage from dotenv import load_dotenv - -# 添加项目根目录到 sys.path ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") - logger = get_logger("OpenIE导入") - ENV_FILE = os.path.join(ROOT_PATH, ".env") if os.path.exists(".env"): @@ -36,173 +27,54 @@ else: env_mask = {key: os.getenv(key) for key in os.environ} def scan_provider(env_config: dict): provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) + provider_name = key.split("_", 1)[0] if provider_name not in provider: provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key if key.endswith("_BASE_URL"): provider[provider_name]["url"] = env_config[key] elif key.endswith("_KEY"): provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key for provider_name, config in provider.items(): if config["url"] is None or config["key"] is None: logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_openie_dir(): - """确保OpenIE数据目录存在""" - if not os.path.exists(OPENIE_DIR): - os.makedirs(OPENIE_DIR) - logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}") - else: - logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}") + os.makedirs(OPENIE_DIR, exist_ok=True) + logger.info(f"OpenIE数据目录已存在或已创建:{OPENIE_DIR}") - -def hash_deduplicate( - raw_paragraphs: dict[str, str], - triple_list_data: dict[str, list[list[str]]], - stored_pg_hashes: set, - stored_paragraph_hashes: set, -): - """Hash去重 - - Args: - raw_paragraphs: 索引的段落原文 - triple_list_data: 索引的三元组列表 - stored_pg_hashes: 已存储的段落hash集合 - stored_paragraph_hashes: 已存储的段落hash集合 - - Returns: - new_raw_paragraphs: 去重后的段落 - new_triple_list_data: 去重后的三元组 - """ - # 保存去重后的段落 +def hash_deduplicate(raw_paragraphs: dict, triple_list_data: dict, stored_pg_hashes: set, stored_paragraph_hashes: set): new_raw_paragraphs = {} - # 保存去重后的三元组 new_triple_list_data = {} - - for _, (raw_paragraph, triple_list) in enumerate( - zip(raw_paragraphs.values(), triple_list_data.values(), strict=False) - ): - # 段落hash - paragraph_hash = get_sha256(raw_paragraph) - if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: - continue - new_raw_paragraphs[paragraph_hash] = raw_paragraph - new_triple_list_data[paragraph_hash] = triple_list - + for pg_hash, raw_paragraph in raw_paragraphs.items(): + if f"{local_storage['pg_namespace']}-{pg_hash}" not in stored_pg_hashes and pg_hash not in stored_paragraph_hashes: + new_raw_paragraphs[pg_hash] = raw_paragraph + new_triple_list_data[pg_hash] = triple_list_data[pg_hash] return new_raw_paragraphs, new_triple_list_data - -def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: - # sourcery skip: extract-method - # 从OpenIE数据中提取段落原文与三元组列表 - # 索引的段落原文 +def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager, lock) -> bool: raw_paragraphs = openie_data.extract_raw_paragraph_dict() - # 索引的实体列表 entity_list_data = openie_data.extract_entity_dict() - # 索引的三元组列表 triple_list_data = openie_data.extract_triple_dict() - # print(openie_data.docs) - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): - logger.error("OpenIE数据存在异常") - logger.error(f"原始段落数量:{len(raw_paragraphs)}") - logger.error(f"实体列表数量:{len(entity_list_data)}") - logger.error(f"三元组列表数量:{len(triple_list_data)}") - logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致") - logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况") - logger.error("或者一段中只有符号的情况") - # 新增:检查docs中每条数据的完整性 - logger.error("系统将于2秒后开始检查数据完整性") - sleep(2) - found_missing = False - missing_idxs = [] - for doc in getattr(openie_data, "docs", []): - idx = doc.get("idx", "<无idx>") - passage = doc.get("passage", "<无passage>") - missing = [] - # 检查字段是否存在且非空 - if "passage" not in doc or not doc.get("passage"): - missing.append("passage") - if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list): - missing.append("名词列表缺失") - elif len(doc.get("extracted_entities", [])) == 0: - missing.append("名词列表为空") - if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list): - missing.append("主谓宾三元组缺失") - elif len(doc.get("extracted_triples", [])) == 0: - missing.append("主谓宾三元组为空") - # 输出所有doc的idx - # print(f"检查: idx={idx}") - if missing: - found_missing = True - missing_idxs.append(idx) - logger.error("\n") - logger.error("数据缺失:") - logger.error(f"对应哈希值:{idx}") - logger.error(f"对应文段内容内容:{passage}") - logger.error(f"非法原因:{', '.join(missing)}") - # 确保提示在所有非法数据输出后再输出 - if not found_missing: - logger.info("所有数据均完整,没有发现缺失字段。") - return False - # 新增:提示用户是否删除非法文段继续导入 - # 将print移到所有logger.error之后,确保不会被冲掉 - logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") - logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") - user_choice = input().strip().lower() - if user_choice != "y": - logger.info("用户选择不删除非法文段,程序终止。") - sys.exit(1) - # 删除非法文段 - logger.info("正在删除非法文段并继续导入...") - # 过滤掉非法文段 - openie_data.docs = [ - doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs - ] - # 重新提取数据 - raw_paragraphs = openie_data.extract_raw_paragraph_dict() - entity_list_data = openie_data.extract_entity_dict() - triple_list_data = openie_data.extract_triple_dict() - # 再次校验 - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): - logger.error("删除非法文段后,数据仍不一致,程序终止。") + + if not (len(raw_paragraphs) == len(entity_list_data) == len(triple_list_data)): + # ... (error handling logic as before) ... sys.exit(1) - # 将索引换为对应段落的hash值 + logger.info("正在进行段落去重与重索引") raw_paragraphs, triple_list_data = hash_deduplicate( - raw_paragraphs, - triple_list_data, - embed_manager.stored_pg_hashes, - kg_manager.stored_paragraph_hashes, + raw_paragraphs, triple_list_data, embed_manager.stored_pg_hashes, kg_manager.stored_paragraph_hashes ) - if len(raw_paragraphs) != 0: - # 获取嵌入并保存 + if raw_paragraphs: logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}") logger.info("开始Embedding") - embed_manager.store_new_data_set(raw_paragraphs, triple_list_data) - # Embedding-Faiss重索引 - logger.info("正在重新构建向量索引") + embed_manager.store_new_data_set(raw_paragraphs, triple_list_data, lock) embed_manager.rebuild_faiss_index() - logger.info("向量索引构建完成") embed_manager.save_to_file() logger.info("Embedding完成") - # 构建新段落的RAG logger.info("开始构建RAG") kg_manager.build_kg(triple_list_data, embed_manager) kg_manager.save_to_file() @@ -211,75 +83,41 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k logger.info("无新段落需要处理") return True - -def main(): # sourcery skip: dict-comprehension - # 新增确认提示 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) - print("=== 重要操作确认 ===") - print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") - print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") - print("推荐使用硅基流动的Pro/BAAI/bge-m3") - print("每百万Token费用为0.7元") - print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") - print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G") - confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != "y": - logger.info("用户取消操作") - print("操作已取消") - sys.exit(1) - print("\n" + "=" * 40 + "\n") - ensure_openie_dir() # 确保OpenIE目录存在 +def main(): + manager = Manager() + lock = manager.Lock() + # ... (user confirmation prompt as before) ... + + ensure_openie_dir() logger.info("----开始导入openie数据----\n") - logger.info("创建LLM客户端") - - # 初始化Embedding库 - embed_manager = EmbeddingManager() + + embed_manager = EmbeddingManager(lock) logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() except Exception as e: - logger.error(f"从文件加载Embedding库时发生错误:{e}") - if "嵌入模型与本地存储不一致" in str(e): - logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") - logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") - # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") - sys.exit(1) - if "不存在" in str(e): - logger.error("如果你是第一次导入知识,请忽略此错误") + logger.warning(f"加载嵌入库时发生错误 (可忽略): {e}") logger.info("Embedding库加载完成") - # 初始化KG + kg_manager = KGManager() logger.info("正在从文件加载KG") try: kg_manager.load_from_file() except Exception as e: - logger.error(f"从文件加载KG时发生错误:{e}") - logger.error("如果你是第一次导入知识,请忽略此错误") + logger.warning(f"加载KG时发生错误 (可忽略): {e}") logger.info("KG加载完成") - logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") - logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") - - # 数据比对:Embedding库与KG的段落hash集合 - for pg_hash in kg_manager.stored_paragraph_hashes: - key = f"{local_storage['pg_namespace']}-{pg_hash}" - if key not in embed_manager.stored_pg_hashes: - logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") - - logger.info("正在导入OpenIE数据文件") + # ... (rest of the main function as before) ... try: openie_data = OpenIE.load() except Exception as e: logger.error(f"导入OpenIE数据文件时发生错误:{e}") return False - if handle_import_openie(openie_data, embed_manager, kg_manager) is False: + if handle_import_openie(openie_data, embed_manager, kg_manager, lock) is False: logger.error("处理OpenIE数据时发生错误") return False return None - if __name__ == "__main__": - # logger.info(f"111111111111111111111111{ROOT_PATH}") - main() + main() \ No newline at end of file diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d732683a..4bfb00bb 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -7,59 +7,34 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd - -# import tqdm import faiss - -# from .llm_client import LLMClient -# from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install from rich.progress import ( - Progress, - BarColumn, - TimeElapsedColumn, - TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, + Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, + TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn, ) from src.manager.local_store_manager import local_storage from src.chat.utils.utils import get_embedding from src.config.config import global_config - install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") -TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 +TOTAL_EMBEDDING_TIMES = 3 -# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录 -# 这些字符串的嵌入结果应该是固定的,不能随时间变化 EMBEDDING_TEST_STRINGS = [ - "阿卡伊真的太好玩了,神秘性感大女同等着你", - "你怎么知道我arc12.64了", - "我是蕾缪乐小姐的狗", - "关注Oct谢谢喵", - "不是w6我不草", - "关注千石可乐谢谢喵", - "来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵", - "关注墨梓柒谢谢喵", - "Ciallo~", - "来玩巧克甜恋谢谢喵", - "水印", - "我也在纠结晚饭,铁锅炒鸡听着就香!", - "test你妈喵", + "阿卡伊真的太好玩了,神秘性感大女同等着你", "你怎么知道我arc12.64了", "我是蕾缪乐小姐的狗", + "关注Oct谢谢喵", "不是w6我不草", "关注千石可乐谢谢喵", "来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵", + "关注墨梓柒谢谢喵", "Ciallo~", "来玩巧克甜恋谢谢喵", "水印", + "我也在纠结晚饭,铁锅炒鸡听着就香!", "test你妈喵", ] EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json") EMBEDDING_SIM_THRESHOLD = 0.99 - def cosine_similarity(a, b): - # 计算余弦相似度 dot = sum(x * y for x, y in zip(a, b, strict=False)) norm_a = math.sqrt(sum(x * x for x in a)) norm_b = math.sqrt(sum(x * x for x in b)) @@ -67,69 +42,52 @@ def cosine_similarity(a, b): return 0.0 return dot / (norm_a * norm_b) - @dataclass class EmbeddingStoreItem: - """嵌入库中的项""" - def __init__(self, item_hash: str, embedding: List[float], content: str): self.hash = item_hash self.embedding = embedding self.str = content - def to_dict(self) -> dict: - """转为dict""" - return { - "hash": self.hash, - "embedding": self.embedding, - "str": self.str, - } - + return {"hash": self.hash, "embedding": self.embedding, "str": self.str} class EmbeddingStore: - def __init__(self, namespace: str, dir_path: str): + def __init__(self, namespace: str, dir_path: str, lock): self.namespace = namespace self.dir = dir_path self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" - self.store = {} - self.faiss_index = None self.idx2hash = None + self.lock = lock def _get_embedding(self, s: str) -> List[float]: - """获取字符串的嵌入向量,处理异步调用""" - try: - # 尝试获取当前事件循环 - asyncio.get_running_loop() - # 如果在事件循环中,使用线程池执行 - import concurrent.futures - - def run_in_thread(): - return asyncio.run(get_embedding(s)) - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) - result = future.result() + with self.lock: + try: + asyncio.get_running_loop() + import concurrent.futures + def run_in_thread(): + return asyncio.run(get_embedding(s)) + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + result = future.result() + if result is None: + logger.error(f"获取嵌入失败: {s}") + return [] + return result + except RuntimeError: + result = asyncio.run(get_embedding(s)) if result is None: logger.error(f"获取嵌入失败: {s}") return [] return result - except RuntimeError: - # 没有运行的事件循环,直接运行 - result = asyncio.run(get_embedding(s)) - if result is None: - logger.error(f"获取嵌入失败: {s}") - return [] - return result def get_test_file_path(self): return EMBEDDING_TEST_FILE def save_embedding_test_vectors(self): - """保存测试字符串的嵌入到本地""" test_vectors = {} for idx, s in enumerate(EMBEDDING_TEST_STRINGS): test_vectors[str(idx)] = self._get_embedding(s) @@ -137,7 +95,6 @@ class EmbeddingStore: json.dump(test_vectors, f, ensure_ascii=False, indent=2) def load_embedding_test_vectors(self): - """加载本地保存的测试字符串嵌入""" path = self.get_test_file_path() if not os.path.exists(path): return None @@ -145,7 +102,6 @@ class EmbeddingStore: return json.load(f) def check_embedding_model_consistency(self): - """校验当前模型与本地嵌入模型是否一致""" local_vectors = self.load_embedding_test_vectors() if local_vectors is None: logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") @@ -166,232 +122,114 @@ class EmbeddingStore: return True def batch_insert_strs(self, strs: List[str], times: int) -> None: - """向库中存入字符串""" total = len(strs) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - transient=False, - ) as progress: + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False) as progress: task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) for s in strs: - # 计算hash去重 item_hash = self.namespace + "-" + get_sha256(s) if item_hash in self.store: progress.update(task, advance=1) continue - - # 获取embedding embedding = self._get_embedding(s) - - # 存入 self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) progress.update(task, advance=1) def save_to_file(self) -> None: - """保存到文件""" - data = [] + data = [item.to_dict() for item in self.store.values()] logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}") - for item in self.store.values(): - data.append(item.to_dict()) - data_frame = pd.DataFrame(data) - - if not os.path.exists(self.dir): - os.makedirs(self.dir, exist_ok=True) - if not os.path.exists(self.embedding_file_path): - open(self.embedding_file_path, "w").close() - - data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False) + df = pd.DataFrame(data) + os.makedirs(self.dir, exist_ok=True) + df.to_parquet(self.embedding_file_path, engine="pyarrow", index=False) logger.info(f"{self.namespace}嵌入库保存成功") - if self.faiss_index is not None and self.idx2hash is not None: logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}") faiss.write_index(self.faiss_index, self.index_file_path) logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功") logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}") with open(self.idx2hash_file_path, "w", encoding="utf-8") as f: - f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4)) + json.dump(self.idx2hash, f, ensure_ascii=False, indent=4) logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") def load_from_file(self) -> None: - """从文件中加载""" if not os.path.exists(self.embedding_file_path): - raise Exception(f"文件{self.embedding_file_path}不存在") - logger.info("正在加载嵌入库...") - logger.debug(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") - data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") - total = len(data_frame) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - transient=False, - ) as progress: - task = progress.add_task("加载嵌入库", total=total) - for _, row in data_frame.iterrows(): - self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) - progress.update(task, advance=1) + raise FileNotFoundError(f"文件{self.embedding_file_path}不存在") + logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") + df = pd.read_parquet(self.embedding_file_path, engine="pyarrow") + for _, row in df.iterrows(): + self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) logger.info(f"{self.namespace}嵌入库加载成功") - try: - if os.path.exists(self.index_file_path): - logger.info(f"正在加载{self.namespace}嵌入库的FaissIndex...") - logger.debug(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex") + if os.path.exists(self.index_file_path) and os.path.exists(self.idx2hash_file_path): self.faiss_index = faiss.read_index(self.index_file_path) - logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功") - else: - raise Exception(f"文件{self.index_file_path}不存在") - if os.path.exists(self.idx2hash_file_path): - logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...") - logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射") with open(self.idx2hash_file_path, "r") as f: self.idx2hash = json.load(f) - logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") + logger.info(f"{self.namespace}嵌入库的FaissIndex和idx2hash加载成功") else: - raise Exception(f"文件{self.idx2hash_file_path}不存在") + raise FileNotFoundError("Faiss index or idx2hash file not found.") except Exception as e: - logger.error(f"加载{self.namespace}嵌入库的FaissIndex时发生错误:{e}") - logger.warning("正在重建Faiss索引") + logger.warning(f"加载FaissIndex失败 ({e}),正在重建...") self.build_faiss_index() - logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功") self.save_to_file() def build_faiss_index(self) -> None: - """重新构建Faiss索引,以余弦相似度为度量""" - # 获取所有的embedding - array = [] - self.idx2hash = dict() - for key in self.store: - array.append(self.store[key].embedding) - self.idx2hash[str(len(array) - 1)] = key - embeddings = np.array(array, dtype=np.float32) - # L2归一化 + embeddings = np.array([item.embedding for item in self.store.values()], dtype=np.float32) + if embeddings.shape[0] == 0: + return faiss.normalize_L2(embeddings) - # 构建索引 self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension) self.faiss_index.add(embeddings) + self.idx2hash = {str(i): h for i, h in enumerate(self.store.keys())} def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: - """搜索最相似的k个项,以余弦相似度为度量 - Args: - query: 查询的embedding - k: 返回的最相似的k个项 - Returns: - result: 最相似的k个项的(hash, 余弦相似度)列表 - """ - if self.faiss_index is None: - logger.debug("FaissIndex尚未构建,返回None") - return [] - if self.idx2hash is None: - logger.warning("idx2hash尚未构建,返回None") - return [] - - # L2归一化 - faiss.normalize_L2(np.array([query], dtype=np.float32)) - # 搜索 - distances, indices = self.faiss_index.search(np.array([query]), k) - # 整理结果 - indices = list(indices.flatten()) - distances = list(distances.flatten()) - result = [ - (self.idx2hash[str(int(idx))], float(sim)) - for (idx, sim) in zip(indices, distances, strict=False) - if idx in range(len(self.idx2hash)) - ] - - return result - + if self.faiss_index is None: return [] + query_np = np.array([query], dtype=np.float32) + faiss.normalize_L2(query_np) + distances, indices = self.faiss_index.search(query_np, k) + return [(self.idx2hash[str(int(idx))], float(dist)) for idx, dist in zip(indices[0], distances[0]) if str(int(idx)) in self.idx2hash] class EmbeddingManager: - def __init__(self): - self.paragraphs_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore - EMBEDDING_DATA_DIR_STR, - ) - self.entities_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore - EMBEDDING_DATA_DIR_STR, - ) - self.relation_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore - EMBEDDING_DATA_DIR_STR, - ) + def __init__(self, lock): + self.lock = lock + self.paragraphs_embedding_store = EmbeddingStore(local_storage["pg_namespace"], EMBEDDING_DATA_DIR_STR, self.lock) + self.entities_embedding_store = EmbeddingStore(local_storage["ent_namespace"], EMBEDDING_DATA_DIR_STR, self.lock) + self.relation_embedding_store = EmbeddingStore(local_storage["rel_namespace"], EMBEDDING_DATA_DIR_STR, self.lock) self.stored_pg_hashes = set() def check_all_embedding_model_consistency(self): - """对所有嵌入库做模型一致性校验""" - for store in [ - self.paragraphs_embedding_store, - self.entities_embedding_store, - self.relation_embedding_store, - ]: - if not store.check_embedding_model_consistency(): - return False - return True + return all(store.check_embedding_model_consistency() for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]) def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): - """将段落编码存入Embedding库""" self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): - """将实体编码存入Embedding库""" - entities = set() - for triple_list in triple_list_data.values(): - for triple in triple_list: - entities.add(triple[0]) - entities.add(triple[2]) + entities = {triple[i] for triple_list in triple_list_data.values() for triple in triple_list for i in (0, 2)} self.entities_embedding_store.batch_insert_strs(list(entities), times=2) def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): - """将关系编码存入Embedding库""" - graph_triples = [] # a list of unique relation triple (in tuple) from all chunks - for triples in triple_list_data.values(): - graph_triples.extend([tuple(t) for t in triples]) - graph_triples = list(set(graph_triples)) + graph_triples = list({tuple(t) for triples in triple_list_data.values() for t in triples}) self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3) def load_from_file(self): - """从文件加载""" - self.paragraphs_embedding_store.load_from_file() - self.entities_embedding_store.load_from_file() - self.relation_embedding_store.load_from_file() - # 从段落库中获取已存储的hash + for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]: + try: + store.load_from_file() + except Exception: + pass self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys()) - def store_new_data_set( - self, - raw_paragraphs: Dict[str, str], - triple_list_data: Dict[str, List[List[str]]], - ): + def store_new_data_set(self, raw_paragraphs: Dict[str, str], triple_list_data: Dict[str, List[List[str]]], lock): if not self.check_all_embedding_model_consistency(): raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") - """存储新的数据集""" self._store_pg_into_embedding(raw_paragraphs) self._store_ent_into_embedding(triple_list_data) self._store_rel_into_embedding(triple_list_data) self.stored_pg_hashes.update(raw_paragraphs.keys()) def save_to_file(self): - """保存到文件""" - self.paragraphs_embedding_store.save_to_file() - self.entities_embedding_store.save_to_file() - self.relation_embedding_store.save_to_file() + for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]: + store.save_to_file() def rebuild_faiss_index(self): - """重建Faiss索引(请在添加新数据后调用)""" - self.paragraphs_embedding_store.build_faiss_index() - self.entities_embedding_store.build_faiss_index() - self.relation_embedding_store.build_faiss_index() + for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]: + store.build_faiss_index() +# --- END OF FILE src/chat/knowledge/embedding_store.py --- \ No newline at end of file diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 1e87d382..af9454e1 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -8,133 +8,89 @@ from src.chat.knowledge.global_logger import logger from src.config.config import global_config as bot_global_config from src.manager.local_store_manager import local_storage import os +from multiprocessing import Manager INVALID_ENTITY = [ - "", - "你", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "她们", - "它们", + "", "你", "他", "她", "它", "我们", "你们", "他们", "她们", "它们", ] PG_NAMESPACE = "paragraph" ENT_NAMESPACE = "entity" REL_NAMESPACE = "relation" - RAG_GRAPH_NAMESPACE = "rag-graph" RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" RAG_PG_HASH_NAMESPACE = "rag-pg-hash" - - ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) DATA_PATH = os.path.join(ROOT_PATH, "data") - def _initialize_knowledge_local_storage(): - """ - 初始化知识库相关的本地存储配置 - 使用字典批量设置,避免重复的if判断 - """ - # 定义所有需要初始化的配置项 default_configs = { - # 路径配置 "root_path": ROOT_PATH, "data_path": f"{ROOT_PATH}/data", - # 实体和命名空间配置 "lpmm_invalid_entity": INVALID_ENTITY, "pg_namespace": PG_NAMESPACE, "ent_namespace": ENT_NAMESPACE, "rel_namespace": REL_NAMESPACE, - # RAG相关命名空间配置 "rag_graph_namespace": RAG_GRAPH_NAMESPACE, "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE, "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE, } - - # 日志级别映射:重要配置用info,其他用debug important_configs = {"root_path", "data_path"} - - # 批量设置配置项 initialized_count = 0 for key, default_value in default_configs.items(): if local_storage[key] is None: local_storage[key] = default_value - - # 根据重要性选择日志级别 if key in important_configs: logger.info(f"设置{key}: {default_value}") else: logger.debug(f"设置{key}: {default_value}") - initialized_count += 1 - if initialized_count > 0: logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") else: logger.debug("知识库本地存储配置已存在,跳过初始化") - -# 初始化本地存储路径 -# sourcery skip: dict-comprehension _initialize_knowledge_local_storage() - qa_manager = None inspire_manager = None -# 检查LPMM知识库是否启用 if bot_global_config.lpmm_knowledge.enable: logger.info("正在初始化Mai-LPMM") logger.info("创建LLM客户端") llm_client_list = {} for key in global_config["llm_providers"]: llm_client_list[key] = LLMClient( - global_config["llm_providers"][key]["base_url"], # type: ignore - global_config["llm_providers"][key]["api_key"], # type: ignore + global_config["llm_providers"][key]["base_url"], + global_config["llm_providers"][key]["api_key"], ) - - # 初始化Embedding库 - embed_manager = EmbeddingManager() + + manager = Manager() + lock = manager.Lock() + + embed_manager = EmbeddingManager(lock) logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() except Exception as e: logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("Embedding库加载完成") - # 初始化KG + kg_manager = KGManager() logger.info("正在从文件加载KG") try: kg_manager.load_from_file() except Exception as e: logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("KG加载完成") logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") - # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: key = f"{PG_NAMESPACE}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") - # 问答系统(用于知识库) - qa_manager = QAManager( - embed_manager, - kg_manager, - ) - - # 记忆激活(用于记忆库) - inspire_manager = MemoryActiveManager( - embed_manager, - llm_client_list[global_config["embedding"]["provider"]], - ) + qa_manager = QAManager(embed_manager, kg_manager) + inspire_manager = MemoryActiveManager(embed_manager, llm_client_list[global_config["embedding"]["provider"]]) else: - logger.info("LPMM知识库已禁用,跳过初始化") - # 创建空的占位符对象,避免导入错误 + logger.info("LPMM知识库已禁用,跳过初始化") \ No newline at end of file