From bf98434d3b289c83cb9f2db98da3d821fa414996 Mon Sep 17 00:00:00 2001 From: looom <42137636+xmexg@users.noreply.github.com> Date: Sun, 27 Jul 2025 00:28:48 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E7=9A=84=E6=8B=86=E5=88=86=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E3=80=81=E5=8A=A0=E8=BD=BDenv=E5=8F=98=E9=87=8F?= =?UTF-8?q?=E3=80=81=E5=AF=BC=E5=85=A5openie3=E4=B8=AA=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 20 ++++- scripts/raw_data_preprocessor.py | 29 ++++++- scripts/raw_data_preprocessor_old.py | 108 ++++++++++++++++++++++++++ src/chat/knowledge/embedding_store.py | 52 ++++++++++++- src/llm_models/utils_model.py | 8 +- 5 files changed, 209 insertions(+), 8 deletions(-) create mode 100644 scripts/raw_data_preprocessor_old.py diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 63a4d985..73951aaa 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -200,8 +200,26 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k logger.info("正在重新构建向量索引") embed_manager.rebuild_faiss_index() logger.info("向量索引构建完成") + # 对实体进行嵌入并存储 embed_manager.save_to_file() - logger.info("Embedding完成") + logger.info("段落Embedding完成") + + # 新增实体Embedding + all_entities = set() + for entity_list in entity_list_data.values(): + all_entities.update(entity_list) + + entity_hash_map = { + get_sha256(entity): entity + for entity in all_entities + } + + if entity_hash_map: + logger.info(f"开始对 {len(entity_hash_map)} 个实体生成向量嵌入") + embed_manager.store_entity_data_set(entity_hash_map) + embed_manager.save_to_file() + logger.info("实体嵌入完成") + # 构建新段落的RAG logger.info("开始构建RAG") kg_manager.build_kg(triple_list_data, embed_manager) diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index 42a99133..55e116dd 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -72,4 +72,31 @@ def load_raw_data() -> tuple[list[str], list[str]]: raw_data.append(item) logger.info(f"共读取到{len(raw_data)}条数据") - return sha256_list, raw_data \ No newline at end of file + return sha256_list, raw_data + +# chatgpt +import json +from datetime import datetime + +if __name__ == "__main__": + sha256_list, raw_data = load_raw_data() + + # 构造导出路径 + output_dir = os.path.join(ROOT_PATH, "data/imported_lpmm_data") + os.makedirs(output_dir, exist_ok=True) + + now = datetime.now().strftime("%m-%d-%H-%M") + output_path = os.path.join(output_dir, f"{now}-imported-data.json") + + # 写入 JSON 文件 + # with open(output_path, "w", encoding="utf-8") as f: + # json.dump({ + # "sha256_list": sha256_list, + # "raw_data": raw_data + # }, f, ensure_ascii=False, indent=2) + # 上面那些是AI写的,我看了下旧版生成的文件格式,实际上只保存[raw_data]就行,后面的info_extraction.py和import_openie.py会正常运行 + # 下面这段只保存 raw_data,和旧版保存的文件格式一致 + with open(output_path, "w", encoding="utf-8") as f: + json.dump(raw_data, f, ensure_ascii=False, indent=2) + + logger.info(f"数据处理完成,已写入:{output_path}") \ No newline at end of file diff --git a/scripts/raw_data_preprocessor_old.py b/scripts/raw_data_preprocessor_old.py new file mode 100644 index 00000000..1fa80d3f --- /dev/null +++ b/scripts/raw_data_preprocessor_old.py @@ -0,0 +1,108 @@ +import json +import os +from pathlib import Path +import sys # 新增系统模块导入 +import datetime # 新增导入 + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from src.common.logger import get_logger +from src.chat.knowledge.lpmmconfig import global_config + +logger = get_logger("lpmm") +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") +# 新增:确保 RAW_DATA_PATH 存在 +if not os.path.exists(RAW_DATA_PATH): + os.makedirs(RAW_DATA_PATH, exist_ok=True) + logger.info(f"已创建目录: {RAW_DATA_PATH}") + +if global_config.get("persistence", {}).get("raw_data_path") is not None: + IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"]) +else: + IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") + +# 添加项目根目录到 sys.path + + +def check_and_create_dirs(): + """检查并创建必要的目录""" + required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH] + + for dir_path in required_dirs: + if not os.path.exists(dir_path): + os.makedirs(dir_path) + logger.info(f"已创建目录: {dir_path}") + + +def process_text_file(file_path): + """处理单个文本文件,返回段落列表""" + with open(file_path, "r", encoding="utf-8") as f: + raw = f.read() + + paragraphs = [] + paragraph = "" + for line in raw.split("\n"): + if line.strip() == "": + if paragraph != "": + paragraphs.append(paragraph.strip()) + paragraph = "" + else: + paragraph += line + "\n" + + if paragraph != "": + paragraphs.append(paragraph.strip()) + + return paragraphs + + +def main(): + # 新增用户确认提示 + print("=== 数据预处理脚本 ===") + print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。") + print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。") + print("请确保原始数据已放置在正确的目录中。") + confirm = input("确认继续执行?(y/n): ").strip().lower() + if confirm != "y": + logger.info("操作已取消") + sys.exit(1) + print("\n" + "=" * 40 + "\n") + + # 检查并创建必要的目录 + check_and_create_dirs() + + # # 检查输出文件是否存在 + # if os.path.exists(RAW_DATA_PATH): + # logger.error("错误: data/import.json 已存在,请先处理或删除该文件") + # sys.exit(1) + + # if os.path.exists(RAW_DATA_PATH): + # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") + # sys.exit(1) + + # 获取所有原始文本文件 + raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) + if not raw_files: + logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") + sys.exit(1) + + # 处理所有文件 + all_paragraphs = [] + for file in raw_files: + logger.info(f"正在处理文件: {file.name}") + paragraphs = process_text_file(file) + all_paragraphs.extend(paragraphs) + + # 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json + now = datetime.datetime.now() + filename = now.strftime("%m-%d-%H-%S-imported-data.json") + output_path = os.path.join(IMPORTED_DATA_PATH, filename) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) + + logger.info(f"处理完成,结果已保存到: {output_path}") + + +if __name__ == "__main__": + logger.info(f"原始数据路径: {RAW_DATA_PATH}") + logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}") + main() \ No newline at end of file diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d732683a..784e59a5 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -29,7 +29,7 @@ from rich.progress import ( from src.manager.local_store_manager import local_storage from src.chat.utils.utils import get_embedding from src.config.config import global_config - +from tqdm import tqdm install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) @@ -317,18 +317,31 @@ class EmbeddingStore: class EmbeddingManager: def __init__(self): + # self.paragraphs_embedding_store = EmbeddingStore( + # local_storage["pg_namespace"] + "-paragraph", # 不再共用一个命名空间 + # EMBEDDING_DATA_DIR_STR, + # ) + # self.entities_embedding_store = EmbeddingStore( + # local_storage["pg_namespace"] + "-entity", + # EMBEDDING_DATA_DIR_STR, + # ) + # self.relation_embedding_store = EmbeddingStore( + # local_storage["pg_namespace"] + "-relation", + # EMBEDDING_DATA_DIR_STR, + # ) self.paragraphs_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "paragraph", # 不再共用一个命名空间 EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "entity", EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "relation", EMBEDDING_DATA_DIR_STR, ) + self.stored_pg_hashes = set() def check_all_embedding_model_consistency(self): @@ -395,3 +408,34 @@ class EmbeddingManager: self.paragraphs_embedding_store.build_faiss_index() self.entities_embedding_store.build_faiss_index() self.relation_embedding_store.build_faiss_index() + + def store_entity_data_set(self, entity_hash_map: dict[str, str]) -> None: + """为每个实体生成嵌入并保存,entity_hash_map: {hash: text}""" + texts = list(entity_hash_map.values()) + hashes = list(entity_hash_map.keys()) + + logger.info(f"正在生成 {len(texts)} 个实体的嵌入向量...") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("嵌入实体", total=len(texts)) + for hash_id, text in zip(hashes, texts): + item_hash = self.entities_embedding_store.namespace + "-entity-" + hash_id + if item_hash in self.entities_embedding_store.store: + progress.update(task, advance=1) + continue + embedding = self.entities_embedding_store._get_embedding(text) + self.entities_embedding_store.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, text) + progress.update(task, advance=1) + + logger.info(f"成功嵌入实体数:{len(self.entities_embedding_store.store)}") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 9aca329e..918e2232 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -16,6 +16,8 @@ from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模 from src.config.config import global_config from src.common.tcp_connector import get_tcp_connector from rich.traceback import install +from dotenv import load_dotenv + install(extra_lines=3) @@ -121,8 +123,10 @@ class LLMRequest: try: # print(f"model['provider']: {model['provider']}") - self.api_key = os.environ[f"{model['provider']}_KEY"] - self.base_url = os.environ[f"{model['provider']}_BASE_URL"] + # self.api_key = os.environ[f"{model['provider']}_KEY"] # 这是原来的写法,没法读取变量 + # self.base_url = os.environ[f"{model['provider']}_BASE_URL"] + self.api_key = os.environ.get(f"{model['provider']}_KEY") # 改成这种写法,能读取变量 + self.base_url = os.environ.get(f"{model['provider']}_BASE_URL") logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL") except AttributeError as e: logger.error(f"原始 model dict 信息:{model}")