fix: 修复导入知识库的拆分文本、加载env变量、导入openie3个问题

pull/1142/head
looom 2025-07-27 00:28:48 +08:00
parent c65f9a572b
commit bf98434d3b
No known key found for this signature in database
GPG Key ID: 08E6AEA38FD37CA3
5 changed files with 209 additions and 8 deletions

View File

@ -200,8 +200,26 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
logger.info("正在重新构建向量索引") logger.info("正在重新构建向量索引")
embed_manager.rebuild_faiss_index() embed_manager.rebuild_faiss_index()
logger.info("向量索引构建完成") logger.info("向量索引构建完成")
# 对实体进行嵌入并存储
embed_manager.save_to_file() embed_manager.save_to_file()
logger.info("Embedding完成") logger.info("段落Embedding完成")
# 新增实体Embedding
all_entities = set()
for entity_list in entity_list_data.values():
all_entities.update(entity_list)
entity_hash_map = {
get_sha256(entity): entity
for entity in all_entities
}
if entity_hash_map:
logger.info(f"开始对 {len(entity_hash_map)} 个实体生成向量嵌入")
embed_manager.store_entity_data_set(entity_hash_map)
embed_manager.save_to_file()
logger.info("实体嵌入完成")
# 构建新段落的RAG # 构建新段落的RAG
logger.info("开始构建RAG") logger.info("开始构建RAG")
kg_manager.build_kg(triple_list_data, embed_manager) kg_manager.build_kg(triple_list_data, embed_manager)

View File

@ -73,3 +73,30 @@ def load_raw_data() -> tuple[list[str], list[str]]:
logger.info(f"共读取到{len(raw_data)}条数据") logger.info(f"共读取到{len(raw_data)}条数据")
return sha256_list, raw_data return sha256_list, raw_data
# chatgpt
import json
from datetime import datetime
if __name__ == "__main__":
sha256_list, raw_data = load_raw_data()
# 构造导出路径
output_dir = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
os.makedirs(output_dir, exist_ok=True)
now = datetime.now().strftime("%m-%d-%H-%M")
output_path = os.path.join(output_dir, f"{now}-imported-data.json")
# 写入 JSON 文件
# with open(output_path, "w", encoding="utf-8") as f:
# json.dump({
# "sha256_list": sha256_list,
# "raw_data": raw_data
# }, f, ensure_ascii=False, indent=2)
# 上面那些是AI写的我看了下旧版生成的文件格式实际上只保存[raw_data]就行后面的info_extraction.py和import_openie.py会正常运行
# 下面这段只保存 raw_data和旧版保存的文件格式一致
with open(output_path, "w", encoding="utf-8") as f:
json.dump(raw_data, f, ensure_ascii=False, indent=2)
logger.info(f"数据处理完成,已写入:{output_path}")

View File

@ -0,0 +1,108 @@
import json
import os
from pathlib import Path
import sys # 新增系统模块导入
import datetime # 新增导入
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
from src.chat.knowledge.lpmmconfig import global_config
logger = get_logger("lpmm")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# 新增:确保 RAW_DATA_PATH 存在
if not os.path.exists(RAW_DATA_PATH):
os.makedirs(RAW_DATA_PATH, exist_ok=True)
logger.info(f"已创建目录: {RAW_DATA_PATH}")
if global_config.get("persistence", {}).get("raw_data_path") is not None:
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
else:
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
# 添加项目根目录到 sys.path
def check_and_create_dirs():
"""检查并创建必要的目录"""
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
for dir_path in required_dirs:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
logger.info(f"已创建目录: {dir_path}")
def process_text_file(file_path):
"""处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f:
raw = f.read()
paragraphs = []
paragraph = ""
for line in raw.split("\n"):
if line.strip() == "":
if paragraph != "":
paragraphs.append(paragraph.strip())
paragraph = ""
else:
paragraph += line + "\n"
if paragraph != "":
paragraphs.append(paragraph.strip())
return paragraphs
def main():
# 新增用户确认提示
print("=== 数据预处理脚本 ===")
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
print("请确保原始数据已放置在正确的目录中。")
confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != "y":
logger.info("操作已取消")
sys.exit(1)
print("\n" + "=" * 40 + "\n")
# 检查并创建必要的目录
check_and_create_dirs()
# # 检查输出文件是否存在
# if os.path.exists(RAW_DATA_PATH):
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
# sys.exit(1)
# if os.path.exists(RAW_DATA_PATH):
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
# sys.exit(1)
# 获取所有原始文本文件
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
if not raw_files:
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1)
# 处理所有文件
all_paragraphs = []
for file in raw_files:
logger.info(f"正在处理文件: {file.name}")
paragraphs = process_text_file(file)
all_paragraphs.extend(paragraphs)
# 保存合并后的结果到 IMPORTED_DATA_PATH文件名格式为 MM-DD-HH-ss-imported-data.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
logger.info(f"处理完成,结果已保存到: {output_path}")
if __name__ == "__main__":
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
main()

View File

@ -29,7 +29,7 @@ from rich.progress import (
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config
from tqdm import tqdm
install(extra_lines=3) install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
@ -317,18 +317,31 @@ class EmbeddingStore:
class EmbeddingManager: class EmbeddingManager:
def __init__(self): def __init__(self):
# self.paragraphs_embedding_store = EmbeddingStore(
# local_storage["pg_namespace"] + "-paragraph", # 不再共用一个命名空间
# EMBEDDING_DATA_DIR_STR,
# )
# self.entities_embedding_store = EmbeddingStore(
# local_storage["pg_namespace"] + "-entity",
# EMBEDDING_DATA_DIR_STR,
# )
# self.relation_embedding_store = EmbeddingStore(
# local_storage["pg_namespace"] + "-relation",
# EMBEDDING_DATA_DIR_STR,
# )
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "paragraph", # 不再共用一个命名空间
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "entity",
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "relation",
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self): def check_all_embedding_model_consistency(self):
@ -395,3 +408,34 @@ class EmbeddingManager:
self.paragraphs_embedding_store.build_faiss_index() self.paragraphs_embedding_store.build_faiss_index()
self.entities_embedding_store.build_faiss_index() self.entities_embedding_store.build_faiss_index()
self.relation_embedding_store.build_faiss_index() self.relation_embedding_store.build_faiss_index()
def store_entity_data_set(self, entity_hash_map: dict[str, str]) -> None:
"""为每个实体生成嵌入并保存entity_hash_map: {hash: text}"""
texts = list(entity_hash_map.values())
hashes = list(entity_hash_map.keys())
logger.info(f"正在生成 {len(texts)} 个实体的嵌入向量...")
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("嵌入实体", total=len(texts))
for hash_id, text in zip(hashes, texts):
item_hash = self.entities_embedding_store.namespace + "-entity-" + hash_id
if item_hash in self.entities_embedding_store.store:
progress.update(task, advance=1)
continue
embedding = self.entities_embedding_store._get_embedding(text)
self.entities_embedding_store.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, text)
progress.update(task, advance=1)
logger.info(f"成功嵌入实体数:{len(self.entities_embedding_store.store)}")

View File

@ -16,6 +16,8 @@ from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模
from src.config.config import global_config from src.config.config import global_config
from src.common.tcp_connector import get_tcp_connector from src.common.tcp_connector import get_tcp_connector
from rich.traceback import install from rich.traceback import install
from dotenv import load_dotenv
install(extra_lines=3) install(extra_lines=3)
@ -121,8 +123,10 @@ class LLMRequest:
try: try:
# print(f"model['provider']: {model['provider']}") # print(f"model['provider']: {model['provider']}")
self.api_key = os.environ[f"{model['provider']}_KEY"] # self.api_key = os.environ[f"{model['provider']}_KEY"] # 这是原来的写法,没法读取变量
self.base_url = os.environ[f"{model['provider']}_BASE_URL"] # self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
self.api_key = os.environ.get(f"{model['provider']}_KEY") # 改成这种写法,能读取变量
self.base_url = os.environ.get(f"{model['provider']}_BASE_URL")
logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL") logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
except AttributeError as e: except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")