diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 2a6e09b7..851cc8b3 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -174,6 +174,11 @@ def main(): embed_manager.load_from_file() except Exception as e: logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + if "嵌入模型与本地存储不一致" in str(e): + logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") + logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") + # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") + sys.exit(1) logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG diff --git a/src/api/__init__.py b/src/api/__init__.py index f5bc08a6..e69de29b 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -1,8 +0,0 @@ -from fastapi import FastAPI -from strawberry.fastapi import GraphQLRouter - -app = FastAPI() - -graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema - -app.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) diff --git a/src/api/main.py b/src/api/main.py new file mode 100644 index 00000000..be225940 --- /dev/null +++ b/src/api/main.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter +from strawberry.fastapi import GraphQLRouter + +# from src.config.config import BotConfig +from src.common.logger_manager import get_logger +from src.api.reload_config import reload_config as reload_config_func +from src.common.server import global_server +# import uvicorn +# import os + +router = APIRouter() + + +logger = get_logger("api") + +# maiapi = FastAPI() +logger.info("API server started.") +graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema + +router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) + + +@router.post("/config/reload") +async def reload_config(): + return await reload_config_func() + + +def start_api_server(): + """启动API服务器""" + global_server.register_router(router, prefix="/api/v1") diff --git a/src/api/reload_config.py b/src/api/reload_config.py new file mode 100644 index 00000000..d77cb536 --- /dev/null +++ b/src/api/reload_config.py @@ -0,0 +1,19 @@ +from fastapi import HTTPException +from rich.traceback import install +from src.config.config import BotConfig +import os + +install(extra_lines=3) + + +async def reload_config(): + try: + from src.config import config as config_module + + bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") + config_module.global_config = BotConfig.load_config(config_path=bot_config_path) + return {"status": "reloaded"} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e diff --git a/src/common/logger.py b/src/common/logger.py index a82c6d88..88fc427f 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -808,6 +808,22 @@ INIT_STYLE_CONFIG = { }, } +API_SERVER_STYLE_CONFIG = { + "advanced": { + "console_format": ( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <8} | " + "API服务 | " + "{message}" + ), + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}", + }, + "simple": { + "console_format": "{time:MM-DD HH:mm} | API服务 | {message}", + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}", + }, +} + # 根据SIMPLE_OUTPUT选择配置 MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"] @@ -878,6 +894,7 @@ CHAT_MESSAGE_STYLE_CONFIG = ( ) CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"] INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"] +API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"] def is_registered_module(record: dict) -> bool: diff --git a/src/common/logger_manager.py b/src/common/logger_manager.py index 5c553838..4c28f82f 100644 --- a/src/common/logger_manager.py +++ b/src/common/logger_manager.py @@ -41,6 +41,7 @@ from src.common.logger import ( CHAT_MESSAGE_STYLE_CONFIG, CHAT_IMAGE_STYLE_CONFIG, INIT_STYLE_CONFIG, + API_SERVER_STYLE_CONFIG, ) # 可根据实际需要补充更多模块配置 @@ -86,6 +87,7 @@ MODULE_LOGGER_CONFIGS = { "chat_message": CHAT_MESSAGE_STYLE_CONFIG, # 聊天消息 "chat_image": CHAT_IMAGE_STYLE_CONFIG, # 聊天图片 "init": INIT_STYLE_CONFIG, # 初始化 + "api": API_SERVER_STYLE_CONFIG, # API服务器 # ...如有更多模块,继续添加... } diff --git a/src/main.py b/src/main.py index 26a56ca2..be71524e 100644 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ from .plugins.remote import heartbeat_thread # noqa: F401 from .individuality.individuality import Individuality from .common.server import global_server from rich.traceback import install +from .api.main import start_api_server install(extra_lines=3) @@ -54,6 +55,9 @@ class MainSystem: self.llm_stats.start() logger.success("LLM统计功能启动成功") + # 启动API服务器 + start_api_server() + logger.success("API服务器启动成功") # 初始化表情管理器 emoji_manager.initialize() logger.success("表情包管理器初始化成功") diff --git a/src/plugins/config_reload/__init__.py b/src/plugins/config_reload/__init__.py deleted file mode 100644 index 8b137891..00000000 --- a/src/plugins/config_reload/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/plugins/config_reload/api.py b/src/plugins/config_reload/api.py deleted file mode 100644 index 56240b88..00000000 --- a/src/plugins/config_reload/api.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter, HTTPException -from rich.traceback import install - -install(extra_lines=3) - -# 创建APIRouter而不是FastAPI实例 -router = APIRouter() - - -@router.post("/reload-config") -async def reload_config(): - try: # TODO: 实现配置重载 - # bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") - # BotConfig.reload_config(config_path=bot_config_path) - return {"message": "TODO: 实现配置重载", "status": "unimplemented"} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - except Exception as e: - raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e diff --git a/src/plugins/config_reload/test.py b/src/plugins/config_reload/test.py deleted file mode 100644 index fc4fc1e8..00000000 --- a/src/plugins/config_reload/test.py +++ /dev/null @@ -1,4 +0,0 @@ -import requests - -response = requests.post("http://localhost:8080/api/reload-config") -print(response.json()) diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index e734f4e9..5ee92a86 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import json import os +import math from typing import Dict, List, Tuple import numpy as np @@ -25,9 +26,39 @@ from rich.progress import ( ) install(extra_lines=3) - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 +# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录 +# 这些字符串的嵌入结果应该是固定的,不能随时间变化 +EMBEDDING_TEST_STRINGS = [ + "阿卡伊真的太好玩了,神秘性感大女同等着你", + "你怎么知道我arc12.64了", + "我是蕾缪乐小姐的狗", + "关注Oct谢谢喵", + "不是w6我不草", + "关注千石可乐谢谢喵", + "来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵", + "关注墨梓柒谢谢喵", + "Ciallo~", + "来玩巧克甜恋谢谢喵", + "水印", + "我也在纠结晚饭,铁锅炒鸡听着就香!", + "test你妈喵", +] +EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json") +EMBEDDING_SIM_THRESHOLD = 0.99 + + +def cosine_similarity(a, b): + # 计算余弦相似度 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + @dataclass class EmbeddingStoreItem: @@ -64,6 +95,46 @@ class EmbeddingStore: def _get_embedding(self, s: str) -> List[float]: return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) + def get_test_file_path(self): + return EMBEDDING_TEST_FILE + + def save_embedding_test_vectors(self): + """保存测试字符串的嵌入到本地""" + test_vectors = {} + for idx, s in enumerate(EMBEDDING_TEST_STRINGS): + test_vectors[str(idx)] = self._get_embedding(s) + with open(self.get_test_file_path(), "w", encoding="utf-8") as f: + json.dump(test_vectors, f, ensure_ascii=False, indent=2) + + def load_embedding_test_vectors(self): + """加载本地保存的测试字符串嵌入""" + path = self.get_test_file_path() + if not os.path.exists(path): + return None + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + def check_embedding_model_consistency(self): + """校验当前模型与本地嵌入模型是否一致""" + local_vectors = self.load_embedding_test_vectors() + if local_vectors is None: + logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") + self.save_embedding_test_vectors() + return True + for idx, s in enumerate(EMBEDDING_TEST_STRINGS): + local_emb = local_vectors.get(str(idx)) + if local_emb is None: + logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") + self.save_embedding_test_vectors() + return True + new_emb = self._get_embedding(s) + sim = cosine_similarity(local_emb, new_emb) + if sim < EMBEDDING_SIM_THRESHOLD: + logger.error("嵌入模型一致性校验失败") + return False + logger.info("嵌入模型一致性校验通过。") + return True + def batch_insert_strs(self, strs: List[str], times: int) -> None: """向库中存入字符串""" total = len(strs) @@ -216,6 +287,17 @@ class EmbeddingManager: ) self.stored_pg_hashes = set() + def check_all_embedding_model_consistency(self): + """对所有嵌入库做模型一致性校验""" + for store in [ + self.paragraphs_embedding_store, + self.entities_embedding_store, + self.relation_embedding_store, + ]: + if not store.check_embedding_model_consistency(): + return False + return True + def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) @@ -239,6 +321,8 @@ class EmbeddingManager: def load_from_file(self): """从文件加载""" + if not self.check_all_embedding_model_consistency(): + raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") self.paragraphs_embedding_store.load_from_file() self.entities_embedding_store.load_from_file() self.relation_embedding_store.load_from_file() @@ -250,6 +334,8 @@ class EmbeddingManager: raw_paragraphs: Dict[str, str], triple_list_data: Dict[str, List[List[str]]], ): + if not self.check_all_embedding_model_consistency(): + raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") """存储新的数据集""" self._store_pg_into_embedding(raw_paragraphs) self._store_ent_into_embedding(triple_list_data)