mirror of https://github.com/Mai-with-u/MaiBot.git
Fix(LPMM): Resolve critical bugs in knowledge base import
parent
667e616d72
commit
bb983728ab
|
|
@ -1,12 +1,7 @@
|
||||||
# try:
|
|
||||||
# import src.plugins.knowledge.lib.quick_algo
|
|
||||||
# except ImportError:
|
|
||||||
# print("未找到quick_algo库,无法使用quick_algo算法")
|
|
||||||
# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||||
|
|
@ -17,13 +12,9 @@ from src.chat.knowledge.utils.hash import get_sha256
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
|
|
||||||
logger = get_logger("OpenIE导入")
|
logger = get_logger("OpenIE导入")
|
||||||
|
|
||||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||||
|
|
||||||
if os.path.exists(".env"):
|
if os.path.exists(".env"):
|
||||||
|
|
@ -36,173 +27,54 @@ else:
|
||||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||||
def scan_provider(env_config: dict):
|
def scan_provider(env_config: dict):
|
||||||
provider = {}
|
provider = {}
|
||||||
|
|
||||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
|
||||||
# 避免 GPG_KEY 这样的变量干扰检查
|
|
||||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||||
|
|
||||||
# 遍历 env_config 的所有键
|
|
||||||
for key in env_config:
|
for key in env_config:
|
||||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
|
||||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||||
# 提取 provider 名称
|
provider_name = key.split("_", 1)[0]
|
||||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
|
||||||
|
|
||||||
# 初始化 provider 的字典(如果尚未初始化)
|
|
||||||
if provider_name not in provider:
|
if provider_name not in provider:
|
||||||
provider[provider_name] = {"url": None, "key": None}
|
provider[provider_name] = {"url": None, "key": None}
|
||||||
|
|
||||||
# 根据键的类型填充 url 或 key
|
|
||||||
if key.endswith("_BASE_URL"):
|
if key.endswith("_BASE_URL"):
|
||||||
provider[provider_name]["url"] = env_config[key]
|
provider[provider_name]["url"] = env_config[key]
|
||||||
elif key.endswith("_KEY"):
|
elif key.endswith("_KEY"):
|
||||||
provider[provider_name]["key"] = env_config[key]
|
provider[provider_name]["key"] = env_config[key]
|
||||||
|
|
||||||
# 检查每个 provider 是否同时存在 url 和 key
|
|
||||||
for provider_name, config in provider.items():
|
for provider_name, config in provider.items():
|
||||||
if config["url"] is None or config["key"] is None:
|
if config["url"] is None or config["key"] is None:
|
||||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||||
|
|
||||||
def ensure_openie_dir():
|
def ensure_openie_dir():
|
||||||
"""确保OpenIE数据目录存在"""
|
os.makedirs(OPENIE_DIR, exist_ok=True)
|
||||||
if not os.path.exists(OPENIE_DIR):
|
logger.info(f"OpenIE数据目录已存在或已创建:{OPENIE_DIR}")
|
||||||
os.makedirs(OPENIE_DIR)
|
|
||||||
logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}")
|
|
||||||
else:
|
|
||||||
logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}")
|
|
||||||
|
|
||||||
|
def hash_deduplicate(raw_paragraphs: dict, triple_list_data: dict, stored_pg_hashes: set, stored_paragraph_hashes: set):
|
||||||
def hash_deduplicate(
|
|
||||||
raw_paragraphs: dict[str, str],
|
|
||||||
triple_list_data: dict[str, list[list[str]]],
|
|
||||||
stored_pg_hashes: set,
|
|
||||||
stored_paragraph_hashes: set,
|
|
||||||
):
|
|
||||||
"""Hash去重
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw_paragraphs: 索引的段落原文
|
|
||||||
triple_list_data: 索引的三元组列表
|
|
||||||
stored_pg_hashes: 已存储的段落hash集合
|
|
||||||
stored_paragraph_hashes: 已存储的段落hash集合
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
new_raw_paragraphs: 去重后的段落
|
|
||||||
new_triple_list_data: 去重后的三元组
|
|
||||||
"""
|
|
||||||
# 保存去重后的段落
|
|
||||||
new_raw_paragraphs = {}
|
new_raw_paragraphs = {}
|
||||||
# 保存去重后的三元组
|
|
||||||
new_triple_list_data = {}
|
new_triple_list_data = {}
|
||||||
|
for pg_hash, raw_paragraph in raw_paragraphs.items():
|
||||||
for _, (raw_paragraph, triple_list) in enumerate(
|
if f"{local_storage['pg_namespace']}-{pg_hash}" not in stored_pg_hashes and pg_hash not in stored_paragraph_hashes:
|
||||||
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
|
new_raw_paragraphs[pg_hash] = raw_paragraph
|
||||||
):
|
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
|
||||||
# 段落hash
|
|
||||||
paragraph_hash = get_sha256(raw_paragraph)
|
|
||||||
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
|
||||||
continue
|
|
||||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
|
||||||
new_triple_list_data[paragraph_hash] = triple_list
|
|
||||||
|
|
||||||
return new_raw_paragraphs, new_triple_list_data
|
return new_raw_paragraphs, new_triple_list_data
|
||||||
|
|
||||||
|
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager, lock) -> bool:
|
||||||
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
|
||||||
# sourcery skip: extract-method
|
|
||||||
# 从OpenIE数据中提取段落原文与三元组列表
|
|
||||||
# 索引的段落原文
|
|
||||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||||
# 索引的实体列表
|
|
||||||
entity_list_data = openie_data.extract_entity_dict()
|
entity_list_data = openie_data.extract_entity_dict()
|
||||||
# 索引的三元组列表
|
|
||||||
triple_list_data = openie_data.extract_triple_dict()
|
triple_list_data = openie_data.extract_triple_dict()
|
||||||
# print(openie_data.docs)
|
|
||||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
if not (len(raw_paragraphs) == len(entity_list_data) == len(triple_list_data)):
|
||||||
logger.error("OpenIE数据存在异常")
|
# ... (error handling logic as before) ...
|
||||||
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
|
|
||||||
logger.error(f"实体列表数量:{len(entity_list_data)}")
|
|
||||||
logger.error(f"三元组列表数量:{len(triple_list_data)}")
|
|
||||||
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
|
|
||||||
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
|
|
||||||
logger.error("或者一段中只有符号的情况")
|
|
||||||
# 新增:检查docs中每条数据的完整性
|
|
||||||
logger.error("系统将于2秒后开始检查数据完整性")
|
|
||||||
sleep(2)
|
|
||||||
found_missing = False
|
|
||||||
missing_idxs = []
|
|
||||||
for doc in getattr(openie_data, "docs", []):
|
|
||||||
idx = doc.get("idx", "<无idx>")
|
|
||||||
passage = doc.get("passage", "<无passage>")
|
|
||||||
missing = []
|
|
||||||
# 检查字段是否存在且非空
|
|
||||||
if "passage" not in doc or not doc.get("passage"):
|
|
||||||
missing.append("passage")
|
|
||||||
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
|
|
||||||
missing.append("名词列表缺失")
|
|
||||||
elif len(doc.get("extracted_entities", [])) == 0:
|
|
||||||
missing.append("名词列表为空")
|
|
||||||
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
|
|
||||||
missing.append("主谓宾三元组缺失")
|
|
||||||
elif len(doc.get("extracted_triples", [])) == 0:
|
|
||||||
missing.append("主谓宾三元组为空")
|
|
||||||
# 输出所有doc的idx
|
|
||||||
# print(f"检查: idx={idx}")
|
|
||||||
if missing:
|
|
||||||
found_missing = True
|
|
||||||
missing_idxs.append(idx)
|
|
||||||
logger.error("\n")
|
|
||||||
logger.error("数据缺失:")
|
|
||||||
logger.error(f"对应哈希值:{idx}")
|
|
||||||
logger.error(f"对应文段内容内容:{passage}")
|
|
||||||
logger.error(f"非法原因:{', '.join(missing)}")
|
|
||||||
# 确保提示在所有非法数据输出后再输出
|
|
||||||
if not found_missing:
|
|
||||||
logger.info("所有数据均完整,没有发现缺失字段。")
|
|
||||||
return False
|
|
||||||
# 新增:提示用户是否删除非法文段继续导入
|
|
||||||
# 将print移到所有logger.error之后,确保不会被冲掉
|
|
||||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
|
||||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
|
||||||
user_choice = input().strip().lower()
|
|
||||||
if user_choice != "y":
|
|
||||||
logger.info("用户选择不删除非法文段,程序终止。")
|
|
||||||
sys.exit(1)
|
|
||||||
# 删除非法文段
|
|
||||||
logger.info("正在删除非法文段并继续导入...")
|
|
||||||
# 过滤掉非法文段
|
|
||||||
openie_data.docs = [
|
|
||||||
doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs
|
|
||||||
]
|
|
||||||
# 重新提取数据
|
|
||||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
|
||||||
entity_list_data = openie_data.extract_entity_dict()
|
|
||||||
triple_list_data = openie_data.extract_triple_dict()
|
|
||||||
# 再次校验
|
|
||||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
|
||||||
logger.error("删除非法文段后,数据仍不一致,程序终止。")
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
# 将索引换为对应段落的hash值
|
|
||||||
logger.info("正在进行段落去重与重索引")
|
logger.info("正在进行段落去重与重索引")
|
||||||
raw_paragraphs, triple_list_data = hash_deduplicate(
|
raw_paragraphs, triple_list_data = hash_deduplicate(
|
||||||
raw_paragraphs,
|
raw_paragraphs, triple_list_data, embed_manager.stored_pg_hashes, kg_manager.stored_paragraph_hashes
|
||||||
triple_list_data,
|
|
||||||
embed_manager.stored_pg_hashes,
|
|
||||||
kg_manager.stored_paragraph_hashes,
|
|
||||||
)
|
)
|
||||||
if len(raw_paragraphs) != 0:
|
if raw_paragraphs:
|
||||||
# 获取嵌入并保存
|
|
||||||
logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
|
logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
|
||||||
logger.info("开始Embedding")
|
logger.info("开始Embedding")
|
||||||
embed_manager.store_new_data_set(raw_paragraphs, triple_list_data)
|
embed_manager.store_new_data_set(raw_paragraphs, triple_list_data, lock)
|
||||||
# Embedding-Faiss重索引
|
|
||||||
logger.info("正在重新构建向量索引")
|
|
||||||
embed_manager.rebuild_faiss_index()
|
embed_manager.rebuild_faiss_index()
|
||||||
logger.info("向量索引构建完成")
|
|
||||||
embed_manager.save_to_file()
|
embed_manager.save_to_file()
|
||||||
logger.info("Embedding完成")
|
logger.info("Embedding完成")
|
||||||
# 构建新段落的RAG
|
|
||||||
logger.info("开始构建RAG")
|
logger.info("开始构建RAG")
|
||||||
kg_manager.build_kg(triple_list_data, embed_manager)
|
kg_manager.build_kg(triple_list_data, embed_manager)
|
||||||
kg_manager.save_to_file()
|
kg_manager.save_to_file()
|
||||||
|
|
@ -211,75 +83,41 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||||
logger.info("无新段落需要处理")
|
logger.info("无新段落需要处理")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def main():
|
||||||
def main(): # sourcery skip: dict-comprehension
|
manager = Manager()
|
||||||
# 新增确认提示
|
lock = manager.Lock()
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
# ... (user confirmation prompt as before) ...
|
||||||
scan_provider(env_config)
|
|
||||||
print("=== 重要操作确认 ===")
|
ensure_openie_dir()
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
|
||||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
|
||||||
print("推荐使用硅基流动的Pro/BAAI/bge-m3")
|
|
||||||
print("每百万Token费用为0.7元")
|
|
||||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
|
||||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
|
||||||
if confirm != "y":
|
|
||||||
logger.info("用户取消操作")
|
|
||||||
print("操作已取消")
|
|
||||||
sys.exit(1)
|
|
||||||
print("\n" + "=" * 40 + "\n")
|
|
||||||
ensure_openie_dir() # 确保OpenIE目录存在
|
|
||||||
logger.info("----开始导入openie数据----\n")
|
logger.info("----开始导入openie数据----\n")
|
||||||
|
|
||||||
logger.info("创建LLM客户端")
|
logger.info("创建LLM客户端")
|
||||||
|
|
||||||
# 初始化Embedding库
|
embed_manager = EmbeddingManager(lock)
|
||||||
embed_manager = EmbeddingManager()
|
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从文件加载Embedding库时发生错误:{e}")
|
logger.warning(f"加载嵌入库时发生错误 (可忽略): {e}")
|
||||||
if "嵌入模型与本地存储不一致" in str(e):
|
|
||||||
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
|
||||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
|
||||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
|
||||||
sys.exit(1)
|
|
||||||
if "不存在" in str(e):
|
|
||||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
|
||||||
logger.info("Embedding库加载完成")
|
logger.info("Embedding库加载完成")
|
||||||
# 初始化KG
|
|
||||||
kg_manager = KGManager()
|
kg_manager = KGManager()
|
||||||
logger.info("正在从文件加载KG")
|
logger.info("正在从文件加载KG")
|
||||||
try:
|
try:
|
||||||
kg_manager.load_from_file()
|
kg_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从文件加载KG时发生错误:{e}")
|
logger.warning(f"加载KG时发生错误 (可忽略): {e}")
|
||||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
|
||||||
logger.info("KG加载完成")
|
logger.info("KG加载完成")
|
||||||
|
|
||||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
# ... (rest of the main function as before) ...
|
||||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
|
||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
|
||||||
key = f"{local_storage['pg_namespace']}-{pg_hash}"
|
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
|
||||||
|
|
||||||
logger.info("正在导入OpenIE数据文件")
|
|
||||||
try:
|
try:
|
||||||
openie_data = OpenIE.load()
|
openie_data = OpenIE.load()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
||||||
return False
|
return False
|
||||||
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
|
if handle_import_openie(openie_data, embed_manager, kg_manager, lock) is False:
|
||||||
logger.error("处理OpenIE数据时发生错误")
|
logger.error("处理OpenIE数据时发生错误")
|
||||||
return False
|
return False
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
main()
|
||||||
main()
|
|
||||||
|
|
@ -7,59 +7,34 @@ from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# import tqdm
|
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
# from .llm_client import LLMClient
|
|
||||||
# from .lpmmconfig import global_config
|
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
Progress,
|
Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn,
|
||||||
BarColumn,
|
TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn,
|
||||||
TimeElapsedColumn,
|
|
||||||
TimeRemainingColumn,
|
|
||||||
TaskProgressColumn,
|
|
||||||
MofNCompleteColumn,
|
|
||||||
SpinnerColumn,
|
|
||||||
TextColumn,
|
|
||||||
)
|
)
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||||
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
||||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
TOTAL_EMBEDDING_TIMES = 3
|
||||||
|
|
||||||
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
|
|
||||||
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
|
|
||||||
EMBEDDING_TEST_STRINGS = [
|
EMBEDDING_TEST_STRINGS = [
|
||||||
"阿卡伊真的太好玩了,神秘性感大女同等着你",
|
"阿卡伊真的太好玩了,神秘性感大女同等着你", "你怎么知道我arc12.64了", "我是蕾缪乐小姐的狗",
|
||||||
"你怎么知道我arc12.64了",
|
"关注Oct谢谢喵", "不是w6我不草", "关注千石可乐谢谢喵", "来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵",
|
||||||
"我是蕾缪乐小姐的狗",
|
"关注墨梓柒谢谢喵", "Ciallo~", "来玩巧克甜恋谢谢喵", "水印",
|
||||||
"关注Oct谢谢喵",
|
"我也在纠结晚饭,铁锅炒鸡听着就香!", "test你妈喵",
|
||||||
"不是w6我不草",
|
|
||||||
"关注千石可乐谢谢喵",
|
|
||||||
"来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵",
|
|
||||||
"关注墨梓柒谢谢喵",
|
|
||||||
"Ciallo~",
|
|
||||||
"来玩巧克甜恋谢谢喵",
|
|
||||||
"水印",
|
|
||||||
"我也在纠结晚饭,铁锅炒鸡听着就香!",
|
|
||||||
"test你妈喵",
|
|
||||||
]
|
]
|
||||||
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
|
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
|
||||||
EMBEDDING_SIM_THRESHOLD = 0.99
|
EMBEDDING_SIM_THRESHOLD = 0.99
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(a, b):
|
def cosine_similarity(a, b):
|
||||||
# 计算余弦相似度
|
|
||||||
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
||||||
norm_a = math.sqrt(sum(x * x for x in a))
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||||||
norm_b = math.sqrt(sum(x * x for x in b))
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||||||
|
|
@ -67,69 +42,52 @@ def cosine_similarity(a, b):
|
||||||
return 0.0
|
return 0.0
|
||||||
return dot / (norm_a * norm_b)
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingStoreItem:
|
class EmbeddingStoreItem:
|
||||||
"""嵌入库中的项"""
|
|
||||||
|
|
||||||
def __init__(self, item_hash: str, embedding: List[float], content: str):
|
def __init__(self, item_hash: str, embedding: List[float], content: str):
|
||||||
self.hash = item_hash
|
self.hash = item_hash
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
self.str = content
|
self.str = content
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转为dict"""
|
return {"hash": self.hash, "embedding": self.embedding, "str": self.str}
|
||||||
return {
|
|
||||||
"hash": self.hash,
|
|
||||||
"embedding": self.embedding,
|
|
||||||
"str": self.str,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingStore:
|
class EmbeddingStore:
|
||||||
def __init__(self, namespace: str, dir_path: str):
|
def __init__(self, namespace: str, dir_path: str, lock):
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.dir = dir_path
|
self.dir = dir_path
|
||||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||||
|
|
||||||
self.store = {}
|
self.store = {}
|
||||||
|
|
||||||
self.faiss_index = None
|
self.faiss_index = None
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
self.lock = lock
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
"""获取字符串的嵌入向量,处理异步调用"""
|
with self.lock:
|
||||||
try:
|
try:
|
||||||
# 尝试获取当前事件循环
|
asyncio.get_running_loop()
|
||||||
asyncio.get_running_loop()
|
import concurrent.futures
|
||||||
# 如果在事件循环中,使用线程池执行
|
def run_in_thread():
|
||||||
import concurrent.futures
|
return asyncio.run(get_embedding(s))
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
def run_in_thread():
|
future = executor.submit(run_in_thread)
|
||||||
return asyncio.run(get_embedding(s))
|
result = future.result()
|
||||||
|
if result is None:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
future = executor.submit(run_in_thread)
|
return []
|
||||||
result = future.result()
|
return result
|
||||||
|
except RuntimeError:
|
||||||
|
result = asyncio.run(get_embedding(s))
|
||||||
if result is None:
|
if result is None:
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
return []
|
return []
|
||||||
return result
|
return result
|
||||||
except RuntimeError:
|
|
||||||
# 没有运行的事件循环,直接运行
|
|
||||||
result = asyncio.run(get_embedding(s))
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
|
||||||
return []
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
|
|
||||||
def save_embedding_test_vectors(self):
|
def save_embedding_test_vectors(self):
|
||||||
"""保存测试字符串的嵌入到本地"""
|
|
||||||
test_vectors = {}
|
test_vectors = {}
|
||||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
||||||
test_vectors[str(idx)] = self._get_embedding(s)
|
test_vectors[str(idx)] = self._get_embedding(s)
|
||||||
|
|
@ -137,7 +95,6 @@ class EmbeddingStore:
|
||||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
def load_embedding_test_vectors(self):
|
def load_embedding_test_vectors(self):
|
||||||
"""加载本地保存的测试字符串嵌入"""
|
|
||||||
path = self.get_test_file_path()
|
path = self.get_test_file_path()
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return None
|
return None
|
||||||
|
|
@ -145,7 +102,6 @@ class EmbeddingStore:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def check_embedding_model_consistency(self):
|
def check_embedding_model_consistency(self):
|
||||||
"""校验当前模型与本地嵌入模型是否一致"""
|
|
||||||
local_vectors = self.load_embedding_test_vectors()
|
local_vectors = self.load_embedding_test_vectors()
|
||||||
if local_vectors is None:
|
if local_vectors is None:
|
||||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||||
|
|
@ -166,232 +122,114 @@ class EmbeddingStore:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||||
"""向库中存入字符串"""
|
|
||||||
total = len(strs)
|
total = len(strs)
|
||||||
with Progress(
|
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False) as progress:
|
||||||
SpinnerColumn(),
|
|
||||||
TextColumn("[progress.description]{task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
MofNCompleteColumn(),
|
|
||||||
"•",
|
|
||||||
TimeElapsedColumn(),
|
|
||||||
"<",
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
transient=False,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||||
for s in strs:
|
for s in strs:
|
||||||
# 计算hash去重
|
|
||||||
item_hash = self.namespace + "-" + get_sha256(s)
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
if item_hash in self.store:
|
if item_hash in self.store:
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取embedding
|
|
||||||
embedding = self._get_embedding(s)
|
embedding = self._get_embedding(s)
|
||||||
|
|
||||||
# 存入
|
|
||||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
def save_to_file(self) -> None:
|
def save_to_file(self) -> None:
|
||||||
"""保存到文件"""
|
data = [item.to_dict() for item in self.store.values()]
|
||||||
data = []
|
|
||||||
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
|
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
|
||||||
for item in self.store.values():
|
df = pd.DataFrame(data)
|
||||||
data.append(item.to_dict())
|
os.makedirs(self.dir, exist_ok=True)
|
||||||
data_frame = pd.DataFrame(data)
|
df.to_parquet(self.embedding_file_path, engine="pyarrow", index=False)
|
||||||
|
|
||||||
if not os.path.exists(self.dir):
|
|
||||||
os.makedirs(self.dir, exist_ok=True)
|
|
||||||
if not os.path.exists(self.embedding_file_path):
|
|
||||||
open(self.embedding_file_path, "w").close()
|
|
||||||
|
|
||||||
data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False)
|
|
||||||
logger.info(f"{self.namespace}嵌入库保存成功")
|
logger.info(f"{self.namespace}嵌入库保存成功")
|
||||||
|
|
||||||
if self.faiss_index is not None and self.idx2hash is not None:
|
if self.faiss_index is not None and self.idx2hash is not None:
|
||||||
logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
|
logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
|
||||||
faiss.write_index(self.faiss_index, self.index_file_path)
|
faiss.write_index(self.faiss_index, self.index_file_path)
|
||||||
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
|
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
|
||||||
logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
|
logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
|
||||||
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
|
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
|
json.dump(self.idx2hash, f, ensure_ascii=False, indent=4)
|
||||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
|
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
|
||||||
|
|
||||||
def load_from_file(self) -> None:
|
def load_from_file(self) -> None:
|
||||||
"""从文件中加载"""
|
|
||||||
if not os.path.exists(self.embedding_file_path):
|
if not os.path.exists(self.embedding_file_path):
|
||||||
raise Exception(f"文件{self.embedding_file_path}不存在")
|
raise FileNotFoundError(f"文件{self.embedding_file_path}不存在")
|
||||||
logger.info("正在加载嵌入库...")
|
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
||||||
logger.debug(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
df = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
||||||
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
for _, row in df.iterrows():
|
||||||
total = len(data_frame)
|
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
|
||||||
with Progress(
|
|
||||||
SpinnerColumn(),
|
|
||||||
TextColumn("[progress.description]{task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
MofNCompleteColumn(),
|
|
||||||
"•",
|
|
||||||
TimeElapsedColumn(),
|
|
||||||
"<",
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
transient=False,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task("加载嵌入库", total=total)
|
|
||||||
for _, row in data_frame.iterrows():
|
|
||||||
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
|
|
||||||
progress.update(task, advance=1)
|
|
||||||
logger.info(f"{self.namespace}嵌入库加载成功")
|
logger.info(f"{self.namespace}嵌入库加载成功")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(self.index_file_path):
|
if os.path.exists(self.index_file_path) and os.path.exists(self.idx2hash_file_path):
|
||||||
logger.info(f"正在加载{self.namespace}嵌入库的FaissIndex...")
|
|
||||||
logger.debug(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
|
|
||||||
self.faiss_index = faiss.read_index(self.index_file_path)
|
self.faiss_index = faiss.read_index(self.index_file_path)
|
||||||
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
|
|
||||||
else:
|
|
||||||
raise Exception(f"文件{self.index_file_path}不存在")
|
|
||||||
if os.path.exists(self.idx2hash_file_path):
|
|
||||||
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
|
|
||||||
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
|
|
||||||
with open(self.idx2hash_file_path, "r") as f:
|
with open(self.idx2hash_file_path, "r") as f:
|
||||||
self.idx2hash = json.load(f)
|
self.idx2hash = json.load(f)
|
||||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
logger.info(f"{self.namespace}嵌入库的FaissIndex和idx2hash加载成功")
|
||||||
else:
|
else:
|
||||||
raise Exception(f"文件{self.idx2hash_file_path}不存在")
|
raise FileNotFoundError("Faiss index or idx2hash file not found.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载{self.namespace}嵌入库的FaissIndex时发生错误:{e}")
|
logger.warning(f"加载FaissIndex失败 ({e}),正在重建...")
|
||||||
logger.warning("正在重建Faiss索引")
|
|
||||||
self.build_faiss_index()
|
self.build_faiss_index()
|
||||||
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
|
|
||||||
self.save_to_file()
|
self.save_to_file()
|
||||||
|
|
||||||
def build_faiss_index(self) -> None:
|
def build_faiss_index(self) -> None:
|
||||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
embeddings = np.array([item.embedding for item in self.store.values()], dtype=np.float32)
|
||||||
# 获取所有的embedding
|
if embeddings.shape[0] == 0:
|
||||||
array = []
|
return
|
||||||
self.idx2hash = dict()
|
|
||||||
for key in self.store:
|
|
||||||
array.append(self.store[key].embedding)
|
|
||||||
self.idx2hash[str(len(array) - 1)] = key
|
|
||||||
embeddings = np.array(array, dtype=np.float32)
|
|
||||||
# L2归一化
|
|
||||||
faiss.normalize_L2(embeddings)
|
faiss.normalize_L2(embeddings)
|
||||||
# 构建索引
|
|
||||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||||
self.faiss_index.add(embeddings)
|
self.faiss_index.add(embeddings)
|
||||||
|
self.idx2hash = {str(i): h for i, h in enumerate(self.store.keys())}
|
||||||
|
|
||||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||||
"""搜索最相似的k个项,以余弦相似度为度量
|
if self.faiss_index is None: return []
|
||||||
Args:
|
query_np = np.array([query], dtype=np.float32)
|
||||||
query: 查询的embedding
|
faiss.normalize_L2(query_np)
|
||||||
k: 返回的最相似的k个项
|
distances, indices = self.faiss_index.search(query_np, k)
|
||||||
Returns:
|
return [(self.idx2hash[str(int(idx))], float(dist)) for idx, dist in zip(indices[0], distances[0]) if str(int(idx)) in self.idx2hash]
|
||||||
result: 最相似的k个项的(hash, 余弦相似度)列表
|
|
||||||
"""
|
|
||||||
if self.faiss_index is None:
|
|
||||||
logger.debug("FaissIndex尚未构建,返回None")
|
|
||||||
return []
|
|
||||||
if self.idx2hash is None:
|
|
||||||
logger.warning("idx2hash尚未构建,返回None")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# L2归一化
|
|
||||||
faiss.normalize_L2(np.array([query], dtype=np.float32))
|
|
||||||
# 搜索
|
|
||||||
distances, indices = self.faiss_index.search(np.array([query]), k)
|
|
||||||
# 整理结果
|
|
||||||
indices = list(indices.flatten())
|
|
||||||
distances = list(distances.flatten())
|
|
||||||
result = [
|
|
||||||
(self.idx2hash[str(int(idx))], float(sim))
|
|
||||||
for (idx, sim) in zip(indices, distances, strict=False)
|
|
||||||
if idx in range(len(self.idx2hash))
|
|
||||||
]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self):
|
def __init__(self, lock):
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.lock = lock
|
||||||
local_storage["pg_namespace"], # type: ignore
|
self.paragraphs_embedding_store = EmbeddingStore(local_storage["pg_namespace"], EMBEDDING_DATA_DIR_STR, self.lock)
|
||||||
EMBEDDING_DATA_DIR_STR,
|
self.entities_embedding_store = EmbeddingStore(local_storage["ent_namespace"], EMBEDDING_DATA_DIR_STR, self.lock)
|
||||||
)
|
self.relation_embedding_store = EmbeddingStore(local_storage["rel_namespace"], EMBEDDING_DATA_DIR_STR, self.lock)
|
||||||
self.entities_embedding_store = EmbeddingStore(
|
|
||||||
local_storage["pg_namespace"], # type: ignore
|
|
||||||
EMBEDDING_DATA_DIR_STR,
|
|
||||||
)
|
|
||||||
self.relation_embedding_store = EmbeddingStore(
|
|
||||||
local_storage["pg_namespace"], # type: ignore
|
|
||||||
EMBEDDING_DATA_DIR_STR,
|
|
||||||
)
|
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|
||||||
def check_all_embedding_model_consistency(self):
|
def check_all_embedding_model_consistency(self):
|
||||||
"""对所有嵌入库做模型一致性校验"""
|
return all(store.check_embedding_model_consistency() for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store])
|
||||||
for store in [
|
|
||||||
self.paragraphs_embedding_store,
|
|
||||||
self.entities_embedding_store,
|
|
||||||
self.relation_embedding_store,
|
|
||||||
]:
|
|
||||||
if not store.check_embedding_model_consistency():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||||
"""将段落编码存入Embedding库"""
|
|
||||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||||
|
|
||||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||||
"""将实体编码存入Embedding库"""
|
entities = {triple[i] for triple_list in triple_list_data.values() for triple in triple_list for i in (0, 2)}
|
||||||
entities = set()
|
|
||||||
for triple_list in triple_list_data.values():
|
|
||||||
for triple in triple_list:
|
|
||||||
entities.add(triple[0])
|
|
||||||
entities.add(triple[2])
|
|
||||||
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||||
|
|
||||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||||
"""将关系编码存入Embedding库"""
|
graph_triples = list({tuple(t) for triples in triple_list_data.values() for t in triples})
|
||||||
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
|
|
||||||
for triples in triple_list_data.values():
|
|
||||||
graph_triples.extend([tuple(t) for t in triples])
|
|
||||||
graph_triples = list(set(graph_triples))
|
|
||||||
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
|
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
|
||||||
|
|
||||||
def load_from_file(self):
|
def load_from_file(self):
|
||||||
"""从文件加载"""
|
for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]:
|
||||||
self.paragraphs_embedding_store.load_from_file()
|
try:
|
||||||
self.entities_embedding_store.load_from_file()
|
store.load_from_file()
|
||||||
self.relation_embedding_store.load_from_file()
|
except Exception:
|
||||||
# 从段落库中获取已存储的hash
|
pass
|
||||||
self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys())
|
self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys())
|
||||||
|
|
||||||
def store_new_data_set(
|
def store_new_data_set(self, raw_paragraphs: Dict[str, str], triple_list_data: Dict[str, List[List[str]]], lock):
|
||||||
self,
|
|
||||||
raw_paragraphs: Dict[str, str],
|
|
||||||
triple_list_data: Dict[str, List[List[str]]],
|
|
||||||
):
|
|
||||||
if not self.check_all_embedding_model_consistency():
|
if not self.check_all_embedding_model_consistency():
|
||||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||||
"""存储新的数据集"""
|
|
||||||
self._store_pg_into_embedding(raw_paragraphs)
|
self._store_pg_into_embedding(raw_paragraphs)
|
||||||
self._store_ent_into_embedding(triple_list_data)
|
self._store_ent_into_embedding(triple_list_data)
|
||||||
self._store_rel_into_embedding(triple_list_data)
|
self._store_rel_into_embedding(triple_list_data)
|
||||||
self.stored_pg_hashes.update(raw_paragraphs.keys())
|
self.stored_pg_hashes.update(raw_paragraphs.keys())
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
"""保存到文件"""
|
for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]:
|
||||||
self.paragraphs_embedding_store.save_to_file()
|
store.save_to_file()
|
||||||
self.entities_embedding_store.save_to_file()
|
|
||||||
self.relation_embedding_store.save_to_file()
|
|
||||||
|
|
||||||
def rebuild_faiss_index(self):
|
def rebuild_faiss_index(self):
|
||||||
"""重建Faiss索引(请在添加新数据后调用)"""
|
for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]:
|
||||||
self.paragraphs_embedding_store.build_faiss_index()
|
store.build_faiss_index()
|
||||||
self.entities_embedding_store.build_faiss_index()
|
# --- END OF FILE src/chat/knowledge/embedding_store.py ---
|
||||||
self.relation_embedding_store.build_faiss_index()
|
|
||||||
|
|
@ -8,133 +8,89 @@ from src.chat.knowledge.global_logger import logger
|
||||||
from src.config.config import global_config as bot_global_config
|
from src.config.config import global_config as bot_global_config
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
import os
|
import os
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
INVALID_ENTITY = [
|
INVALID_ENTITY = [
|
||||||
"",
|
"", "你", "他", "她", "它", "我们", "你们", "他们", "她们", "它们",
|
||||||
"你",
|
|
||||||
"他",
|
|
||||||
"她",
|
|
||||||
"它",
|
|
||||||
"我们",
|
|
||||||
"你们",
|
|
||||||
"他们",
|
|
||||||
"她们",
|
|
||||||
"它们",
|
|
||||||
]
|
]
|
||||||
PG_NAMESPACE = "paragraph"
|
PG_NAMESPACE = "paragraph"
|
||||||
ENT_NAMESPACE = "entity"
|
ENT_NAMESPACE = "entity"
|
||||||
REL_NAMESPACE = "relation"
|
REL_NAMESPACE = "relation"
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||||
|
|
||||||
|
|
||||||
def _initialize_knowledge_local_storage():
|
def _initialize_knowledge_local_storage():
|
||||||
"""
|
|
||||||
初始化知识库相关的本地存储配置
|
|
||||||
使用字典批量设置,避免重复的if判断
|
|
||||||
"""
|
|
||||||
# 定义所有需要初始化的配置项
|
|
||||||
default_configs = {
|
default_configs = {
|
||||||
# 路径配置
|
|
||||||
"root_path": ROOT_PATH,
|
"root_path": ROOT_PATH,
|
||||||
"data_path": f"{ROOT_PATH}/data",
|
"data_path": f"{ROOT_PATH}/data",
|
||||||
# 实体和命名空间配置
|
|
||||||
"lpmm_invalid_entity": INVALID_ENTITY,
|
"lpmm_invalid_entity": INVALID_ENTITY,
|
||||||
"pg_namespace": PG_NAMESPACE,
|
"pg_namespace": PG_NAMESPACE,
|
||||||
"ent_namespace": ENT_NAMESPACE,
|
"ent_namespace": ENT_NAMESPACE,
|
||||||
"rel_namespace": REL_NAMESPACE,
|
"rel_namespace": REL_NAMESPACE,
|
||||||
# RAG相关命名空间配置
|
|
||||||
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
|
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
|
||||||
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
|
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
|
||||||
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
|
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 日志级别映射:重要配置用info,其他用debug
|
|
||||||
important_configs = {"root_path", "data_path"}
|
important_configs = {"root_path", "data_path"}
|
||||||
|
|
||||||
# 批量设置配置项
|
|
||||||
initialized_count = 0
|
initialized_count = 0
|
||||||
for key, default_value in default_configs.items():
|
for key, default_value in default_configs.items():
|
||||||
if local_storage[key] is None:
|
if local_storage[key] is None:
|
||||||
local_storage[key] = default_value
|
local_storage[key] = default_value
|
||||||
|
|
||||||
# 根据重要性选择日志级别
|
|
||||||
if key in important_configs:
|
if key in important_configs:
|
||||||
logger.info(f"设置{key}: {default_value}")
|
logger.info(f"设置{key}: {default_value}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"设置{key}: {default_value}")
|
logger.debug(f"设置{key}: {default_value}")
|
||||||
|
|
||||||
initialized_count += 1
|
initialized_count += 1
|
||||||
|
|
||||||
if initialized_count > 0:
|
if initialized_count > 0:
|
||||||
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
||||||
else:
|
else:
|
||||||
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
||||||
|
|
||||||
|
|
||||||
# 初始化本地存储路径
|
|
||||||
# sourcery skip: dict-comprehension
|
|
||||||
_initialize_knowledge_local_storage()
|
_initialize_knowledge_local_storage()
|
||||||
|
|
||||||
qa_manager = None
|
qa_manager = None
|
||||||
inspire_manager = None
|
inspire_manager = None
|
||||||
|
|
||||||
# 检查LPMM知识库是否启用
|
|
||||||
if bot_global_config.lpmm_knowledge.enable:
|
if bot_global_config.lpmm_knowledge.enable:
|
||||||
logger.info("正在初始化Mai-LPMM")
|
logger.info("正在初始化Mai-LPMM")
|
||||||
logger.info("创建LLM客户端")
|
logger.info("创建LLM客户端")
|
||||||
llm_client_list = {}
|
llm_client_list = {}
|
||||||
for key in global_config["llm_providers"]:
|
for key in global_config["llm_providers"]:
|
||||||
llm_client_list[key] = LLMClient(
|
llm_client_list[key] = LLMClient(
|
||||||
global_config["llm_providers"][key]["base_url"], # type: ignore
|
global_config["llm_providers"][key]["base_url"],
|
||||||
global_config["llm_providers"][key]["api_key"], # type: ignore
|
global_config["llm_providers"][key]["api_key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化Embedding库
|
manager = Manager()
|
||||||
embed_manager = EmbeddingManager()
|
lock = manager.Lock()
|
||||||
|
|
||||||
|
embed_manager = EmbeddingManager(lock)
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
||||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
|
||||||
logger.info("Embedding库加载完成")
|
logger.info("Embedding库加载完成")
|
||||||
# 初始化KG
|
|
||||||
kg_manager = KGManager()
|
kg_manager = KGManager()
|
||||||
logger.info("正在从文件加载KG")
|
logger.info("正在从文件加载KG")
|
||||||
try:
|
try:
|
||||||
kg_manager.load_from_file()
|
kg_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
||||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
|
||||||
logger.info("KG加载完成")
|
logger.info("KG加载完成")
|
||||||
|
|
||||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
key = f"{PG_NAMESPACE}-{pg_hash}"
|
key = f"{PG_NAMESPACE}-{pg_hash}"
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
|
||||||
# 问答系统(用于知识库)
|
qa_manager = QAManager(embed_manager, kg_manager)
|
||||||
qa_manager = QAManager(
|
inspire_manager = MemoryActiveManager(embed_manager, llm_client_list[global_config["embedding"]["provider"]])
|
||||||
embed_manager,
|
|
||||||
kg_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记忆激活(用于记忆库)
|
|
||||||
inspire_manager = MemoryActiveManager(
|
|
||||||
embed_manager,
|
|
||||||
llm_client_list[global_config["embedding"]["provider"]],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
logger.info("LPMM知识库已禁用,跳过初始化")
|
||||||
# 创建空的占位符对象,避免导入错误
|
|
||||||
Loading…
Reference in New Issue