pull/1001/head
SnowindMe 2025-05-03 00:29:01 +08:00
commit e7dc30204d
44 changed files with 436 additions and 194 deletions

View File

@ -2,26 +2,68 @@
CHCP 65001 > nul
setlocal enabledelayedexpansion
REM 查找venv虚拟环境
set "venv_path=%~dp0venv\Scripts\activate.bat"
if not exist "%venv_path%" (
echo 错误: 未找到虚拟环境请确保venv目录存在
pause
exit /b 1
)
echo 你需要选择启动方式,输入字母来选择:
echo V = 不知道什么意思就输入 V
echo C = 输入 C 使用 Conda 环境
echo.
choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V
REM 激活虚拟环境
call "%venv_path%"
if %ERRORLEVEL% neq 0 (
echo 错误: 虚拟环境激活失败
pause
exit /b 1
)
set "ENV_TYPE="
if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA"
if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV"
if "%ENV_TYPE%" == "CONDA" goto activate_conda
if "%ENV_TYPE%" == "VENV" goto activate_venv
REM 如果 choice 超时或返回意外值,默认使用 venv
echo WARN: Invalid selection or timeout from choice. Defaulting to VENV.
set "ENV_TYPE=VENV"
goto activate_venv
:activate_conda
set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: "
if not defined CONDA_ENV_NAME (
echo 错误: 未输入 Conda 环境名称.
pause
exit /b 1
)
echo 选择: Conda '!CONDA_ENV_NAME!'
REM 激活Conda环境
call conda activate !CONDA_ENV_NAME!
if !ERRORLEVEL! neq 0 (
echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在.
pause
exit /b 1
)
goto env_activated
:activate_venv
echo Selected: venv (default or selected)
REM 查找venv虚拟环境
set "venv_path=%~dp0venv\Scripts\activate.bat"
if not exist "%venv_path%" (
echo Error: venv not found. Ensure the venv directory exists alongside the script.
pause
exit /b 1
)
REM 激活虚拟环境
call "%venv_path%"
if %ERRORLEVEL% neq 0 (
echo Error: Failed to activate venv virtual environment.
pause
exit /b 1
)
goto env_activated
:env_activated
echo Environment activated successfully!
REM --- 后续脚本执行 ---
REM 运行预处理脚本
python "%~dp0scripts\raw_data_preprocessor.py"
if %ERRORLEVEL% neq 0 (
echo 错误: raw_data_preprocessor.py 执行失败
echo Error: raw_data_preprocessor.py execution failed.
pause
exit /b 1
)
@ -29,7 +71,7 @@ if %ERRORLEVEL% neq 0 (
REM 运行信息提取脚本
python "%~dp0scripts\info_extraction.py"
if %ERRORLEVEL% neq 0 (
echo 错误: info_extraction.py 执行失败
echo Error: info_extraction.py execution failed.
pause
exit /b 1
)
@ -37,10 +79,10 @@ if %ERRORLEVEL% neq 0 (
REM 运行OpenIE导入脚本
python "%~dp0scripts\import_openie.py"
if %ERRORLEVEL% neq 0 (
echo 错误: import_openie.py 执行失败
echo Error: import_openie.py execution failed.
pause
exit /b 1
)
echo 所有处理步骤完成!
echo All processing steps completed!
pause

2
bot.py
View File

@ -15,7 +15,7 @@ from src.common.crash_logger import install_crash_handler
from src.main import MainSystem
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("main")

View File

@ -6,6 +6,7 @@
import sys
import os
from time import sleep
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -19,9 +20,14 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
# 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
logger = get_module_logger("LPMM知识库-OpenIE导入")
logger = get_module_logger("OpenIE导入")
def hash_deduplicate(
@ -66,8 +72,45 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
entity_list_data = openie_data.extract_entity_dict()
# 索引的三元组列表
triple_list_data = openie_data.extract_triple_dict()
# print(openie_data.docs)
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("OpenIE数据存在异常")
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
logger.error(f"实体列表数量:{len(entity_list_data)}")
logger.error(f"三元组列表数量:{len(triple_list_data)}")
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
logger.error("或者一段中只有符号的情况")
# 新增检查docs中每条数据的完整性
logger.error("系统将于2秒后开始检查数据完整性")
sleep(2)
found_missing = False
for doc in getattr(openie_data, "docs", []):
idx = doc.get("idx", "<无idx>")
passage = doc.get("passage", "<无passage>")
missing = []
# 检查字段是否存在且非空
if "passage" not in doc or not doc.get("passage"):
missing.append("passage")
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
missing.append("名词列表缺失")
elif len(doc.get("extracted_entities", [])) == 0:
missing.append("名词列表为空")
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
missing.append("主谓宾三元组缺失")
elif len(doc.get("extracted_triples", [])) == 0:
missing.append("主谓宾三元组为空")
# 输出所有doc的idx
# print(f"检查: idx={idx}")
if missing:
found_missing = True
logger.error("\n")
logger.error("数据缺失:")
logger.error(f"对应哈希值:{idx}")
logger.error(f"对应文段内容内容:{passage}")
logger.error(f"非法原因:{', '.join(missing)}")
if not found_missing:
print("所有数据均完整,没有发现缺失字段。")
return False
# 将索引换为对应段落的hash值
logger.info("正在进行段落去重与重索引")
@ -131,6 +174,7 @@ def main():
embed_manager.load_from_file()
except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager()
@ -139,6 +183,7 @@ def main():
kg_manager.load_from_file()
except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
@ -163,4 +208,5 @@ def main():
if __name__ == "__main__":
# logger.info(f"111111111111111111111111{ROOT_PATH}")
main()

View File

@ -4,11 +4,13 @@ import signal
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event
import sys
import glob
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# 添加项目根目录到 sys.path
import tqdm
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_module_logger
from src.plugins.knowledge.src.lpmmconfig import global_config
@ -16,10 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
from src.plugins.knowledge.src.llm_client import LLMClient
from src.plugins.knowledge.src.open_ie import OpenIE
from src.plugins.knowledge.src.raw_processing import load_raw_data
from rich.progress import (
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
logger = get_module_logger("LPMM知识库-信息提取")
TEMP_DIR = "./temp"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = (
global_config["persistence"]["raw_data_path"]
if global_config["persistence"]["raw_data_path"]
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
)
OPENIE_OUTPUT_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
@ -70,8 +93,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
# 如果保存失败,确保不会留下损坏的文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
# 设置shutdown_event以终止程序
shutdown_event.set()
sys.exit(0)
return None, pg_hash
return doc_item, None
@ -79,7 +101,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
def signal_handler(_signum, _frame):
"""处理Ctrl+C信号"""
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set()
sys.exit(0)
def main():
@ -87,8 +109,8 @@ def main():
signal.signal(signal.SIGINT, signal_handler)
# 新增用户确认提示
print("=== 重要操作确认 ===")
print("实体提取操作将会花费较多资金和时间,建议在空闲时段执行。")
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
print("举例600万字全剧情提取选用deepseek v3 0324消耗约40元约3小时。")
print("建议使用硅基流动的非Pro模型")
print("或者使用可以用赠金抵扣的Pro模型")
@ -110,33 +132,61 @@ def main():
global_config["llm_providers"][key]["api_key"],
)
logger.info("正在加载原始数据")
sha256_list, raw_datas = load_raw_data()
logger.info("原始数据加载完成\n")
# 检查 openie 输出目录
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
# 创建临时目录
if not os.path.exists(f"{TEMP_DIR}"):
os.makedirs(f"{TEMP_DIR}")
# 确保 TEMP_DIR 目录存在
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
logger.info(f"已创建缓存目录: {TEMP_DIR}")
# 遍历IMPORTED_DATA_PATH下所有json文件
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
if not imported_files:
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
sys.exit(1)
all_sha256_list = []
all_raw_datas = []
for imported_file in imported_files:
logger.info(f"正在处理文件: {imported_file}")
try:
sha256_list, raw_datas = load_raw_data(imported_file)
except Exception as e:
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
continue
all_sha256_list.extend(sha256_list)
all_raw_datas.extend(raw_datas)
failed_sha256 = []
open_ie_doc = []
# 创建线程池最大线程数为50
workers = global_config["info_extraction"]["workers"]
with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务到线程池
future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
for pg_hash, raw_data in zip(sha256_list, raw_datas)
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
}
# 使用tqdm显示进度
with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar:
# 处理完成的任务
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
try:
for future in as_completed(future_to_hash):
if shutdown_event.is_set():
# 取消所有未完成的任务
for f in future_to_hash:
if not f.done():
f.cancel()
@ -149,26 +199,38 @@ def main():
elif doc_item:
with open_ie_doc_lock:
open_ie_doc.append(doc_item)
pbar.update(1)
progress.update(task, advance=1)
except KeyboardInterrupt:
# 如果在这里捕获到KeyboardInterrupt说明signal_handler可能没有正常工作
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set()
# 取消所有未完成的任务
for f in future_to_hash:
if not f.done():
f.cancel()
# 保存信息提取结果
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE(
open_ie_doc,
round(sum_phrase_chars / num_phrases, 4),
round(sum_phrase_words / num_phrases, 4),
)
OpenIE.save(openie_obj)
# 合并所有文件的提取结果并保存
if open_ie_doc:
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE(
open_ie_doc,
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
)
# 输出文件名格式MM-DD-HH-ss-openie.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
f,
ensure_ascii=False,
indent=4,
)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")
logger.info("--------信息提取完成--------")
logger.info(f"提取失败的文段SHA256{failed_sha256}")

View File

@ -2,18 +2,22 @@ import json
import os
from pathlib import Path
import sys # 新增系统模块导入
import datetime # 新增导入
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_module_logger
logger = get_module_logger("LPMM数据库-原始数据处理")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
# 添加项目根目录到 sys.path
def check_and_create_dirs():
"""检查并创建必要的目录"""
required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
for dir_path in required_dirs:
if not os.path.exists(dir_path):
@ -44,11 +48,10 @@ def process_text_file(file_path):
def main():
# 新增用户确认提示
print("=== 重要操作确认 ===")
print("如果你并非第一次导入知识")
print("请先删除data/import.json文件备份data/openie.json文件")
print("在进行知识库导入之前")
print("请修改config/lpmm_config.toml中的配置项")
print("=== 数据预处理脚本 ===")
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
print("请确保原始数据已放置在正确的目录中。")
confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != "y":
logger.error("操作已取消")
@ -58,17 +61,17 @@ def main():
# 检查并创建必要的目录
check_and_create_dirs()
# 检查输出文件是否存在
if os.path.exists("data/import.json"):
logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
sys.exit(1)
# # 检查输出文件是否存在
# if os.path.exists(RAW_DATA_PATH):
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
# sys.exit(1)
if os.path.exists("data/openie.json"):
logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
sys.exit(1)
# if os.path.exists(RAW_DATA_PATH):
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
# sys.exit(1)
# 获取所有原始文本文件
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
if not raw_files:
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1)
@ -80,8 +83,10 @@ def main():
paragraphs = process_text_file(file)
all_paragraphs.extend(paragraphs)
# 保存合并后的结果
output_path = "data/import.json"
# 保存合并后的结果到 IMPORTED_DATA_PATH文件名格式为 MM-DD-HH-ss-imported-data.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
@ -89,4 +94,6 @@ def main():
if __name__ == "__main__":
print(f"Raw Data Path: {RAW_DATA_PATH}")
print(f"Imported Data Path: {IMPORTED_DATA_PATH}")
main()

View File

@ -3,7 +3,7 @@ from pymongo import MongoClient
from pymongo.database import Database
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
_client = None
_db = None

View File

@ -4,7 +4,7 @@ from typing import Callable, Any
from .logger import logger, add_custom_style_handler
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
def use_log_style(

View File

@ -4,7 +4,7 @@ from uvicorn import Config, Server as UvicornServer
import os
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
class Server:

View File

@ -16,7 +16,7 @@ from packaging.specifiers import SpecifierSet, InvalidSpecifier
from src.common.logger_manager import get_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
# 配置主程序日志格式

View File

@ -6,7 +6,7 @@ import os
from src.common.logger_manager import get_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("base_tool")

View File

@ -4,7 +4,7 @@ from .identity import Identity
import random
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
class Individuality:

View File

@ -8,7 +8,7 @@ import requests
from src.common.logger import get_module_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_module_logger("offline_llm")

View File

@ -19,7 +19,7 @@ from .individuality.individuality import Individuality
from .common.server import global_server
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("main")

View File

@ -9,7 +9,7 @@ from .chat_states import NotificationManager, create_new_message_notification, c
from .message_storage import MongoDBMessageStorage
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_module_logger("chat_observer")

View File

@ -25,7 +25,7 @@ from .waiter import Waiter
import traceback
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("pfc")

View File

@ -10,7 +10,7 @@ from ..storage.storage import MessageStorage
from ...config.config import global_config
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_module_logger("message_sender")

View File

@ -10,7 +10,7 @@ from .observation_info import ObservationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
if TYPE_CHECKING:
pass

View File

@ -11,7 +11,7 @@ from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("chat_stream")

View File

@ -11,7 +11,7 @@ from .utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("chat_message")

View File

@ -15,7 +15,7 @@ from .utils import truncate_message, calculate_typing_time, count_messages_betwe
from src.common.logger_manager import get_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("sender")
@ -219,9 +219,9 @@ class MessageManager:
# print(f"message.reply:{message.reply}")
# --- 条件应用 set_reply 逻辑 ---
logger.debug(
f"[message.apply_set_reply_logic:{message.apply_set_reply_logic},message.is_head:{message.is_head},thinking_messages_count:{thinking_messages_count},thinking_messages_length:{thinking_messages_length},message.is_private_message():{message.is_private_message()}]"
)
# logger.debug(
# f"[message.apply_set_reply_logic:{message.apply_set_reply_logic},message.is_head:{message.is_head},thinking_messages_count:{thinking_messages_count},thinking_messages_length:{thinking_messages_length},message.is_private_message():{message.is_private_message()}]"
# )
if (
message.apply_set_reply_logic # 检查标记
and message.is_head

View File

@ -15,7 +15,7 @@ from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("chat_image")

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
# 创建APIRouter而不是FastAPI实例
router = APIRouter()

View File

@ -17,7 +17,7 @@ from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("emoji")
@ -289,7 +289,6 @@ def _to_emoji_objects(data):
except Exception as e:
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}")
load_errors += 1
return emoji_objects, load_errors
return emoji_objects, load_errors

View File

@ -29,7 +29,7 @@ from src.plugins.moods.moods import MoodManager
from src.heart_flow.utils_chat import get_chat_type_and_target_info
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒

View File

@ -11,7 +11,7 @@ from src.common.logger_manager import get_logger
from src.plugins.chat.utils import calculate_typing_time
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_logger("sender")

View File

@ -178,8 +178,11 @@ class NormalChat:
"""更新关系情绪"""
ori_response = ",".join(response_set)
stance, emotion = await self.gpt._get_emotion_tags(ori_response, message.processed_plain_text)
user_info = message.message_info.user_info
platform = user_info.platform
await relationship_manager.calculate_update_relationship_value(
chat_stream=self.chat_stream,
user_info,
platform,
label=emotion,
stance=stance, # 使用 self.chat_stream
)

View File

@ -26,6 +26,7 @@ try:
embed_manager.load_from_file()
except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager()
@ -34,6 +35,7 @@ try:
kg_manager.load_from_file()
except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")

View File

@ -13,8 +13,20 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
@dataclass
@ -52,20 +64,35 @@ class EmbeddingStore:
def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def batch_insert_strs(self, strs: List[str]) -> None:
def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串"""
# 逐项处理
for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"):
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
continue
total = len(strs)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
progress.update(task, advance=1)
continue
# 获取embedding
embedding = self._get_embedding(s)
# 获取embedding
embedding = self._get_embedding(s)
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1)
def save_to_file(self) -> None:
"""保存到文件"""
@ -191,7 +218,7 @@ class EmbeddingManager:
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()))
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库"""
@ -200,7 +227,7 @@ class EmbeddingManager:
for triple in triple_list:
entities.add(triple[0])
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities))
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库"""
@ -208,7 +235,7 @@ class EmbeddingManager:
for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self):
"""从文件加载"""

View File

@ -5,7 +5,16 @@ from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import tqdm
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from quick_algo import di_graph, pagerank
@ -132,41 +141,56 @@ class KGManager:
ent_hash_list = list(ent_hash_list)
synonym_hash_set = set()
synonym_result = dict()
# 对每个实体节点,查找其相似的实体节点,建立扩展连接
for ent_hash in tqdm.tqdm(ent_hash_list):
if ent_hash in synonym_hash_set:
# 避免同一批次内重复添加
continue
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
continue
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
)
res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash:
# 避免自连接
# rich 进度条
total = len(ent_hash_list)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("同义词连接", total=total)
for ent_hash in ent_hash_list:
if ent_hash in synonym_hash_set:
progress.update(task, advance=1)
continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
# 相似度阈值
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
progress.update(task, advance=1)
continue
node_to_node[(res_ent_hash, ent_hash)] = similarity
node_to_node[(ent_hash, res_ent_hash)] = similarity
synonym_hash_set.add(res_ent_hash)
new_edge_cnt += 1
res_ent.append(
(
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
similarity,
)
) # Debug
synonym_result[ent.str] = res_ent
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
)
res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash:
# 避免自连接
continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
# 相似度阈值
continue
node_to_node[(res_ent_hash, ent_hash)] = similarity
node_to_node[(ent_hash, res_ent_hash)] = similarity
synonym_hash_set.add(res_ent_hash)
new_edge_cnt += 1
res_ent.append(
(
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
similarity,
)
) # Debug
synonym_result[ent.str] = res_ent
progress.update(task, advance=1)
for k, v in synonym_result.items():
print(f'"{k}"的相似实体为:{v}')

View File

@ -1,7 +1,8 @@
import os
import toml
import sys
import argparse
# import argparse
from .global_logger import logger
PG_NAMESPACE = "paragraph"
@ -37,7 +38,8 @@ def _load_config(config, config_file_path):
# Check if all top-level keys from default config exist in the file config
for key in config.keys():
if key not in file_config:
print(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。")
logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。")
logger.critical("请通过template/lpmm_config_template.toml文件进行更新")
sys.exit(1)
if "llm_providers" in file_config:
@ -68,16 +70,11 @@ def _load_config(config, config_file_path):
logger.info(f"从文件中读取配置: {config_file_path}")
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
parser.add_argument(
"--config_path",
type=str,
default="lpmm_config.toml",
help="Path to the configuration file",
)
global_config = dict(
{
"lpmm": {
"version": "0.1.0",
},
"llm_providers": {
"localhost": {
"base_url": "https://api.siliconflow.cn/v1",
@ -136,8 +133,8 @@ global_config = dict(
)
# _load_config(global_config, parser.parse_args().config_path)
file_path = os.path.abspath(__file__)
dir_path = os.path.dirname(file_path)
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
config_path = os.path.join(root_path, "config", "lpmm_config.toml")
# file_path = os.path.abspath(__file__)
# dir_path = os.path.dirname(file_path)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
_load_config(global_config, config_path)

View File

@ -1,9 +1,13 @@
import json
import os
import glob
from typing import Any, Dict, List
from .lpmmconfig import INVALID_ENTITY, global_config
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体"""
@ -74,12 +78,22 @@ class OpenIE:
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@staticmethod
def _from_dict(data):
"""从字典中获取OpenIE对象"""
def _from_dict(data_list):
"""从多个字典合并OpenIE对象"""
# data_list: List[dict]
all_docs = []
for data in data_list:
all_docs.extend(data.get("docs", []))
# 重新计算统计
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
return OpenIE(
docs=data["docs"],
avg_ent_chars=data["avg_ent_chars"],
avg_ent_words=data["avg_ent_words"],
docs=all_docs,
avg_ent_chars=avg_ent_chars,
avg_ent_words=avg_ent_words,
)
def _to_dict(self):
@ -92,12 +106,20 @@ class OpenIE:
@staticmethod
def load() -> "OpenIE":
"""从文件中加载OpenIE数据"""
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
data = json.loads(f.read())
openie_data = OpenIE._from_dict(data)
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, "r", encoding="utf-8") as f:
data = json.load(f)
data_list.append(data)
if not data_list:
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
openie_data = OpenIE._from_dict(data_list)
return openie_data
@staticmethod
@ -132,3 +154,8 @@ class OpenIE:
"""提取原始段落"""
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
return raw_paragraph_dict
if __name__ == "__main__":
# 测试代码
print(ROOT_PATH)

View File

@ -6,21 +6,25 @@ from .lpmmconfig import global_config
from .utils.hash import get_sha256
def load_raw_data() -> tuple[list[str], list[str]]:
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
"""加载原始数据文件
读取原始数据文件将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns:
- raw_data: 原始数据字典
- md5_set: 原始数据的SHA256集合
- raw_data: 原始数据列表
- sha256_list: 原始数据的SHA256集合
"""
# 读取import.json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
# 读取指定路径或默认路径的json文件
json_path = path if path else global_config["persistence"]["raw_data_path"]
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
import_json = json.loads(f.read())
else:
raise Exception("原始数据文件读取失败")
raise Exception(f"原始数据文件读取失败: {json_path}")
# import_json内容示例
# import_json = [
# "The capital of China is Beijing. The capital of France is Paris.",

View File

@ -22,7 +22,7 @@ from ..chat.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
def calculate_information_content(text):

View File

@ -10,7 +10,7 @@ from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
async def test_memory_system():

View File

@ -11,7 +11,7 @@ from Hippocampus import Hippocampus # 海马体和记忆图
from dotenv import load_dotenv
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
"""

View File

@ -8,7 +8,7 @@ import requests
from src.common.logger import get_module_logger
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_module_logger("offline_llm")

View File

@ -3,7 +3,7 @@ from scipy import stats
from datetime import datetime, timedelta
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
class DistributionVisualizer:

View File

@ -16,7 +16,7 @@ from ...common.database import db
from ...config.config import global_config
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_module_logger("model_utils")

View File

@ -5,6 +5,7 @@ from bson.decimal128 import Decimal128
from .person_info import person_info_manager
import time
import random
from maim_message import UserInfo
# import re
# import traceback
@ -102,7 +103,7 @@ class RelationshipManager:
# await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar)
await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar)
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> tuple:
async def calculate_update_relationship_value(self, user_info: UserInfo, platform: str, label: str, stance: str):
"""计算并变更关系值
新的关系值变更计算方式
将关系值限定在-1000到1000
@ -134,11 +135,11 @@ class RelationshipManager:
"困惑": 0.5,
}
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
person_id = person_info_manager.get_person_id(platform, user_info.user_id)
data = {
"platform": chat_stream.user_info.platform,
"user_id": chat_stream.user_info.user_id,
"nickname": chat_stream.user_info.user_nickname,
"platform": platform,
"user_id": user_info.user_id,
"nickname": user_info.user_nickname,
"konw_time": int(time.time()),
}
old_value = await person_info_manager.get_value(person_id, "relationship_value")
@ -178,7 +179,7 @@ class RelationshipManager:
level_num = self.calculate_level_num(old_value + value)
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
logger.info(
f"用户: {chat_stream.user_info.user_nickname}"
f"用户: {user_info.user_nickname}"
f"当前关系: {relationship_level[level_num]}, "
f"关系值: {old_value:.2f}, "
f"当前立场情感: {stance}-{label}, "
@ -187,8 +188,6 @@ class RelationshipManager:
await person_info_manager.update_one_field(person_id, "relationship_value", old_value + value, data)
return chat_stream.user_info.user_nickname, value, relationship_level[level_num]
async def calculate_update_relationship_value_with_reason(
self, chat_stream: ChatStream, label: str, stance: str, reason: str
) -> tuple:

View File

@ -7,7 +7,7 @@ from src.common.logger import get_module_logger
# import traceback
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
logger = get_module_logger("prompt_build")

View File

@ -4,7 +4,7 @@ from typing import Optional, Dict, Callable
import asyncio
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
"""
# 更好的计时器

View File

@ -10,7 +10,7 @@ from typing import Dict, Optional
import asyncio
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
"""
基类方法概览

View File

@ -9,7 +9,7 @@ from rich.console import Console
from rich.table import Table
from rich.traceback import install
install(show_locals=True, extra_lines=3)
install(extra_lines=3)
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))

View File

@ -1,3 +1,6 @@
[lpmm]
version = "0.1.0"
# LLM API 服务提供商,可配置多个
[[llm_providers]]
name = "localhost"
@ -51,7 +54,7 @@ res_top_k = 3 # 最终提供的文段TopK
[persistence]
# 持久化配置(存储中间数据,防止重复计算)
data_root_path = "data" # 数据根目录
raw_data_path = "data/import.json" # 原始数据路径
openie_data_path = "data/openie.json" # OpenIE数据路径
raw_data_path = "data/imported_lpmm_data" # 原始数据路径
openie_data_path = "data/openie" # OpenIE数据路径
embedding_data_dir = "data/embedding" # 嵌入数据目录
rag_data_dir = "data/rag" # RAG数据目录