mirror of https://github.com/Mai-with-u/MaiBot.git
fix: 修复导入知识库的拆分文本、加载env变量、导入openie3个问题
parent
c65f9a572b
commit
bf98434d3b
|
|
@ -200,8 +200,26 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||||
logger.info("正在重新构建向量索引")
|
logger.info("正在重新构建向量索引")
|
||||||
embed_manager.rebuild_faiss_index()
|
embed_manager.rebuild_faiss_index()
|
||||||
logger.info("向量索引构建完成")
|
logger.info("向量索引构建完成")
|
||||||
|
# 对实体进行嵌入并存储
|
||||||
embed_manager.save_to_file()
|
embed_manager.save_to_file()
|
||||||
logger.info("Embedding完成")
|
logger.info("段落Embedding完成")
|
||||||
|
|
||||||
|
# 新增实体Embedding
|
||||||
|
all_entities = set()
|
||||||
|
for entity_list in entity_list_data.values():
|
||||||
|
all_entities.update(entity_list)
|
||||||
|
|
||||||
|
entity_hash_map = {
|
||||||
|
get_sha256(entity): entity
|
||||||
|
for entity in all_entities
|
||||||
|
}
|
||||||
|
|
||||||
|
if entity_hash_map:
|
||||||
|
logger.info(f"开始对 {len(entity_hash_map)} 个实体生成向量嵌入")
|
||||||
|
embed_manager.store_entity_data_set(entity_hash_map)
|
||||||
|
embed_manager.save_to_file()
|
||||||
|
logger.info("实体嵌入完成")
|
||||||
|
|
||||||
# 构建新段落的RAG
|
# 构建新段落的RAG
|
||||||
logger.info("开始构建RAG")
|
logger.info("开始构建RAG")
|
||||||
kg_manager.build_kg(triple_list_data, embed_manager)
|
kg_manager.build_kg(triple_list_data, embed_manager)
|
||||||
|
|
|
||||||
|
|
@ -72,4 +72,31 @@ def load_raw_data() -> tuple[list[str], list[str]]:
|
||||||
raw_data.append(item)
|
raw_data.append(item)
|
||||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||||
|
|
||||||
return sha256_list, raw_data
|
return sha256_list, raw_data
|
||||||
|
|
||||||
|
# chatgpt
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sha256_list, raw_data = load_raw_data()
|
||||||
|
|
||||||
|
# 构造导出路径
|
||||||
|
output_dir = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
now = datetime.now().strftime("%m-%d-%H-%M")
|
||||||
|
output_path = os.path.join(output_dir, f"{now}-imported-data.json")
|
||||||
|
|
||||||
|
# 写入 JSON 文件
|
||||||
|
# with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
# json.dump({
|
||||||
|
# "sha256_list": sha256_list,
|
||||||
|
# "raw_data": raw_data
|
||||||
|
# }, f, ensure_ascii=False, indent=2)
|
||||||
|
# 上面那些是AI写的,我看了下旧版生成的文件格式,实际上只保存[raw_data]就行,后面的info_extraction.py和import_openie.py会正常运行
|
||||||
|
# 下面这段只保存 raw_data,和旧版保存的文件格式一致
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(raw_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"数据处理完成,已写入:{output_path}")
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import sys # 新增系统模块导入
|
||||||
|
import datetime # 新增导入
|
||||||
|
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.knowledge.lpmmconfig import global_config
|
||||||
|
|
||||||
|
logger = get_logger("lpmm")
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||||
|
# 新增:确保 RAW_DATA_PATH 存在
|
||||||
|
if not os.path.exists(RAW_DATA_PATH):
|
||||||
|
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||||
|
logger.info(f"已创建目录: {RAW_DATA_PATH}")
|
||||||
|
|
||||||
|
if global_config.get("persistence", {}).get("raw_data_path") is not None:
|
||||||
|
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
|
||||||
|
else:
|
||||||
|
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
|
|
||||||
|
# 添加项目根目录到 sys.path
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_create_dirs():
|
||||||
|
"""检查并创建必要的目录"""
|
||||||
|
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
|
||||||
|
|
||||||
|
for dir_path in required_dirs:
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
logger.info(f"已创建目录: {dir_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def process_text_file(file_path):
|
||||||
|
"""处理单个文本文件,返回段落列表"""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
raw = f.read()
|
||||||
|
|
||||||
|
paragraphs = []
|
||||||
|
paragraph = ""
|
||||||
|
for line in raw.split("\n"):
|
||||||
|
if line.strip() == "":
|
||||||
|
if paragraph != "":
|
||||||
|
paragraphs.append(paragraph.strip())
|
||||||
|
paragraph = ""
|
||||||
|
else:
|
||||||
|
paragraph += line + "\n"
|
||||||
|
|
||||||
|
if paragraph != "":
|
||||||
|
paragraphs.append(paragraph.strip())
|
||||||
|
|
||||||
|
return paragraphs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 新增用户确认提示
|
||||||
|
print("=== 数据预处理脚本 ===")
|
||||||
|
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
|
||||||
|
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
|
||||||
|
print("请确保原始数据已放置在正确的目录中。")
|
||||||
|
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||||
|
if confirm != "y":
|
||||||
|
logger.info("操作已取消")
|
||||||
|
sys.exit(1)
|
||||||
|
print("\n" + "=" * 40 + "\n")
|
||||||
|
|
||||||
|
# 检查并创建必要的目录
|
||||||
|
check_and_create_dirs()
|
||||||
|
|
||||||
|
# # 检查输出文件是否存在
|
||||||
|
# if os.path.exists(RAW_DATA_PATH):
|
||||||
|
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
||||||
|
# sys.exit(1)
|
||||||
|
|
||||||
|
# if os.path.exists(RAW_DATA_PATH):
|
||||||
|
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
||||||
|
# sys.exit(1)
|
||||||
|
|
||||||
|
# 获取所有原始文本文件
|
||||||
|
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||||
|
if not raw_files:
|
||||||
|
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 处理所有文件
|
||||||
|
all_paragraphs = []
|
||||||
|
for file in raw_files:
|
||||||
|
logger.info(f"正在处理文件: {file.name}")
|
||||||
|
paragraphs = process_text_file(file)
|
||||||
|
all_paragraphs.extend(paragraphs)
|
||||||
|
|
||||||
|
# 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
|
||||||
|
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
logger.info(f"处理完成,结果已保存到: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
|
||||||
|
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
|
||||||
|
main()
|
||||||
|
|
@ -29,7 +29,7 @@ from rich.progress import (
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
|
@ -317,18 +317,31 @@ class EmbeddingStore:
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
|
# local_storage["pg_namespace"] + "-paragraph", # 不再共用一个命名空间
|
||||||
|
# EMBEDDING_DATA_DIR_STR,
|
||||||
|
# )
|
||||||
|
# self.entities_embedding_store = EmbeddingStore(
|
||||||
|
# local_storage["pg_namespace"] + "-entity",
|
||||||
|
# EMBEDDING_DATA_DIR_STR,
|
||||||
|
# )
|
||||||
|
# self.relation_embedding_store = EmbeddingStore(
|
||||||
|
# local_storage["pg_namespace"] + "-relation",
|
||||||
|
# EMBEDDING_DATA_DIR_STR,
|
||||||
|
# )
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"paragraph", # 不再共用一个命名空间
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.entities_embedding_store = EmbeddingStore(
|
self.entities_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"entity",
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.relation_embedding_store = EmbeddingStore(
|
self.relation_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"relation",
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|
||||||
def check_all_embedding_model_consistency(self):
|
def check_all_embedding_model_consistency(self):
|
||||||
|
|
@ -395,3 +408,34 @@ class EmbeddingManager:
|
||||||
self.paragraphs_embedding_store.build_faiss_index()
|
self.paragraphs_embedding_store.build_faiss_index()
|
||||||
self.entities_embedding_store.build_faiss_index()
|
self.entities_embedding_store.build_faiss_index()
|
||||||
self.relation_embedding_store.build_faiss_index()
|
self.relation_embedding_store.build_faiss_index()
|
||||||
|
|
||||||
|
def store_entity_data_set(self, entity_hash_map: dict[str, str]) -> None:
|
||||||
|
"""为每个实体生成嵌入并保存,entity_hash_map: {hash: text}"""
|
||||||
|
texts = list(entity_hash_map.values())
|
||||||
|
hashes = list(entity_hash_map.keys())
|
||||||
|
|
||||||
|
logger.info(f"正在生成 {len(texts)} 个实体的嵌入向量...")
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
"•",
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"<",
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
transient=False,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("嵌入实体", total=len(texts))
|
||||||
|
for hash_id, text in zip(hashes, texts):
|
||||||
|
item_hash = self.entities_embedding_store.namespace + "-entity-" + hash_id
|
||||||
|
if item_hash in self.entities_embedding_store.store:
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
continue
|
||||||
|
embedding = self.entities_embedding_store._get_embedding(text)
|
||||||
|
self.entities_embedding_store.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, text)
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
|
logger.info(f"成功嵌入实体数:{len(self.entities_embedding_store.store)}")
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.tcp_connector import get_tcp_connector
|
from src.common.tcp_connector import get_tcp_connector
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
@ -121,8 +123,10 @@ class LLMRequest:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# print(f"model['provider']: {model['provider']}")
|
# print(f"model['provider']: {model['provider']}")
|
||||||
self.api_key = os.environ[f"{model['provider']}_KEY"]
|
# self.api_key = os.environ[f"{model['provider']}_KEY"] # 这是原来的写法,没法读取变量
|
||||||
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
# self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
||||||
|
self.api_key = os.environ.get(f"{model['provider']}_KEY") # 改成这种写法,能读取变量
|
||||||
|
self.base_url = os.environ.get(f"{model['provider']}_BASE_URL")
|
||||||
logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
|
logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
logger.error(f"原始 model dict 信息:{model}")
|
logger.error(f"原始 model dict 信息:{model}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue