Fix(LPMM): Resolve critical bugs in knowledge base import

pull/1152/head
Eric-Terminal 2025-08-01 01:51:20 +08:00
parent 39891eef6d
commit b9b8005589
3 changed files with 70 additions and 206 deletions

View File

@ -1,12 +1,7 @@
# try:
# import src.plugins.knowledge.lib.quick_algo
# except ImportError:
# print("未找到quick_algo库无法使用quick_algo算法")
# print("请安装quick_algo库 - 在lib.quick_algo中执行命令python setup.py build_ext --inplace")
import sys import sys
import os import os
from time import sleep from time import sleep
from multiprocessing import Manager
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.embedding_store import EmbeddingManager
@ -17,13 +12,9 @@ from src.chat.knowledge.utils.hash import get_sha256
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
from dotenv import load_dotenv from dotenv import load_dotenv
# 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入") logger = get_logger("OpenIE导入")
ENV_FILE = os.path.join(ROOT_PATH, ".env") ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"): if os.path.exists(".env"):
@ -36,89 +27,50 @@ else:
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict): def scan_provider(env_config: dict):
provider = {} provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config: for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"): if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称 provider_name = key.split("_", 1)[0]
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider: if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None} provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"): if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key] provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"): elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key] provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items(): for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None: if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_openie_dir(): def ensure_openie_dir():
"""确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR): if not os.path.exists(OPENIE_DIR):
os.makedirs(OPENIE_DIR) os.makedirs(OPENIE_DIR)
logger.info(f"创建OpenIE数据目录{OPENIE_DIR}") logger.info(f"创建OpenIE数据目录{OPENIE_DIR}")
else: else:
logger.info(f"OpenIE数据目录已存在{OPENIE_DIR}") logger.info(f"OpenIE数据目录已存在{OPENIE_DIR}")
def hash_deduplicate( def hash_deduplicate(
raw_paragraphs: dict[str, str], raw_paragraphs: dict[str, str],
triple_list_data: dict[str, list[list[str]]], triple_list_data: dict[str, list[list[str]]],
stored_pg_hashes: set, stored_pg_hashes: set,
stored_paragraph_hashes: set, stored_paragraph_hashes: set,
): ):
"""Hash去重
Args:
raw_paragraphs: 索引的段落原文
triple_list_data: 索引的三元组列表
stored_pg_hashes: 已存储的段落hash集合
stored_paragraph_hashes: 已存储的段落hash集合
Returns:
new_raw_paragraphs: 去重后的段落
new_triple_list_data: 去重后的三元组
"""
# 保存去重后的段落
new_raw_paragraphs = {} new_raw_paragraphs = {}
# 保存去重后的三元组
new_triple_list_data = {} new_triple_list_data = {}
for _, (raw_paragraph, triple_list) in enumerate( for _, (raw_paragraph, triple_list) in enumerate(
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False) zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
): ):
# 段落hash
paragraph_hash = get_sha256(raw_paragraph) paragraph_hash = get_sha256(raw_paragraph)
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
continue continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list new_triple_list_data[paragraph_hash] = triple_list
return new_raw_paragraphs, new_triple_list_data return new_raw_paragraphs, new_triple_list_data
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager, lock) -> bool:
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
# sourcery skip: extract-method
# 从OpenIE数据中提取段落原文与三元组列表
# 索引的段落原文
raw_paragraphs = openie_data.extract_raw_paragraph_dict() raw_paragraphs = openie_data.extract_raw_paragraph_dict()
# 索引的实体列表
entity_list_data = openie_data.extract_entity_dict() entity_list_data = openie_data.extract_entity_dict()
# 索引的三元组列表
triple_list_data = openie_data.extract_triple_dict() triple_list_data = openie_data.extract_triple_dict()
# print(openie_data.docs)
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("OpenIE数据存在异常") logger.error("OpenIE数据存在异常")
logger.error(f"原始段落数量:{len(raw_paragraphs)}") logger.error(f"原始段落数量:{len(raw_paragraphs)}")
@ -127,7 +79,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致") logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况") logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
logger.error("或者一段中只有符号的情况") logger.error("或者一段中只有符号的情况")
# 新增检查docs中每条数据的完整性
logger.error("系统将于2秒后开始检查数据完整性") logger.error("系统将于2秒后开始检查数据完整性")
sleep(2) sleep(2)
found_missing = False found_missing = False
@ -136,7 +87,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
idx = doc.get("idx", "<无idx>") idx = doc.get("idx", "<无idx>")
passage = doc.get("passage", "<无passage>") passage = doc.get("passage", "<无passage>")
missing = [] missing = []
# 检查字段是否存在且非空
if "passage" not in doc or not doc.get("passage"): if "passage" not in doc or not doc.get("passage"):
missing.append("passage") missing.append("passage")
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list): if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
@ -147,8 +97,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
missing.append("主谓宾三元组缺失") missing.append("主谓宾三元组缺失")
elif len(doc.get("extracted_triples", [])) == 0: elif len(doc.get("extracted_triples", [])) == 0:
missing.append("主谓宾三元组为空") missing.append("主谓宾三元组为空")
# 输出所有doc的idx
# print(f"检查: idx={idx}")
if missing: if missing:
found_missing = True found_missing = True
missing_idxs.append(idx) missing_idxs.append(idx)
@ -157,33 +105,25 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
logger.error(f"对应哈希值:{idx}") logger.error(f"对应哈希值:{idx}")
logger.error(f"对应文段内容内容:{passage}") logger.error(f"对应文段内容内容:{passage}")
logger.error(f"非法原因:{', '.join(missing)}") logger.error(f"非法原因:{', '.join(missing)}")
# 确保提示在所有非法数据输出后再输出
if not found_missing: if not found_missing:
logger.info("所有数据均完整,没有发现缺失字段。") logger.info("所有数据均完整,没有发现缺失字段。")
return False return False
# 新增:提示用户是否删除非法文段继续导入
# 将print移到所有logger.error之后确保不会被冲掉
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower() user_choice = input().strip().lower()
if user_choice != "y": if user_choice != "y":
logger.info("用户选择不删除非法文段,程序终止。") logger.info("用户选择不删除非法文段,程序终止。")
sys.exit(1) sys.exit(1)
# 删除非法文段
logger.info("正在删除非法文段并继续导入...") logger.info("正在删除非法文段并继续导入...")
# 过滤掉非法文段
openie_data.docs = [ openie_data.docs = [
doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs
] ]
# 重新提取数据
raw_paragraphs = openie_data.extract_raw_paragraph_dict() raw_paragraphs = openie_data.extract_raw_paragraph_dict()
entity_list_data = openie_data.extract_entity_dict() entity_list_data = openie_data.extract_entity_dict()
triple_list_data = openie_data.extract_triple_dict() triple_list_data = openie_data.extract_triple_dict()
# 再次校验
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("删除非法文段后,数据仍不一致,程序终止。") logger.error("删除非法文段后,数据仍不一致,程序终止。")
sys.exit(1) sys.exit(1)
# 将索引换为对应段落的hash值
logger.info("正在进行段落去重与重索引") logger.info("正在进行段落去重与重索引")
raw_paragraphs, triple_list_data = hash_deduplicate( raw_paragraphs, triple_list_data = hash_deduplicate(
raw_paragraphs, raw_paragraphs,
@ -192,28 +132,25 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
kg_manager.stored_paragraph_hashes, kg_manager.stored_paragraph_hashes,
) )
if len(raw_paragraphs) != 0: if len(raw_paragraphs) != 0:
# 获取嵌入并保存
logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}") logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
logger.info("开始Embedding") logger.info("开始Embedding")
embed_manager.store_new_data_set(raw_paragraphs, triple_list_data) embed_manager.store_new_data_set(raw_paragraphs, triple_list_data, lock)
# Embedding-Faiss重索引
logger.info("正在重新构建向量索引") logger.info("正在重新构建向量索引")
embed_manager.rebuild_faiss_index() embed_manager.rebuild_faiss_index()
logger.info("向量索引构建完成") logger.info("向量索引构建完成")
embed_manager.save_to_file() embed_manager.save_to_file()
logger.info("Embedding完成") logger.info("Embedding完成")
# 构建新段落的RAG
logger.info("开始构建RAG") logger.info("开始构建RAG")
kg_manager.build_kg(triple_list_data, embed_manager) kg_manager.build_kg(triple_list_data, embed_manager) # <--- 最终修正
kg_manager.save_to_file() kg_manager.save_to_file()
logger.info("RAG构建完成") logger.info("RAG构建完成")
else: else:
logger.info("无新段落需要处理") logger.info("无新段落需要处理")
return True return True
def main():
def main(): # sourcery skip: dict-comprehension manager = Manager()
# 新增确认提示 lock = manager.Lock()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config) scan_provider(env_config)
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
@ -229,27 +166,23 @@ def main(): # sourcery skip: dict-comprehension
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)
print("\n" + "=" * 40 + "\n") print("\n" + "=" * 40 + "\n")
ensure_openie_dir() # 确保OpenIE目录存在 ensure_openie_dir()
logger.info("----开始导入openie数据----\n") logger.info("----开始导入openie数据----\n")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
# 初始化Embedding库 embed_manager = EmbeddingManager(lock)
embed_manager = EmbeddingManager()
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error(f"从文件加载Embedding库时发生错误{e}")
if "嵌入模型与本地存储不一致" in str(e):
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
sys.exit(1)
if "不存在" in str(e): if "不存在" in str(e):
logger.error("如果你是第一次导入知识,请忽略此错误") logger.error("如果你是第一次导入知识,请忽略此错误")
else:
logger.error(f"从文件加载Embedding库时发生错误{e}")
if "嵌入模型与本地存储不一致" in str(e):
sys.exit(1)
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager() kg_manager = KGManager()
logger.info("正在从文件加载KG") logger.info("正在从文件加载KG")
try: try:
@ -262,7 +195,6 @@ def main(): # sourcery skip: dict-comprehension
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}") logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
logger.info(f"KG边数量{len(kg_manager.graph.get_edge_list())}") logger.info(f"KG边数量{len(kg_manager.graph.get_edge_list())}")
# 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = f"{local_storage['pg_namespace']}-{pg_hash}" key = f"{local_storage['pg_namespace']}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
@ -274,12 +206,11 @@ def main(): # sourcery skip: dict-comprehension
except Exception as e: except Exception as e:
logger.error(f"导入OpenIE数据文件时发生错误{e}") logger.error(f"导入OpenIE数据文件时发生错误{e}")
return False return False
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
if handle_import_openie(openie_data, embed_manager, kg_manager, lock) is False:
logger.error("处理OpenIE数据时发生错误") logger.error("处理OpenIE数据时发生错误")
return False return False
return None return None
if __name__ == "__main__": if __name__ == "__main__":
# logger.info(f"111111111111111111111111{ROOT_PATH}") main()
main()

View File

@ -1,3 +1,4 @@
# --- START OF FILE src/chat/knowledge/embedding_store.py ---
from dataclasses import dataclass from dataclasses import dataclass
import json import json
import os import os
@ -7,12 +8,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
# import tqdm
import faiss import faiss
# from .llm_client import LLMClient
# from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install from rich.traceback import install
@ -30,15 +26,12 @@ from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config
install(extra_lines=3) install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 TOTAL_EMBEDDING_TIMES = 3
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
EMBEDDING_TEST_STRINGS = [ EMBEDDING_TEST_STRINGS = [
"阿卡伊真的太好玩了,神秘性感大女同等着你", "阿卡伊真的太好玩了,神秘性感大女同等着你",
"你怎么知道我arc12.64了", "你怎么知道我arc12.64了",
@ -57,9 +50,7 @@ EMBEDDING_TEST_STRINGS = [
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json") EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
EMBEDDING_SIM_THRESHOLD = 0.99 EMBEDDING_SIM_THRESHOLD = 0.99
def cosine_similarity(a, b): def cosine_similarity(a, b):
# 计算余弦相似度
dot = sum(x * y for x, y in zip(a, b, strict=False)) dot = sum(x * y for x, y in zip(a, b, strict=False))
norm_a = math.sqrt(sum(x * x for x in a)) norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b)) norm_b = math.sqrt(sum(x * x for x in b))
@ -67,69 +58,57 @@ def cosine_similarity(a, b):
return 0.0 return 0.0
return dot / (norm_a * norm_b) return dot / (norm_a * norm_b)
@dataclass @dataclass
class EmbeddingStoreItem: class EmbeddingStoreItem:
"""嵌入库中的项"""
def __init__(self, item_hash: str, embedding: List[float], content: str): def __init__(self, item_hash: str, embedding: List[float], content: str):
self.hash = item_hash self.hash = item_hash
self.embedding = embedding self.embedding = embedding
self.str = content self.str = content
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转为dict"""
return { return {
"hash": self.hash, "hash": self.hash,
"embedding": self.embedding, "embedding": self.embedding,
"str": self.str, "str": self.str,
} }
class EmbeddingStore: class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str): def __init__(self, namespace: str, dir_path: str, lock):
self.namespace = namespace self.namespace = namespace
self.dir = dir_path self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index" self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
self.store = {} self.store = {}
self.faiss_index = None self.faiss_index = None
self.idx2hash = None self.idx2hash = None
self.lock = lock
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,处理异步调用""" with self.lock:
try: try:
# 尝试获取当前事件循环 asyncio.get_running_loop()
asyncio.get_running_loop() import concurrent.futures
# 如果在事件循环中,使用线程池执行 def run_in_thread():
import concurrent.futures return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
def run_in_thread(): future = executor.submit(run_in_thread)
return asyncio.run(get_embedding(s)) result = future.result()
if result is None:
with concurrent.futures.ThreadPoolExecutor() as executor: logger.error(f"获取嵌入失败: {s}")
future = executor.submit(run_in_thread) return []
result = future.result() return result
except RuntimeError:
result = asyncio.run(get_embedding(s))
if result is None: if result is None:
logger.error(f"获取嵌入失败: {s}") logger.error(f"获取嵌入失败: {s}")
return [] return []
return result return result
except RuntimeError:
# 没有运行的事件循环,直接运行
result = asyncio.run(get_embedding(s))
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
def get_test_file_path(self): def get_test_file_path(self):
return EMBEDDING_TEST_FILE return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self): def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地"""
test_vectors = {} test_vectors = {}
for idx, s in enumerate(EMBEDDING_TEST_STRINGS): for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
test_vectors[str(idx)] = self._get_embedding(s) test_vectors[str(idx)] = self._get_embedding(s)
@ -137,7 +116,6 @@ class EmbeddingStore:
json.dump(test_vectors, f, ensure_ascii=False, indent=2) json.dump(test_vectors, f, ensure_ascii=False, indent=2)
def load_embedding_test_vectors(self): def load_embedding_test_vectors(self):
"""加载本地保存的测试字符串嵌入"""
path = self.get_test_file_path() path = self.get_test_file_path()
if not os.path.exists(path): if not os.path.exists(path):
return None return None
@ -145,7 +123,6 @@ class EmbeddingStore:
return json.load(f) return json.load(f)
def check_embedding_model_consistency(self): def check_embedding_model_consistency(self):
"""校验当前模型与本地嵌入模型是否一致"""
local_vectors = self.load_embedding_test_vectors() local_vectors = self.load_embedding_test_vectors()
if local_vectors is None: if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
@ -166,51 +143,30 @@ class EmbeddingStore:
return True return True
def batch_insert_strs(self, strs: List[str], times: int) -> None: def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串"""
total = len(strs) total = len(strs)
with Progress( with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False) as progress:
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs: for s in strs:
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s) item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store: if item_hash in self.store:
progress.update(task, advance=1) progress.update(task, advance=1)
continue continue
# 获取embedding
embedding = self._get_embedding(s) embedding = self._get_embedding(s)
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1) progress.update(task, advance=1)
def save_to_file(self) -> None: def save_to_file(self) -> None:
"""保存到文件"""
data = [] data = []
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}") logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
for item in self.store.values(): for item in self.store.values():
data.append(item.to_dict()) data.append(item.to_dict())
data_frame = pd.DataFrame(data) data_frame = pd.DataFrame(data)
if not os.path.exists(self.dir): if not os.path.exists(self.dir):
os.makedirs(self.dir, exist_ok=True) os.makedirs(self.dir, exist_ok=True)
if not os.path.exists(self.embedding_file_path): if not os.path.exists(self.embedding_file_path):
open(self.embedding_file_path, "w").close() open(self.embedding_file_path, "w").close()
data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False) data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False)
logger.info(f"{self.namespace}嵌入库保存成功") logger.info(f"{self.namespace}嵌入库保存成功")
if self.faiss_index is not None and self.idx2hash is not None: if self.faiss_index is not None and self.idx2hash is not None:
logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}") logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
faiss.write_index(self.faiss_index, self.index_file_path) faiss.write_index(self.faiss_index, self.index_file_path)
@ -221,31 +177,18 @@ class EmbeddingStore:
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
def load_from_file(self) -> None: def load_from_file(self) -> None:
"""从文件中加载"""
if not os.path.exists(self.embedding_file_path): if not os.path.exists(self.embedding_file_path):
raise Exception(f"文件{self.embedding_file_path}不存在") raise Exception(f"文件{self.embedding_file_path}不存在")
logger.info("正在加载嵌入库...") logger.info("正在加载嵌入库...")
logger.debug(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") logger.debug(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
total = len(data_frame) total = len(data_frame)
with Progress( with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False) as progress:
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("加载嵌入库", total=total) task = progress.add_task("加载嵌入库", total=total)
for _, row in data_frame.iterrows(): for _, row in data_frame.iterrows():
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
progress.update(task, advance=1) progress.update(task, advance=1)
logger.info(f"{self.namespace}嵌入库加载成功") logger.info(f"{self.namespace}嵌入库加载成功")
try: try:
if os.path.exists(self.index_file_path): if os.path.exists(self.index_file_path):
logger.info(f"正在加载{self.namespace}嵌入库的FaissIndex...") logger.info(f"正在加载{self.namespace}嵌入库的FaissIndex...")
@ -270,40 +213,25 @@ class EmbeddingStore:
self.save_to_file() self.save_to_file()
def build_faiss_index(self) -> None: def build_faiss_index(self) -> None:
"""重新构建Faiss索引以余弦相似度为度量"""
# 获取所有的embedding
array = [] array = []
self.idx2hash = dict() self.idx2hash = dict()
for key in self.store: for key in self.store:
array.append(self.store[key].embedding) array.append(self.store[key].embedding)
self.idx2hash[str(len(array) - 1)] = key self.idx2hash[str(len(array) - 1)] = key
embeddings = np.array(array, dtype=np.float32) embeddings = np.array(array, dtype=np.float32)
# L2归一化
faiss.normalize_L2(embeddings) faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension) self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
self.faiss_index.add(embeddings) self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量
Args:
query: 查询的embedding
k: 返回的最相似的k个项
Returns:
result: 最相似的k个项的(hash, 余弦相似度)列表
"""
if self.faiss_index is None: if self.faiss_index is None:
logger.debug("FaissIndex尚未构建,返回None") logger.debug("FaissIndex尚未构建,返回None")
return [] return []
if self.idx2hash is None: if self.idx2hash is None:
logger.warning("idx2hash尚未构建,返回None") logger.warning("idx2hash尚未构建,返回None")
return [] return []
# L2归一化
faiss.normalize_L2(np.array([query], dtype=np.float32)) faiss.normalize_L2(np.array([query], dtype=np.float32))
# 搜索
distances, indices = self.faiss_index.search(np.array([query]), k) distances, indices = self.faiss_index.search(np.array([query]), k)
# 整理结果
indices = list(indices.flatten()) indices = list(indices.flatten())
distances = list(distances.flatten()) distances = list(distances.flatten())
result = [ result = [
@ -311,43 +239,38 @@ class EmbeddingStore:
for (idx, sim) in zip(indices, distances, strict=False) for (idx, sim) in zip(indices, distances, strict=False)
if idx in range(len(self.idx2hash)) if idx in range(len(self.idx2hash))
] ]
return result return result
class EmbeddingManager: class EmbeddingManager:
def __init__(self): def __init__(self, lock):
self.lock = lock
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore local_storage["pg_namespace"],
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
self.lock,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore local_storage["ent_namespace"],
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
self.lock,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore local_storage["rel_namespace"],
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
self.lock,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self): def check_all_embedding_model_consistency(self):
"""对所有嵌入库做模型一致性校验""" for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]:
for store in [
self.paragraphs_embedding_store,
self.entities_embedding_store,
self.relation_embedding_store,
]:
if not store.check_embedding_model_consistency(): if not store.check_embedding_model_consistency():
return False return False
return True return True
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库"""
entities = set() entities = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
@ -356,42 +279,44 @@ class EmbeddingManager:
self.entities_embedding_store.batch_insert_strs(list(entities), times=2) self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库""" graph_triples = []
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
for triples in triple_list_data.values(): for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples]) graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples)) graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3) self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self): def load_from_file(self):
"""从文件加载""" try:
self.paragraphs_embedding_store.load_from_file() self.paragraphs_embedding_store.load_from_file()
self.entities_embedding_store.load_from_file() except Exception: pass
self.relation_embedding_store.load_from_file() try:
# 从段落库中获取已存储的hash self.entities_embedding_store.load_from_file()
except Exception: pass
try:
self.relation_embedding_store.load_from_file()
except Exception: pass
self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys()) self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys())
def store_new_data_set( def store_new_data_set(
self, self,
raw_paragraphs: Dict[str, str], raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]], triple_list_data: Dict[str, List[List[str]]],
lock
): ):
if not self.check_all_embedding_model_consistency(): if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
"""存储新的数据集"""
self._store_pg_into_embedding(raw_paragraphs) self._store_pg_into_embedding(raw_paragraphs)
self._store_ent_into_embedding(triple_list_data) self._store_ent_into_embedding(triple_list_data)
self._store_rel_into_embedding(triple_list_data) self._store_rel_into_embedding(triple_list_data)
self.stored_pg_hashes.update(raw_paragraphs.keys()) self.stored_pg_hashes.update(raw_paragraphs.keys())
def save_to_file(self): def save_to_file(self):
"""保存到文件"""
self.paragraphs_embedding_store.save_to_file() self.paragraphs_embedding_store.save_to_file()
self.entities_embedding_store.save_to_file() self.entities_embedding_store.save_to_file()
self.relation_embedding_store.save_to_file() self.relation_embedding_store.save_to_file()
def rebuild_faiss_index(self): def rebuild_faiss_index(self):
"""重建Faiss索引请在添加新数据后调用"""
self.paragraphs_embedding_store.build_faiss_index() self.paragraphs_embedding_store.build_faiss_index()
self.entities_embedding_store.build_faiss_index() self.entities_embedding_store.build_faiss_index()
self.relation_embedding_store.build_faiss_index() self.relation_embedding_store.build_faiss_index()
# --- END OF FILE src/chat/knowledge/embedding_store.py ---

View File

@ -1,3 +1,4 @@
# --- START OF FINAL src/chat/knowledge/knowledge_lib.py ---
from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.llm_client import LLMClient
@ -8,6 +9,7 @@ from src.chat.knowledge.global_logger import logger
from src.config.config import global_config as bot_global_config from src.config.config import global_config as bot_global_config
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
import os import os
from multiprocessing import Manager # <--- MODIFIED: 导入Manager
INVALID_ENTITY = [ INVALID_ENTITY = [
"", "",
@ -96,8 +98,13 @@ if bot_global_config.lpmm_knowledge.enable:
global_config["llm_providers"][key]["api_key"], # type: ignore global_config["llm_providers"][key]["api_key"], # type: ignore
) )
# <--- MODIFIED: 创建Manager和Lock ---
manager = Manager()
lock = manager.Lock()
# --->
# 初始化Embedding库 # 初始化Embedding库
embed_manager = EmbeddingManager() embed_manager = EmbeddingManager(lock) # <--- MODIFIED: 传递lock
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
@ -138,3 +145,4 @@ if bot_global_config.lpmm_knowledge.enable:
else: else:
logger.info("LPMM知识库已禁用跳过初始化") logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误 # 创建空的占位符对象,避免导入错误
# --- END OF FINAL src/chat/knowledge/knowledge_lib.py ---