diff --git a/scripts/import_openie.py b/scripts/import_openie.py
index 2a6e09b7..851cc8b3 100644
--- a/scripts/import_openie.py
+++ b/scripts/import_openie.py
@@ -174,6 +174,11 @@ def main():
embed_manager.load_from_file()
except Exception as e:
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
+ if "嵌入模型与本地存储不一致" in str(e):
+ logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
+ logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
+ # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
+ sys.exit(1)
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
diff --git a/src/api/__init__.py b/src/api/__init__.py
index f5bc08a6..e69de29b 100644
--- a/src/api/__init__.py
+++ b/src/api/__init__.py
@@ -1,8 +0,0 @@
-from fastapi import FastAPI
-from strawberry.fastapi import GraphQLRouter
-
-app = FastAPI()
-
-graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
-
-app.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
diff --git a/src/api/main.py b/src/api/main.py
new file mode 100644
index 00000000..be225940
--- /dev/null
+++ b/src/api/main.py
@@ -0,0 +1,30 @@
+from fastapi import APIRouter
+from strawberry.fastapi import GraphQLRouter
+
+# from src.config.config import BotConfig
+from src.common.logger_manager import get_logger
+from src.api.reload_config import reload_config as reload_config_func
+from src.common.server import global_server
+# import uvicorn
+# import os
+
+router = APIRouter()
+
+
+logger = get_logger("api")
+
+# maiapi = FastAPI()
+logger.info("API server started.")
+graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
+
+router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
+
+
+@router.post("/config/reload")
+async def reload_config():
+ return await reload_config_func()
+
+
+def start_api_server():
+ """启动API服务器"""
+ global_server.register_router(router, prefix="/api/v1")
diff --git a/src/api/reload_config.py b/src/api/reload_config.py
new file mode 100644
index 00000000..d77cb536
--- /dev/null
+++ b/src/api/reload_config.py
@@ -0,0 +1,19 @@
+from fastapi import HTTPException
+from rich.traceback import install
+from src.config.config import BotConfig
+import os
+
+install(extra_lines=3)
+
+
+async def reload_config():
+ try:
+ from src.config import config as config_module
+
+ bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
+ config_module.global_config = BotConfig.load_config(config_path=bot_config_path)
+ return {"status": "reloaded"}
+ except FileNotFoundError as e:
+ raise HTTPException(status_code=404, detail=str(e)) from e
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
diff --git a/src/common/logger.py b/src/common/logger.py
index a82c6d88..88fc427f 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -808,6 +808,22 @@ INIT_STYLE_CONFIG = {
},
}
+API_SERVER_STYLE_CONFIG = {
+ "advanced": {
+ "console_format": (
+ "{time:YYYY-MM-DD HH:mm:ss} | "
+ "{level: <8} | "
+ "API服务 | "
+ "{message}"
+ ),
+ "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}",
+ },
+ "simple": {
+ "console_format": "{time:MM-DD HH:mm} | API服务 | {message}",
+ "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}",
+ },
+}
+
# 根据SIMPLE_OUTPUT选择配置
MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"]
@@ -878,6 +894,7 @@ CHAT_MESSAGE_STYLE_CONFIG = (
)
CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"]
INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"]
+API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"]
def is_registered_module(record: dict) -> bool:
diff --git a/src/common/logger_manager.py b/src/common/logger_manager.py
index 5c553838..4c28f82f 100644
--- a/src/common/logger_manager.py
+++ b/src/common/logger_manager.py
@@ -41,6 +41,7 @@ from src.common.logger import (
CHAT_MESSAGE_STYLE_CONFIG,
CHAT_IMAGE_STYLE_CONFIG,
INIT_STYLE_CONFIG,
+ API_SERVER_STYLE_CONFIG,
)
# 可根据实际需要补充更多模块配置
@@ -86,6 +87,7 @@ MODULE_LOGGER_CONFIGS = {
"chat_message": CHAT_MESSAGE_STYLE_CONFIG, # 聊天消息
"chat_image": CHAT_IMAGE_STYLE_CONFIG, # 聊天图片
"init": INIT_STYLE_CONFIG, # 初始化
+ "api": API_SERVER_STYLE_CONFIG, # API服务器
# ...如有更多模块,继续添加...
}
diff --git a/src/main.py b/src/main.py
index 26a56ca2..be71524e 100644
--- a/src/main.py
+++ b/src/main.py
@@ -18,6 +18,7 @@ from .plugins.remote import heartbeat_thread # noqa: F401
from .individuality.individuality import Individuality
from .common.server import global_server
from rich.traceback import install
+from .api.main import start_api_server
install(extra_lines=3)
@@ -54,6 +55,9 @@ class MainSystem:
self.llm_stats.start()
logger.success("LLM统计功能启动成功")
+ # 启动API服务器
+ start_api_server()
+ logger.success("API服务器启动成功")
# 初始化表情管理器
emoji_manager.initialize()
logger.success("表情包管理器初始化成功")
diff --git a/src/plugins/config_reload/__init__.py b/src/plugins/config_reload/__init__.py
deleted file mode 100644
index 8b137891..00000000
--- a/src/plugins/config_reload/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/plugins/config_reload/api.py b/src/plugins/config_reload/api.py
deleted file mode 100644
index 56240b88..00000000
--- a/src/plugins/config_reload/api.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from fastapi import APIRouter, HTTPException
-from rich.traceback import install
-
-install(extra_lines=3)
-
-# 创建APIRouter而不是FastAPI实例
-router = APIRouter()
-
-
-@router.post("/reload-config")
-async def reload_config():
- try: # TODO: 实现配置重载
- # bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
- # BotConfig.reload_config(config_path=bot_config_path)
- return {"message": "TODO: 实现配置重载", "status": "unimplemented"}
- except FileNotFoundError as e:
- raise HTTPException(status_code=404, detail=str(e)) from e
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
diff --git a/src/plugins/config_reload/test.py b/src/plugins/config_reload/test.py
deleted file mode 100644
index fc4fc1e8..00000000
--- a/src/plugins/config_reload/test.py
+++ /dev/null
@@ -1,4 +0,0 @@
-import requests
-
-response = requests.post("http://localhost:8080/api/reload-config")
-print(response.json())
diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py
index e734f4e9..5ee92a86 100644
--- a/src/plugins/knowledge/src/embedding_store.py
+++ b/src/plugins/knowledge/src/embedding_store.py
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import json
import os
+import math
from typing import Dict, List, Tuple
import numpy as np
@@ -25,9 +26,39 @@ from rich.progress import (
)
install(extra_lines=3)
-
+ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
+# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
+# 这些字符串的嵌入结果应该是固定的,不能随时间变化
+EMBEDDING_TEST_STRINGS = [
+ "阿卡伊真的太好玩了,神秘性感大女同等着你",
+ "你怎么知道我arc12.64了",
+ "我是蕾缪乐小姐的狗",
+ "关注Oct谢谢喵",
+ "不是w6我不草",
+ "关注千石可乐谢谢喵",
+ "来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵",
+ "关注墨梓柒谢谢喵",
+ "Ciallo~",
+ "来玩巧克甜恋谢谢喵",
+ "水印",
+ "我也在纠结晚饭,铁锅炒鸡听着就香!",
+ "test你妈喵",
+]
+EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
+EMBEDDING_SIM_THRESHOLD = 0.99
+
+
+def cosine_similarity(a, b):
+ # 计算余弦相似度
+ dot = sum(x * y for x, y in zip(a, b))
+ norm_a = math.sqrt(sum(x * x for x in a))
+ norm_b = math.sqrt(sum(x * x for x in b))
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+ return dot / (norm_a * norm_b)
+
@dataclass
class EmbeddingStoreItem:
@@ -64,6 +95,46 @@ class EmbeddingStore:
def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
+ def get_test_file_path(self):
+ return EMBEDDING_TEST_FILE
+
+ def save_embedding_test_vectors(self):
+ """保存测试字符串的嵌入到本地"""
+ test_vectors = {}
+ for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
+ test_vectors[str(idx)] = self._get_embedding(s)
+ with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
+ json.dump(test_vectors, f, ensure_ascii=False, indent=2)
+
+ def load_embedding_test_vectors(self):
+ """加载本地保存的测试字符串嵌入"""
+ path = self.get_test_file_path()
+ if not os.path.exists(path):
+ return None
+ with open(path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+ def check_embedding_model_consistency(self):
+ """校验当前模型与本地嵌入模型是否一致"""
+ local_vectors = self.load_embedding_test_vectors()
+ if local_vectors is None:
+ logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
+ self.save_embedding_test_vectors()
+ return True
+ for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
+ local_emb = local_vectors.get(str(idx))
+ if local_emb is None:
+ logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
+ self.save_embedding_test_vectors()
+ return True
+ new_emb = self._get_embedding(s)
+ sim = cosine_similarity(local_emb, new_emb)
+ if sim < EMBEDDING_SIM_THRESHOLD:
+ logger.error("嵌入模型一致性校验失败")
+ return False
+ logger.info("嵌入模型一致性校验通过。")
+ return True
+
def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串"""
total = len(strs)
@@ -216,6 +287,17 @@ class EmbeddingManager:
)
self.stored_pg_hashes = set()
+ def check_all_embedding_model_consistency(self):
+ """对所有嵌入库做模型一致性校验"""
+ for store in [
+ self.paragraphs_embedding_store,
+ self.entities_embedding_store,
+ self.relation_embedding_store,
+ ]:
+ if not store.check_embedding_model_consistency():
+ return False
+ return True
+
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
@@ -239,6 +321,8 @@ class EmbeddingManager:
def load_from_file(self):
"""从文件加载"""
+ if not self.check_all_embedding_model_consistency():
+ raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
self.paragraphs_embedding_store.load_from_file()
self.entities_embedding_store.load_from_file()
self.relation_embedding_store.load_from_file()
@@ -250,6 +334,8 @@ class EmbeddingManager:
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
):
+ if not self.check_all_embedding_model_consistency():
+ raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
"""存储新的数据集"""
self._store_pg_into_embedding(raw_paragraphs)
self._store_ent_into_embedding(triple_list_data)