From edda83453886ab2389af3d66e1aee31c2a0b319f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 2 May 2025 12:01:05 +0800 Subject: [PATCH 01/11] =?UTF-8?q?fix:=20=E7=A7=BB=E9=99=A4Traceback?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E4=B8=AD=E7=9A=84=20show=5Flocals=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E4=BB=A5=E7=AE=80=E5=8C=96=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E8=BF=BD=E8=B8=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 2 +- src/common/database.py | 2 +- src/common/log_decorators.py | 2 +- src/common/server.py | 2 +- src/config/config.py | 2 +- src/do_tool/tool_can_use/base_tool.py | 2 +- src/individuality/individuality.py | 2 +- src/individuality/offline_llm.py | 2 +- src/main.py | 2 +- src/plugins/PFC/chat_observer.py | 2 +- src/plugins/PFC/conversation.py | 2 +- src/plugins/PFC/message_sender.py | 2 +- src/plugins/PFC/pfc.py | 2 +- src/plugins/chat/chat_stream.py | 2 +- src/plugins/chat/message.py | 2 +- src/plugins/chat/message_sender.py | 2 +- src/plugins/chat/utils_image.py | 2 +- src/plugins/config_reload/api.py | 2 +- src/plugins/emoji_system/emoji_manager.py | 2 +- src/plugins/heartFC_chat/heartFC_chat.py | 2 +- src/plugins/heartFC_chat/heartFC_sender.py | 2 +- src/plugins/knowledge/src/embedding_store.py | 2 +- src/plugins/memory_system/Hippocampus.py | 2 +- src/plugins/memory_system/debug_memory.py | 2 +- src/plugins/memory_system/manually_alter_memory.py | 2 +- src/plugins/memory_system/offline_llm.py | 2 +- src/plugins/memory_system/sample_distribution.py | 2 +- src/plugins/models/utils_model.py | 2 +- src/plugins/utils/prompt_builder.py | 2 +- src/plugins/utils/timer_calculator.py | 2 +- src/plugins/willing/willing_manager.py | 2 +- src/plugins/zhishi/knowledge_library.py | 2 +- 32 files changed, 32 insertions(+), 32 deletions(-) diff --git a/bot.py b/bot.py index 5d811d4e..41847a01 100644 --- a/bot.py +++ b/bot.py @@ -15,7 +15,7 @@ from src.common.crash_logger import install_crash_handler from src.main import MainSystem from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("main") diff --git a/src/common/database.py b/src/common/database.py index 66a2dc16..752f746d 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -3,7 +3,7 @@ from pymongo import MongoClient from pymongo.database import Database from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) _client = None _db = None diff --git a/src/common/log_decorators.py b/src/common/log_decorators.py index a57fae79..414ba923 100644 --- a/src/common/log_decorators.py +++ b/src/common/log_decorators.py @@ -4,7 +4,7 @@ from typing import Callable, Any from .logger import logger, add_custom_style_handler from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) def use_log_style( diff --git a/src/common/server.py b/src/common/server.py index c080e28a..ff6106a7 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -4,7 +4,7 @@ from uvicorn import Config, Server as UvicornServer import os from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) class Server: diff --git a/src/config/config.py b/src/config/config.py index a067633b..28d947ef 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -16,7 +16,7 @@ from packaging.specifiers import SpecifierSet, InvalidSpecifier from src.common.logger_manager import get_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) # 配置主程序日志格式 diff --git a/src/do_tool/tool_can_use/base_tool.py b/src/do_tool/tool_can_use/base_tool.py index 88da036d..b0f04ffe 100644 --- a/src/do_tool/tool_can_use/base_tool.py +++ b/src/do_tool/tool_can_use/base_tool.py @@ -6,7 +6,7 @@ import os from src.common.logger_manager import get_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("base_tool") diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 963fae0e..38131ea1 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -4,7 +4,7 @@ from .identity import Identity import random from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) class Individuality: diff --git a/src/individuality/offline_llm.py b/src/individuality/offline_llm.py index 0e1a446c..cc956001 100644 --- a/src/individuality/offline_llm.py +++ b/src/individuality/offline_llm.py @@ -8,7 +8,7 @@ import requests from src.common.logger import get_module_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_module_logger("offline_llm") diff --git a/src/main.py b/src/main.py index 3de3e880..26a56ca2 100644 --- a/src/main.py +++ b/src/main.py @@ -19,7 +19,7 @@ from .individuality.individuality import Individuality from .common.server import global_server from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("main") diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index 34b66316..22cbf27d 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -9,7 +9,7 @@ from .chat_states import NotificationManager, create_new_message_notification, c from .message_storage import MongoDBMessageStorage from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_module_logger("chat_observer") diff --git a/src/plugins/PFC/conversation.py b/src/plugins/PFC/conversation.py index 925fd7b5..0bc4cae8 100644 --- a/src/plugins/PFC/conversation.py +++ b/src/plugins/PFC/conversation.py @@ -25,7 +25,7 @@ from .waiter import Waiter import traceback from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("pfc") diff --git a/src/plugins/PFC/message_sender.py b/src/plugins/PFC/message_sender.py index f1085768..12c2143e 100644 --- a/src/plugins/PFC/message_sender.py +++ b/src/plugins/PFC/message_sender.py @@ -10,7 +10,7 @@ from ..storage.storage import MessageStorage from ...config.config import global_config from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_module_logger("message_sender") diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index 50f7bf4c..b17ee21d 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -10,7 +10,7 @@ from .observation_info import ObservationInfo from src.plugins.utils.chat_message_builder import build_readable_messages from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) if TYPE_CHECKING: pass diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index a949247c..53ebd502 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -11,7 +11,7 @@ from maim_message import GroupInfo, UserInfo from src.common.logger_manager import get_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("chat_stream") diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 354082e1..b9c15288 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -11,7 +11,7 @@ from .utils_image import image_manager from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("chat_message") diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index 8bfee44b..b65ae895 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -15,7 +15,7 @@ from .utils import truncate_message, calculate_typing_time, count_messages_betwe from src.common.logger_manager import get_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("sender") diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 1f734502..5508ad23 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -15,7 +15,7 @@ from ..models.utils_model import LLMRequest from src.common.logger_manager import get_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("chat_image") diff --git a/src/plugins/config_reload/api.py b/src/plugins/config_reload/api.py index ee0a5454..56240b88 100644 --- a/src/plugins/config_reload/api.py +++ b/src/plugins/config_reload/api.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, HTTPException from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) # 创建APIRouter而不是FastAPI实例 router = APIRouter() diff --git a/src/plugins/emoji_system/emoji_manager.py b/src/plugins/emoji_system/emoji_manager.py index 24266c08..86dab9d9 100644 --- a/src/plugins/emoji_system/emoji_manager.py +++ b/src/plugins/emoji_system/emoji_manager.py @@ -17,7 +17,7 @@ from ..models.utils_model import LLMRequest from src.common.logger_manager import get_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("emoji") diff --git a/src/plugins/heartFC_chat/heartFC_chat.py b/src/plugins/heartFC_chat/heartFC_chat.py index 712b6af5..28c17d9a 100644 --- a/src/plugins/heartFC_chat/heartFC_chat.py +++ b/src/plugins/heartFC_chat/heartFC_chat.py @@ -29,7 +29,7 @@ from src.plugins.moods.moods import MoodManager from src.heart_flow.utils_chat import get_chat_type_and_target_info from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒 diff --git a/src/plugins/heartFC_chat/heartFC_sender.py b/src/plugins/heartFC_chat/heartFC_sender.py index 6fab5d62..b193ae44 100644 --- a/src/plugins/heartFC_chat/heartFC_sender.py +++ b/src/plugins/heartFC_chat/heartFC_sender.py @@ -11,7 +11,7 @@ from src.common.logger_manager import get_logger from src.plugins.chat.utils import calculate_typing_time from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_logger("sender") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 72c6c7b5..8e0d116b 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -14,7 +14,7 @@ from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) @dataclass diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py index 11ba8f40..24d320f7 100644 --- a/src/plugins/memory_system/Hippocampus.py +++ b/src/plugins/memory_system/Hippocampus.py @@ -22,7 +22,7 @@ from ..chat.utils import translate_timestamp_to_human_readable from .memory_config import MemoryConfig from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) def calculate_information_content(text): diff --git a/src/plugins/memory_system/debug_memory.py b/src/plugins/memory_system/debug_memory.py index ae767c85..8f79c6a8 100644 --- a/src/plugins/memory_system/debug_memory.py +++ b/src/plugins/memory_system/debug_memory.py @@ -10,7 +10,7 @@ from src.plugins.memory_system.Hippocampus import HippocampusManager from src.config.config import global_config from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) async def test_memory_system(): diff --git a/src/plugins/memory_system/manually_alter_memory.py b/src/plugins/memory_system/manually_alter_memory.py index 10a75738..ce5abbba 100644 --- a/src/plugins/memory_system/manually_alter_memory.py +++ b/src/plugins/memory_system/manually_alter_memory.py @@ -11,7 +11,7 @@ from Hippocampus import Hippocampus # 海马体和记忆图 from dotenv import load_dotenv from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) """ diff --git a/src/plugins/memory_system/offline_llm.py b/src/plugins/memory_system/offline_llm.py index 335a76d3..d4862ad3 100644 --- a/src/plugins/memory_system/offline_llm.py +++ b/src/plugins/memory_system/offline_llm.py @@ -8,7 +8,7 @@ import requests from src.common.logger import get_module_logger from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_module_logger("offline_llm") diff --git a/src/plugins/memory_system/sample_distribution.py b/src/plugins/memory_system/sample_distribution.py index 76796728..b3b84eb4 100644 --- a/src/plugins/memory_system/sample_distribution.py +++ b/src/plugins/memory_system/sample_distribution.py @@ -3,7 +3,7 @@ from scipy import stats from datetime import datetime, timedelta from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) class DistributionVisualizer: diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 7c7fe713..8ee21956 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -16,7 +16,7 @@ from ...common.database import db from ...config.config import global_config from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_module_logger("model_utils") diff --git a/src/plugins/utils/prompt_builder.py b/src/plugins/utils/prompt_builder.py index c4555a55..4a226a02 100644 --- a/src/plugins/utils/prompt_builder.py +++ b/src/plugins/utils/prompt_builder.py @@ -7,7 +7,7 @@ from src.common.logger import get_module_logger # import traceback from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) logger = get_module_logger("prompt_build") diff --git a/src/plugins/utils/timer_calculator.py b/src/plugins/utils/timer_calculator.py index d66f21cc..af8058a5 100644 --- a/src/plugins/utils/timer_calculator.py +++ b/src/plugins/utils/timer_calculator.py @@ -4,7 +4,7 @@ from typing import Optional, Dict, Callable import asyncio from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) """ # 更好的计时器 diff --git a/src/plugins/willing/willing_manager.py b/src/plugins/willing/willing_manager.py index a5884da2..ba1e3e09 100644 --- a/src/plugins/willing/willing_manager.py +++ b/src/plugins/willing/willing_manager.py @@ -10,7 +10,7 @@ from typing import Dict, Optional import asyncio from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) """ 基类方法概览: diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py index 26af3bda..6fa1d3e1 100644 --- a/src/plugins/zhishi/knowledge_library.py +++ b/src/plugins/zhishi/knowledge_library.py @@ -9,7 +9,7 @@ from rich.console import Console from rich.table import Table from rich.traceback import install -install(show_locals=True, extra_lines=3) +install(extra_lines=3) # 添加项目根目录到 Python 路径 root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) From 03961b71a259f52fb79e6b47565c6fed1db00fe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 2 May 2025 13:42:28 +0800 Subject: [PATCH 02/11] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=B7=AF=E5=BE=84=E9=85=8D=E7=BD=AE=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96=E9=94=99=E8=AF=AF=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 6 +- scripts/info_extraction.py | 103 +++++++++++++------ scripts/raw_data_preprocessor.py | 30 ++++-- src/plugins/knowledge/knowledge_lib.py | 2 + src/plugins/knowledge/src/embedding_store.py | 45 +++++--- src/plugins/knowledge/src/kg_manager.py | 79 ++++++++------ src/plugins/knowledge/src/open_ie.py | 48 +++++++-- src/plugins/knowledge/src/raw_processing.py | 18 ++-- template/lpmm_config_template.toml | 4 +- 9 files changed, 226 insertions(+), 109 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 595f22ec..25a1a877 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -19,7 +19,8 @@ from src.plugins.knowledge.src.utils.hash import get_sha256 # 添加项目根目录到 sys.path - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +OPENIE_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") logger = get_module_logger("LPMM知识库-OpenIE导入") @@ -131,6 +132,7 @@ def main(): embed_manager.load_from_file() except Exception as e: logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG kg_manager = KGManager() @@ -139,6 +141,7 @@ def main(): kg_manager.load_from_file() except Exception as e: logger.error("从文件加载KG时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("KG加载完成") logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") @@ -163,4 +166,5 @@ def main(): if __name__ == "__main__": + # logger.info(f"111111111111111111111111{ROOT_PATH}") main() diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 65c4082b..00f7a2a2 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -4,11 +4,13 @@ import signal from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock, Event import sys +import glob +import datetime sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) # 添加项目根目录到 sys.path -import tqdm +from rich.progress import Progress # 替换为 rich 进度条 from src.common.logger import get_module_logger from src.plugins.knowledge.src.lpmmconfig import global_config @@ -16,10 +18,15 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str from src.plugins.knowledge.src.llm_client import LLMClient from src.plugins.knowledge.src.open_ie import OpenIE from src.plugins.knowledge.src.raw_processing import load_raw_data +from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn logger = get_module_logger("LPMM知识库-信息提取") -TEMP_DIR = "./temp" + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +TEMP_DIR = os.path.join(ROOT_PATH, "temp") +IMPORTED_DATA_PATH = global_config["persistence"]["raw_data_path"] if global_config["persistence"]["raw_data_path"] else os.path.join(ROOT_PATH, "data/imported_lpmm_data") +OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() @@ -70,8 +77,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list): # 如果保存失败,确保不会留下损坏的文件 if os.path.exists(temp_file_path): os.remove(temp_file_path) - # 设置shutdown_event以终止程序 - shutdown_event.set() + sys.exit(0) return None, pg_hash return doc_item, None @@ -79,7 +85,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list): def signal_handler(_signum, _frame): """处理Ctrl+C信号""" logger.info("\n接收到中断信号,正在优雅地关闭程序...") - shutdown_event.set() + sys.exit(0) def main(): @@ -110,33 +116,61 @@ def main(): global_config["llm_providers"][key]["api_key"], ) - logger.info("正在加载原始数据") - sha256_list, raw_datas = load_raw_data() - logger.info("原始数据加载完成\n") + # 检查 openie 输出目录 + if not os.path.exists(OPENIE_OUTPUT_DIR): + os.makedirs(OPENIE_OUTPUT_DIR) + logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - # 创建临时目录 - if not os.path.exists(f"{TEMP_DIR}"): - os.makedirs(f"{TEMP_DIR}") + # 确保 TEMP_DIR 目录存在 + if not os.path.exists(TEMP_DIR): + os.makedirs(TEMP_DIR) + logger.info(f"已创建缓存目录: {TEMP_DIR}") + + # 遍历IMPORTED_DATA_PATH下所有json文件 + imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json"))) + if not imported_files: + logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件") + sys.exit(1) + + all_sha256_list = [] + all_raw_datas = [] + + for imported_file in imported_files: + logger.info(f"正在处理文件: {imported_file}") + try: + sha256_list, raw_datas = load_raw_data(imported_file) + except Exception as e: + logger.error(f"读取文件失败: {imported_file}, 错误: {e}") + continue + all_sha256_list.extend(sha256_list) + all_raw_datas.extend(raw_datas) failed_sha256 = [] open_ie_doc = [] - # 创建线程池,最大线程数为50 workers = global_config["info_extraction"]["workers"] with ThreadPoolExecutor(max_workers=workers) as executor: - # 提交所有任务到线程池 future_to_hash = { executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash - for pg_hash, raw_data in zip(sha256_list, raw_datas) + for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas) } - # 使用tqdm显示进度 - with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar: - # 处理完成的任务 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("正在进行提取:", total=len(future_to_hash)) try: for future in as_completed(future_to_hash): if shutdown_event.is_set(): - # 取消所有未完成的任务 for f in future_to_hash: if not f.done(): f.cancel() @@ -149,26 +183,33 @@ def main(): elif doc_item: with open_ie_doc_lock: open_ie_doc.append(doc_item) - pbar.update(1) + progress.update(task, advance=1) except KeyboardInterrupt: - # 如果在这里捕获到KeyboardInterrupt,说明signal_handler可能没有正常工作 logger.info("\n接收到中断信号,正在优雅地关闭程序...") shutdown_event.set() - # 取消所有未完成的任务 for f in future_to_hash: if not f.done(): f.cancel() - # 保存信息提取结果 - sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) - sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) - num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) - openie_obj = OpenIE( - open_ie_doc, - round(sum_phrase_chars / num_phrases, 4), - round(sum_phrase_words / num_phrases, 4), - ) - OpenIE.save(openie_obj) + # 合并所有文件的提取结果并保存 + if open_ie_doc: + sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) + openie_obj = OpenIE( + open_ie_doc, + round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0, + round(sum_phrase_words / num_phrases, 4) if num_phrases else 0, + ) + # 输出文件名格式:MM-DD-HH-ss-openie.json + now = datetime.datetime.now() + filename = now.strftime("%m-%d-%H-%S-openie.json") + output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, f, ensure_ascii=False, indent=4) + logger.info(f"信息提取结果已保存到: {output_path}") + else: + logger.warning("没有可保存的信息提取结果") logger.info("--------信息提取完成--------") logger.info(f"提取失败的文段SHA256:{failed_sha256}") diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index 2fc30352..d808fb0e 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -2,18 +2,22 @@ import json import os from pathlib import Path import sys # 新增系统模块导入 +import datetime # 新增导入 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_module_logger logger = get_module_logger("LPMM数据库-原始数据处理") +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") +IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") # 添加项目根目录到 sys.path def check_and_create_dirs(): """检查并创建必要的目录""" - required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"] + required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH] for dir_path in required_dirs: if not os.path.exists(dir_path): @@ -58,17 +62,17 @@ def main(): # 检查并创建必要的目录 check_and_create_dirs() - # 检查输出文件是否存在 - if os.path.exists("data/import.json"): - logger.error("错误: data/import.json 已存在,请先处理或删除该文件") - sys.exit(1) + # # 检查输出文件是否存在 + # if os.path.exists(RAW_DATA_PATH): + # logger.error("错误: data/import.json 已存在,请先处理或删除该文件") + # sys.exit(1) - if os.path.exists("data/openie.json"): - logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") - sys.exit(1) + # if os.path.exists(RAW_DATA_PATH): + # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") + # sys.exit(1) # 获取所有原始文本文件 - raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) + raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) if not raw_files: logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) @@ -80,8 +84,10 @@ def main(): paragraphs = process_text_file(file) all_paragraphs.extend(paragraphs) - # 保存合并后的结果 - output_path = "data/import.json" + # 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json + now = datetime.datetime.now() + filename = now.strftime("%m-%d-%H-%S-imported-data.json") + output_path = os.path.join(IMPORTED_DATA_PATH, filename) with open(output_path, "w", encoding="utf-8") as f: json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) @@ -89,4 +95,6 @@ def main(): if __name__ == "__main__": + print(f"Raw Data Path: {RAW_DATA_PATH}") + print(f"Imported Data Path: {IMPORTED_DATA_PATH}") main() diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py index c0d2fe61..df82970a 100644 --- a/src/plugins/knowledge/knowledge_lib.py +++ b/src/plugins/knowledge/knowledge_lib.py @@ -26,6 +26,7 @@ try: embed_manager.load_from_file() except Exception as e: logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG kg_manager = KGManager() @@ -34,6 +35,7 @@ try: kg_manager.load_from_file() except Exception as e: logger.error("从文件加载KG时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("KG加载完成") logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 8e0d116b..d2791ca4 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -13,9 +13,11 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install +from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn install(extra_lines=3) +TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 @dataclass class EmbeddingStoreItem: @@ -52,20 +54,35 @@ class EmbeddingStore: def _get_embedding(self, s: str) -> List[float]: return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) - def batch_insert_strs(self, strs: List[str]) -> None: + def batch_insert_strs(self, strs: List[str], times: int) -> None: """向库中存入字符串""" - # 逐项处理 - for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"): - # 计算hash去重 - item_hash = self.namespace + "-" + get_sha256(s) - if item_hash in self.store: - continue + total = len(strs) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) + for s in strs: + # 计算hash去重 + item_hash = self.namespace + "-" + get_sha256(s) + if item_hash in self.store: + progress.update(task, advance=1) + continue - # 获取embedding - embedding = self._get_embedding(s) + # 获取embedding + embedding = self._get_embedding(s) - # 存入 - self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + # 存入 + self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + progress.update(task, advance=1) def save_to_file(self) -> None: """保存到文件""" @@ -191,7 +208,7 @@ class EmbeddingManager: def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" - self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values())) + self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1) def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将实体编码存入Embedding库""" @@ -200,7 +217,7 @@ class EmbeddingManager: for triple in triple_list: entities.add(triple[0]) entities.add(triple[2]) - self.entities_embedding_store.batch_insert_strs(list(entities)) + self.entities_embedding_store.batch_insert_strs(list(entities),times=2) def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将关系编码存入Embedding库""" @@ -208,7 +225,7 @@ class EmbeddingManager: for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples]) + self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3) def load_from_file(self): """从文件加载""" diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index 71ce65ef..ccaf7aa8 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -5,7 +5,7 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd -import tqdm +from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn from quick_algo import di_graph, pagerank @@ -132,41 +132,56 @@ class KGManager: ent_hash_list = list(ent_hash_list) synonym_hash_set = set() - synonym_result = dict() - # 对每个实体节点,查找其相似的实体节点,建立扩展连接 - for ent_hash in tqdm.tqdm(ent_hash_list): - if ent_hash in synonym_hash_set: - # 避免同一批次内重复添加 - continue - ent = embedding_manager.entities_embedding_store.store.get(ent_hash) - assert isinstance(ent, EmbeddingStoreItem) - if ent is None: - continue - # 查询相似实体 - similar_ents = embedding_manager.entities_embedding_store.search_top_k( - ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] - ) - res_ent = [] # Debug - for res_ent_hash, similarity in similar_ents: - if res_ent_hash == ent_hash: - # 避免自连接 + # rich 进度条 + total = len(ent_hash_list) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("同义词连接", total=total) + for ent_hash in ent_hash_list: + if ent_hash in synonym_hash_set: + progress.update(task, advance=1) continue - if similarity < global_config["rag"]["params"]["synonym_threshold"]: - # 相似度阈值 + ent = embedding_manager.entities_embedding_store.store.get(ent_hash) + assert isinstance(ent, EmbeddingStoreItem) + if ent is None: + progress.update(task, advance=1) continue - node_to_node[(res_ent_hash, ent_hash)] = similarity - node_to_node[(ent_hash, res_ent_hash)] = similarity - synonym_hash_set.add(res_ent_hash) - new_edge_cnt += 1 - res_ent.append( - ( - embedding_manager.entities_embedding_store.store[res_ent_hash].str, - similarity, - ) - ) # Debug - synonym_result[ent.str] = res_ent + # 查询相似实体 + similar_ents = embedding_manager.entities_embedding_store.search_top_k( + ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] + ) + res_ent = [] # Debug + for res_ent_hash, similarity in similar_ents: + if res_ent_hash == ent_hash: + # 避免自连接 + continue + if similarity < global_config["rag"]["params"]["synonym_threshold"]: + # 相似度阈值 + continue + node_to_node[(res_ent_hash, ent_hash)] = similarity + node_to_node[(ent_hash, res_ent_hash)] = similarity + synonym_hash_set.add(res_ent_hash) + new_edge_cnt += 1 + res_ent.append( + ( + embedding_manager.entities_embedding_store.store[res_ent_hash].str, + similarity, + ) + ) # Debug + synonym_result[ent.str] = res_ent + progress.update(task, advance=1) for k, v in synonym_result.items(): print(f'"{k}"的相似实体为:{v}') diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py index 5fe163bb..ea84af4a 100644 --- a/src/plugins/knowledge/src/open_ie.py +++ b/src/plugins/knowledge/src/open_ie.py @@ -1,9 +1,13 @@ import json +import os +import glob from typing import Any, Dict, List from .lpmmconfig import INVALID_ENTITY, global_config +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + def _filter_invalid_entities(entities: List[str]) -> List[str]: """过滤无效的实体""" @@ -74,12 +78,22 @@ class OpenIE: doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) @staticmethod - def _from_dict(data): - """从字典中获取OpenIE对象""" + def _from_dict(data_list): + """从多个字典合并OpenIE对象""" + # data_list: List[dict] + all_docs = [] + for data in data_list: + all_docs.extend(data.get("docs", [])) + # 重新计算统计 + sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]]) + sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]]) + num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs]) + avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0 + avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0 return OpenIE( - docs=data["docs"], - avg_ent_chars=data["avg_ent_chars"], - avg_ent_words=data["avg_ent_words"], + docs=all_docs, + avg_ent_chars=avg_ent_chars, + avg_ent_words=avg_ent_words, ) def _to_dict(self): @@ -92,12 +106,20 @@ class OpenIE: @staticmethod def load() -> "OpenIE": - """从文件中加载OpenIE数据""" - with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f: - data = json.loads(f.read()) - - openie_data = OpenIE._from_dict(data) - + """从OPENIE_DIR下所有json文件合并加载OpenIE数据""" + openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"]) + if not os.path.exists(openie_dir): + raise Exception(f"OpenIE数据目录不存在: {openie_dir}") + json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) + data_list = [] + for file in json_files: + with open(file, "r", encoding="utf-8") as f: + data = json.load(f) + data_list.append(data) + if not data_list: + # print(f"111111111111111111111Root Path : \n{ROOT_PATH}") + raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件") + openie_data = OpenIE._from_dict(data_list) return openie_data @staticmethod @@ -132,3 +154,7 @@ class OpenIE: """提取原始段落""" raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) return raw_paragraph_dict + +if __name__ == "__main__": + # 测试代码 + print(ROOT_PATH) diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py index 91e681c7..a333ef99 100644 --- a/src/plugins/knowledge/src/raw_processing.py +++ b/src/plugins/knowledge/src/raw_processing.py @@ -6,21 +6,25 @@ from .lpmmconfig import global_config from .utils.hash import get_sha256 -def load_raw_data() -> tuple[list[str], list[str]]: +def load_raw_data(path: str = None) -> tuple[list[str], list[str]]: """加载原始数据文件 读取原始数据文件,将原始数据加载到内存中 + Args: + path: 可选,指定要读取的json文件绝对路径 + Returns: - - raw_data: 原始数据字典 - - md5_set: 原始数据的SHA256集合 + - raw_data: 原始数据列表 + - sha256_list: 原始数据的SHA256集合 """ - # 读取import.json文件 - if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: - with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f: + # 读取指定路径或默认路径的json文件 + json_path = path if path else global_config["persistence"]["raw_data_path"] + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as f: import_json = json.loads(f.read()) else: - raise Exception("原始数据文件读取失败") + raise Exception(f"原始数据文件读取失败: {json_path}") # import_json内容示例: # import_json = [ # "The capital of China is Beijing. The capital of France is Paris.", diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml index 43785e79..8563b7ca 100644 --- a/template/lpmm_config_template.toml +++ b/template/lpmm_config_template.toml @@ -51,7 +51,7 @@ res_top_k = 3 # 最终提供的文段TopK [persistence] # 持久化配置(存储中间数据,防止重复计算) data_root_path = "data" # 数据根目录 -raw_data_path = "data/import.json" # 原始数据路径 -openie_data_path = "data/openie.json" # OpenIE数据路径 +raw_data_path = "data/imported_lpmm_data" # 原始数据路径 +openie_data_path = "data/openie" # OpenIE数据路径 embedding_data_dir = "data/embedding" # 嵌入数据目录 rag_data_dir = "data/rag" # RAG数据目录 From b117e876870e0743b06efa4000b6aba01f0a0588 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 2 May 2025 05:42:41 +0000 Subject: [PATCH 03/11] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 6 +++- scripts/info_extraction.py | 29 +++++++++++++++++--- src/plugins/knowledge/src/embedding_store.py | 18 +++++++++--- src/plugins/knowledge/src/kg_manager.py | 11 +++++++- src/plugins/knowledge/src/open_ie.py | 3 +- 5 files changed, 56 insertions(+), 11 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 25a1a877..dd4b50ec 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -20,7 +20,11 @@ from src.plugins.knowledge.src.utils.hash import get_sha256 # 添加项目根目录到 sys.path ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -OPENIE_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") +OPENIE_DIR = ( + global_config["persistence"]["openie_data_path"] + if global_config["persistence"]["openie_data_path"] + else os.path.join(ROOT_PATH, "data/openie") +) logger = get_module_logger("LPMM知识库-OpenIE导入") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 00f7a2a2..ee0d789a 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -18,15 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str from src.plugins.knowledge.src.llm_client import LLMClient from src.plugins.knowledge.src.open_ie import OpenIE from src.plugins.knowledge.src.raw_processing import load_raw_data -from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn +from rich.progress import ( + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) logger = get_module_logger("LPMM知识库-信息提取") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") -IMPORTED_DATA_PATH = global_config["persistence"]["raw_data_path"] if global_config["persistence"]["raw_data_path"] else os.path.join(ROOT_PATH, "data/imported_lpmm_data") -OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") +IMPORTED_DATA_PATH = ( + global_config["persistence"]["raw_data_path"] + if global_config["persistence"]["raw_data_path"] + else os.path.join(ROOT_PATH, "data/imported_lpmm_data") +) +OPENIE_OUTPUT_DIR = ( + global_config["persistence"]["openie_data_path"] + if global_config["persistence"]["openie_data_path"] + else os.path.join(ROOT_PATH, "data/openie") +) # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() @@ -206,7 +222,12 @@ def main(): filename = now.strftime("%m-%d-%H-%S-openie.json") output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) with open(output_path, "w", encoding="utf-8") as f: - json.dump(openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, f, ensure_ascii=False, indent=4) + json.dump( + openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, + f, + ensure_ascii=False, + indent=4, + ) logger.info(f"信息提取结果已保存到: {output_path}") else: logger.warning("没有可保存的信息提取结果") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index d2791ca4..e734f4e9 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -13,12 +13,22 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install -from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn +from rich.progress import ( + Progress, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) install(extra_lines=3) TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 + @dataclass class EmbeddingStoreItem: """嵌入库中的项""" @@ -208,7 +218,7 @@ class EmbeddingManager: def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" - self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1) + self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将实体编码存入Embedding库""" @@ -217,7 +227,7 @@ class EmbeddingManager: for triple in triple_list: entities.add(triple[0]) entities.add(triple[2]) - self.entities_embedding_store.batch_insert_strs(list(entities),times=2) + self.entities_embedding_store.batch_insert_strs(list(entities), times=2) def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将关系编码存入Embedding库""" @@ -225,7 +235,7 @@ class EmbeddingManager: for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3) + self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3) def load_from_file(self): """从文件加载""" diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index ccaf7aa8..fd922af4 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -5,7 +5,16 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd -from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn +from rich.progress import ( + Progress, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) from quick_algo import di_graph, pagerank diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py index ea84af4a..75fd1854 100644 --- a/src/plugins/knowledge/src/open_ie.py +++ b/src/plugins/knowledge/src/open_ie.py @@ -154,7 +154,8 @@ class OpenIE: """提取原始段落""" raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) return raw_paragraph_dict - + + if __name__ == "__main__": # 测试代码 print(ROOT_PATH) From 81e5c1bb8bdc14b3e42b5fb84b1a49bb245409d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 2 May 2025 15:45:42 +0800 Subject: [PATCH 04/11] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BAOpenIE=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=AE=8C=E6=95=B4=E6=80=A7=E6=A3=80=E6=9F=A5=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=94=99=E8=AF=AF=E6=97=A5=E5=BF=97=E8=BE=93?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index dd4b50ec..2a6e09b7 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -6,6 +6,7 @@ import sys import os +from time import sleep sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -26,7 +27,7 @@ OPENIE_DIR = ( else os.path.join(ROOT_PATH, "data/openie") ) -logger = get_module_logger("LPMM知识库-OpenIE导入") +logger = get_module_logger("OpenIE导入") def hash_deduplicate( @@ -71,8 +72,45 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k entity_list_data = openie_data.extract_entity_dict() # 索引的三元组列表 triple_list_data = openie_data.extract_triple_dict() + # print(openie_data.docs) if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): logger.error("OpenIE数据存在异常") + logger.error(f"原始段落数量:{len(raw_paragraphs)}") + logger.error(f"实体列表数量:{len(entity_list_data)}") + logger.error(f"三元组列表数量:{len(triple_list_data)}") + logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致") + logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况") + logger.error("或者一段中只有符号的情况") + # 新增:检查docs中每条数据的完整性 + logger.error("系统将于2秒后开始检查数据完整性") + sleep(2) + found_missing = False + for doc in getattr(openie_data, "docs", []): + idx = doc.get("idx", "<无idx>") + passage = doc.get("passage", "<无passage>") + missing = [] + # 检查字段是否存在且非空 + if "passage" not in doc or not doc.get("passage"): + missing.append("passage") + if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list): + missing.append("名词列表缺失") + elif len(doc.get("extracted_entities", [])) == 0: + missing.append("名词列表为空") + if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list): + missing.append("主谓宾三元组缺失") + elif len(doc.get("extracted_triples", [])) == 0: + missing.append("主谓宾三元组为空") + # 输出所有doc的idx + # print(f"检查: idx={idx}") + if missing: + found_missing = True + logger.error("\n") + logger.error("数据缺失:") + logger.error(f"对应哈希值:{idx}") + logger.error(f"对应文段内容内容:{passage}") + logger.error(f"非法原因:{', '.join(missing)}") + if not found_missing: + print("所有数据均完整,没有发现缺失字段。") return False # 将索引换为对应段落的hash值 logger.info("正在进行段落去重与重索引") From 502f509630e4f073ff6b8d3db582d34cc2cd3ba3 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 2 May 2025 19:15:15 +0800 Subject: [PATCH 05/11] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8Drelationshi?= =?UTF-8?q?p=E5=8A=A0=E9=94=99=E4=BA=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/message_sender.py | 6 +++--- src/plugins/heartFC_chat/normal_chat.py | 5 ++++- src/plugins/person_info/relationship_manager.py | 15 +++++++-------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index 61e2dd49..c50c7aad 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -216,9 +216,9 @@ class MessageManager: # print(f"message.reply:{message.reply}") # --- 条件应用 set_reply 逻辑 --- - logger.debug( - f"[message.apply_set_reply_logic:{message.apply_set_reply_logic},message.is_head:{message.is_head},thinking_messages_count:{thinking_messages_count},thinking_messages_length:{thinking_messages_length},message.is_private_message():{message.is_private_message()}]" - ) + # logger.debug( + # f"[message.apply_set_reply_logic:{message.apply_set_reply_logic},message.is_head:{message.is_head},thinking_messages_count:{thinking_messages_count},thinking_messages_length:{thinking_messages_length},message.is_private_message():{message.is_private_message()}]" + # ) if ( message.apply_set_reply_logic # 检查标记 and message.is_head diff --git a/src/plugins/heartFC_chat/normal_chat.py b/src/plugins/heartFC_chat/normal_chat.py index 70568f83..1c1372c5 100644 --- a/src/plugins/heartFC_chat/normal_chat.py +++ b/src/plugins/heartFC_chat/normal_chat.py @@ -178,8 +178,11 @@ class NormalChat: """更新关系情绪""" ori_response = ",".join(response_set) stance, emotion = await self.gpt._get_emotion_tags(ori_response, message.processed_plain_text) + user_info = message.message_info.user_info + platform = user_info.platform await relationship_manager.calculate_update_relationship_value( - chat_stream=self.chat_stream, + user_info, + platform, label=emotion, stance=stance, # 使用 self.chat_stream ) diff --git a/src/plugins/person_info/relationship_manager.py b/src/plugins/person_info/relationship_manager.py index fc8cf548..3c264b05 100644 --- a/src/plugins/person_info/relationship_manager.py +++ b/src/plugins/person_info/relationship_manager.py @@ -5,6 +5,7 @@ from bson.decimal128 import Decimal128 from .person_info import person_info_manager import time import random +from maim_message import UserInfo, Seg # import re # import traceback @@ -102,7 +103,7 @@ class RelationshipManager: # await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar) await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar) - async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> tuple: + async def calculate_update_relationship_value(self, user_info: UserInfo, platform: str, label: str, stance: str): """计算并变更关系值 新的关系值变更计算方式: 将关系值限定在-1000到1000 @@ -134,11 +135,11 @@ class RelationshipManager: "困惑": 0.5, } - person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id) + person_id = person_info_manager.get_person_id(platform, user_info.user_id) data = { - "platform": chat_stream.user_info.platform, - "user_id": chat_stream.user_info.user_id, - "nickname": chat_stream.user_info.user_nickname, + "platform": platform, + "user_id": user_info.user_id, + "nickname": user_info.user_nickname, "konw_time": int(time.time()), } old_value = await person_info_manager.get_value(person_id, "relationship_value") @@ -178,7 +179,7 @@ class RelationshipManager: level_num = self.calculate_level_num(old_value + value) relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"] logger.info( - f"用户: {chat_stream.user_info.user_nickname}" + f"用户: {user_info.user_nickname}" f"当前关系: {relationship_level[level_num]}, " f"关系值: {old_value:.2f}, " f"当前立场情感: {stance}-{label}, " @@ -187,8 +188,6 @@ class RelationshipManager: await person_info_manager.update_one_field(person_id, "relationship_value", old_value + value, data) - return chat_stream.user_info.user_nickname, value, relationship_level[level_num] - async def calculate_update_relationship_value_with_reason( self, chat_stream: ChatStream, label: str, stance: str, reason: str ) -> tuple: From a859f9238fa0df18023317276dcce3c1f0d1da23 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 2 May 2025 19:17:59 +0800 Subject: [PATCH 06/11] =?UTF-8?q?better=EF=BC=9A=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ![新版麦麦开始学习.bat | 65 +++++++++++++++++++++++--------- scripts/info_extraction.py | 4 +- scripts/raw_data_preprocessor.py | 2 +- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/![新版麦麦开始学习.bat b/![新版麦麦开始学习.bat index 41fc9836..b95bad00 100644 --- a/![新版麦麦开始学习.bat +++ b/![新版麦麦开始学习.bat @@ -2,26 +2,57 @@ CHCP 65001 > nul setlocal enabledelayedexpansion -REM 查找venv虚拟环境 -set "venv_path=%~dp0venv\Scripts\activate.bat" -if not exist "%venv_path%" ( - echo 错误: 未找到虚拟环境,请确保venv目录存在 - pause - exit /b 1 +echo 你需要选择启动方式,输入字母来选择: +echo V = 不知道什么意思就输入 V +echo C = 输入 C 使用 Conda 环境 +echo. +choice /C CV /N /M "在下方输入字母并回车 (C/V)?" /T 10 /D V + +set "ENV_TYPE=" +if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA" +if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV" + +if "%ENV_TYPE%" == "CONDA" ( + set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: " + if not defined CONDA_ENV_NAME ( + echo 错误: 未输入 Conda 环境名称. + pause + exit /b 1 + ) + echo 选择: Conda '%CONDA_ENV_NAME%' + REM 激活Conda环境 + call conda activate %CONDA_ENV_NAME% + if %ERRORLEVEL% neq 0 ( + echo 错误: Conda环境 '%CONDA_ENV_NAME%' 激活失败. 请确保Conda已安装并正确配置, 且 '%CONDA_ENV_NAME%' 环境存在. + pause + exit /b 1 + ) +) else ( + echo Selected: venv (default) + REM 查找venv虚拟环境 + set "venv_path=%~dp0venv\Scripts\activate.bat" + if not exist "%venv_path%" ( + echo Error: venv not found. Ensure the venv directory exists alongside the script. + pause + exit /b 1 + ) + REM 激活虚拟环境 + call "%venv_path%" + if %ERRORLEVEL% neq 0 ( + echo Error: Failed to activate venv virtual environment. + pause + exit /b 1 + ) ) -REM 激活虚拟环境 -call "%venv_path%" -if %ERRORLEVEL% neq 0 ( - echo 错误: 虚拟环境激活失败 - pause - exit /b 1 -) +echo Environment activated successfully! + +REM --- 后续脚本执行 --- REM 运行预处理脚本 python "%~dp0scripts\raw_data_preprocessor.py" if %ERRORLEVEL% neq 0 ( - echo 错误: raw_data_preprocessor.py 执行失败 + echo Error: raw_data_preprocessor.py execution failed. pause exit /b 1 ) @@ -29,7 +60,7 @@ if %ERRORLEVEL% neq 0 ( REM 运行信息提取脚本 python "%~dp0scripts\info_extraction.py" if %ERRORLEVEL% neq 0 ( - echo 错误: info_extraction.py 执行失败 + echo Error: info_extraction.py execution failed. pause exit /b 1 ) @@ -37,10 +68,10 @@ if %ERRORLEVEL% neq 0 ( REM 运行OpenIE导入脚本 python "%~dp0scripts\import_openie.py" if %ERRORLEVEL% neq 0 ( - echo 错误: import_openie.py 执行失败 + echo Error: import_openie.py execution failed. pause exit /b 1 ) -echo 所有处理步骤完成! +echo All processing steps completed! pause \ No newline at end of file diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 65c4082b..9e079070 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -87,8 +87,8 @@ def main(): signal.signal(signal.SIGINT, signal_handler) # 新增用户确认提示 - print("=== 重要操作确认 ===") - print("实体提取操作将会花费较多资金和时间,建议在空闲时段执行。") + print("=== 重要操作确认,请认真阅读以下内容哦 ===") + print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。") print("建议使用硅基流动的非Pro模型") print("或者使用可以用赠金抵扣的Pro模型") diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index 2fc30352..33d51153 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -44,7 +44,7 @@ def process_text_file(file_path): def main(): # 新增用户确认提示 - print("=== 重要操作确认 ===") + print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("如果你并非第一次导入知识") print("请先删除data/import.json文件,备份data/openie.json文件") print("在进行知识库导入之前") From 021ac90ead332195eb69d2cf98e556a6bcc7c3a3 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 2 May 2025 19:46:04 +0800 Subject: [PATCH 07/11] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E6=94=B9bat?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E6=8F=8F=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ![新版麦麦开始学习.bat | 29 ++++++++++++++++++++--------- scripts/raw_data_preprocessor.py | 9 ++++----- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/![新版麦麦开始学习.bat b/![新版麦麦开始学习.bat index b95bad00..eacaa2eb 100644 --- a/![新版麦麦开始学习.bat +++ b/![新版麦麦开始学习.bat @@ -6,29 +6,39 @@ echo 你需要选择启动方式,输入字母来选择: echo V = 不知道什么意思就输入 V echo C = 输入 C 使用 Conda 环境 echo. -choice /C CV /N /M "在下方输入字母并回车 (C/V)?" /T 10 /D V +choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V set "ENV_TYPE=" if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA" if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV" -if "%ENV_TYPE%" == "CONDA" ( +if "%ENV_TYPE%" == "CONDA" goto activate_conda +if "%ENV_TYPE%" == "VENV" goto activate_venv + +REM 如果 choice 超时或返回意外值,默认使用 venv +echo WARN: Invalid selection or timeout from choice. Defaulting to VENV. +set "ENV_TYPE=VENV" +goto activate_venv + +:activate_conda set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: " if not defined CONDA_ENV_NAME ( echo 错误: 未输入 Conda 环境名称. pause exit /b 1 ) - echo 选择: Conda '%CONDA_ENV_NAME%' + echo 选择: Conda '!CONDA_ENV_NAME!' REM 激活Conda环境 - call conda activate %CONDA_ENV_NAME% - if %ERRORLEVEL% neq 0 ( - echo 错误: Conda环境 '%CONDA_ENV_NAME%' 激活失败. 请确保Conda已安装并正确配置, 且 '%CONDA_ENV_NAME%' 环境存在. + call conda activate !CONDA_ENV_NAME! + if !ERRORLEVEL! neq 0 ( + echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在. pause exit /b 1 ) -) else ( - echo Selected: venv (default) + goto env_activated + +:activate_venv + echo Selected: venv (default or selected) REM 查找venv虚拟环境 set "venv_path=%~dp0venv\Scripts\activate.bat" if not exist "%venv_path%" ( @@ -43,8 +53,9 @@ if "%ENV_TYPE%" == "CONDA" ( pause exit /b 1 ) -) + goto env_activated +:env_activated echo Environment activated successfully! REM --- 后续脚本执行 --- diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index 056cf572..c87c30ca 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -48,11 +48,10 @@ def process_text_file(file_path): def main(): # 新增用户确认提示 - print("=== 重要操作确认,请认真阅读以下内容哦 ===") - print("如果你并非第一次导入知识") - print("请先删除data/import.json文件,备份data/openie.json文件") - print("在进行知识库导入之前") - print("请修改config/lpmm_config.toml中的配置项") + print("=== 数据预处理脚本 ===") + print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。") + print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。") + print("请确保原始数据已放置在正确的目录中。") confirm = input("确认继续执行?(y/n): ").strip().lower() if confirm != "y": logger.error("操作已取消") From acbf5c974bfb173f216eb4cfe681dd0748c95342 Mon Sep 17 00:00:00 2001 From: Bakadax Date: Fri, 2 May 2025 20:38:34 +0800 Subject: [PATCH 08/11] =?UTF-8?q?=E8=A1=A8=E6=83=85=E5=8C=85=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/emoji_system/emoji_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/plugins/emoji_system/emoji_manager.py b/src/plugins/emoji_system/emoji_manager.py index 86dab9d9..d105e0b8 100644 --- a/src/plugins/emoji_system/emoji_manager.py +++ b/src/plugins/emoji_system/emoji_manager.py @@ -289,7 +289,6 @@ def _to_emoji_objects(data): except Exception as e: logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}") load_errors += 1 - return emoji_objects, load_errors return emoji_objects, load_errors From 6747e1d44177e419ecde9a51a41380fdcd6543d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 2 May 2025 20:42:58 +0800 Subject: [PATCH 09/11] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=A8=A1=E6=9D=BF=EF=BC=8C=E6=B7=BB=E5=8A=A0lpmm?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E4=BF=A1=E6=81=AF=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E6=97=A5=E5=BF=97=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/knowledge/src/lpmmconfig.py | 24 ++++++++++-------------- template/lpmm_config_template.toml | 3 +++ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/plugins/knowledge/src/lpmmconfig.py b/src/plugins/knowledge/src/lpmmconfig.py index 753562f4..040bdedc 100644 --- a/src/plugins/knowledge/src/lpmmconfig.py +++ b/src/plugins/knowledge/src/lpmmconfig.py @@ -1,7 +1,7 @@ import os import toml import sys -import argparse +# import argparse from .global_logger import logger PG_NAMESPACE = "paragraph" @@ -37,7 +37,8 @@ def _load_config(config, config_file_path): # Check if all top-level keys from default config exist in the file config for key in config.keys(): if key not in file_config: - print(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。") + logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。") + logger.critical("请通过template/lpmm_config_template.toml文件进行更新") sys.exit(1) if "llm_providers" in file_config: @@ -68,16 +69,11 @@ def _load_config(config, config_file_path): logger.info(f"从文件中读取配置: {config_file_path}") -parser = argparse.ArgumentParser(description="Configurations for the pipeline") -parser.add_argument( - "--config_path", - type=str, - default="lpmm_config.toml", - help="Path to the configuration file", -) - global_config = dict( { + "lpmm":{ + "version": "0.1.0", + }, "llm_providers": { "localhost": { "base_url": "https://api.siliconflow.cn/v1", @@ -136,8 +132,8 @@ global_config = dict( ) # _load_config(global_config, parser.parse_args().config_path) -file_path = os.path.abspath(__file__) -dir_path = os.path.dirname(file_path) -root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir) -config_path = os.path.join(root_path, "config", "lpmm_config.toml") +# file_path = os.path.abspath(__file__) +# dir_path = os.path.dirname(file_path) +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) +config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml") _load_config(global_config, config_path) diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml index 8563b7ca..aae664d5 100644 --- a/template/lpmm_config_template.toml +++ b/template/lpmm_config_template.toml @@ -1,3 +1,6 @@ +[lpmm] +version = "0.1.0" + # LLM API 服务提供商,可配置多个 [[llm_providers]] name = "localhost" From 5d1c880fb9e99f7a8a7d78ecbc6242205b46b971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 2 May 2025 20:45:06 +0800 Subject: [PATCH 10/11] fix: Ruff --- src/plugins/person_info/relationship_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/person_info/relationship_manager.py b/src/plugins/person_info/relationship_manager.py index 3c264b05..862f2398 100644 --- a/src/plugins/person_info/relationship_manager.py +++ b/src/plugins/person_info/relationship_manager.py @@ -5,7 +5,7 @@ from bson.decimal128 import Decimal128 from .person_info import person_info_manager import time import random -from maim_message import UserInfo, Seg +from maim_message import UserInfo # import re # import traceback From 4f9fbe78728adc2d2afd1f815de3f958901536e7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 2 May 2025 12:45:19 +0000 Subject: [PATCH 11/11] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/knowledge/src/lpmmconfig.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/plugins/knowledge/src/lpmmconfig.py b/src/plugins/knowledge/src/lpmmconfig.py index 040bdedc..387a7b29 100644 --- a/src/plugins/knowledge/src/lpmmconfig.py +++ b/src/plugins/knowledge/src/lpmmconfig.py @@ -1,6 +1,7 @@ import os import toml import sys + # import argparse from .global_logger import logger @@ -71,7 +72,7 @@ def _load_config(config, config_file_path): global_config = dict( { - "lpmm":{ + "lpmm": { "version": "0.1.0", }, "llm_providers": {