MaiBot/src/chat/knowledge/embedding_store.py

320 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from dataclasses import dataclass
import json
import os
import math
import asyncio
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import faiss
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding
from src.config.config import global_config
install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
TOTAL_EMBEDDING_TIMES = 3
EMBEDDING_TEST_STRINGS = [
"阿卡伊真的太好玩了,神秘性感大女同等着你",
"你怎么知道我arc12.64了",
"我是蕾缪乐小姐的狗",
"关注Oct谢谢喵",
"不是w6我不草",
"关注千石可乐谢谢喵",
"来玩CLANNADAIR樱之诗樱之刻谢谢喵",
"关注墨梓柒谢谢喵",
"Ciallo~",
"来玩巧克甜恋谢谢喵",
"水印",
"我也在纠结晚饭,铁锅炒鸡听着就香!",
"test你妈喵",
]
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
EMBEDDING_SIM_THRESHOLD = 0.99
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b, strict=False))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
@dataclass
class EmbeddingStoreItem:
def __init__(self, item_hash: str, embedding: List[float], content: str):
self.hash = item_hash
self.embedding = embedding
self.str = content
def to_dict(self) -> dict:
return {
"hash": self.hash,
"embedding": self.embedding,
"str": self.str,
}
class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str, lock):
self.namespace = namespace
self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
self.store = {}
self.faiss_index = None
self.idx2hash = None
self.lock = lock
def _get_embedding(self, s: str) -> List[float]:
with self.lock:
try:
asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except RuntimeError:
result = asyncio.run(get_embedding(s))
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
def get_test_file_path(self):
return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self):
test_vectors = {}
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
def load_embedding_test_vectors(self):
path = self.get_test_file_path()
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def check_embedding_model_consistency(self):
local_vectors = self.load_embedding_test_vectors()
if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors()
return True
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
local_emb = local_vectors.get(str(idx))
if local_emb is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors()
return True
new_emb = self._get_embedding(s)
sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD:
logger.error("嵌入模型一致性校验失败")
return False
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None:
total = len(strs)
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
progress.update(task, advance=1)
continue
embedding = self._get_embedding(s)
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1)
def save_to_file(self) -> None:
data = []
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
for item in self.store.values():
data.append(item.to_dict())
data_frame = pd.DataFrame(data)
if not os.path.exists(self.dir):
os.makedirs(self.dir, exist_ok=True)
if not os.path.exists(self.embedding_file_path):
open(self.embedding_file_path, "w").close()
data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False)
logger.info(f"{self.namespace}嵌入库保存成功")
if self.faiss_index is not None and self.idx2hash is not None:
logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
faiss.write_index(self.faiss_index, self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
def load_from_file(self) -> None:
if not os.path.exists(self.embedding_file_path):
raise Exception(f"文件{self.embedding_file_path}不存在")
logger.info("正在加载嵌入库...")
logger.debug(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
total = len(data_frame)
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False) as progress:
task = progress.add_task("加载嵌入库", total=total)
for _, row in data_frame.iterrows():
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
progress.update(task, advance=1)
logger.info(f"{self.namespace}嵌入库加载成功")
try:
if os.path.exists(self.index_file_path):
logger.info(f"正在加载{self.namespace}嵌入库的FaissIndex...")
logger.debug(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
self.faiss_index = faiss.read_index(self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
else:
raise Exception(f"文件{self.index_file_path}不存在")
if os.path.exists(self.idx2hash_file_path):
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
with open(self.idx2hash_file_path, "r") as f:
self.idx2hash = json.load(f)
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
else:
raise Exception(f"文件{self.idx2hash_file_path}不存在")
except Exception as e:
logger.error(f"加载{self.namespace}嵌入库的FaissIndex时发生错误{e}")
logger.warning("正在重建Faiss索引")
self.build_faiss_index()
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
self.save_to_file()
def build_faiss_index(self) -> None:
array = []
self.idx2hash = dict()
for key in self.store:
array.append(self.store[key].embedding)
self.idx2hash[str(len(array) - 1)] = key
embeddings = np.array(array, dtype=np.float32)
faiss.normalize_L2(embeddings)
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
if self.faiss_index is None:
logger.debug("FaissIndex尚未构建,返回None")
return []
if self.idx2hash is None:
logger.warning("idx2hash尚未构建,返回None")
return []
faiss.normalize_L2(np.array([query], dtype=np.float32))
distances, indices = self.faiss_index.search(np.array([query]), k)
indices = list(indices.flatten())
distances = list(distances.flatten())
result = [
(self.idx2hash[str(int(idx))], float(sim))
for (idx, sim) in zip(indices, distances, strict=False)
if idx in range(len(self.idx2hash))
]
return result
class EmbeddingManager:
def __init__(self, lock):
self.lock = lock
self.paragraphs_embedding_store = EmbeddingStore(
local_storage["pg_namespace"],
EMBEDDING_DATA_DIR_STR,
self.lock,
)
self.entities_embedding_store = EmbeddingStore(
local_storage["ent_namespace"],
EMBEDDING_DATA_DIR_STR,
self.lock,
)
self.relation_embedding_store = EmbeddingStore(
local_storage["rel_namespace"],
EMBEDDING_DATA_DIR_STR,
self.lock,
)
self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self):
for store in [self.paragraphs_embedding_store, self.entities_embedding_store, self.relation_embedding_store]:
if not store.check_embedding_model_consistency():
return False
return True
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
entities = set()
for triple_list in triple_list_data.values():
for triple in triple_list:
entities.add(triple[0])
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
graph_triples = []
for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self):
try:
self.paragraphs_embedding_store.load_from_file()
except Exception: pass
try:
self.entities_embedding_store.load_from_file()
except Exception: pass
try:
self.relation_embedding_store.load_from_file()
except Exception: pass
self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys())
def store_new_data_set(
self,
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
lock
):
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
self._store_pg_into_embedding(raw_paragraphs)
self._store_ent_into_embedding(triple_list_data)
self._store_rel_into_embedding(triple_list_data)
self.stored_pg_hashes.update(raw_paragraphs.keys())
def save_to_file(self):
self.paragraphs_embedding_store.save_to_file()
self.entities_embedding_store.save_to_file()
self.relation_embedding_store.save_to_file()
def rebuild_faiss_index(self):
self.paragraphs_embedding_store.build_faiss_index()
self.entities_embedding_store.build_faiss_index()
self.relation_embedding_store.build_faiss_index()