mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/SnowindMe/MaiBot into dev
commit
52acf4ee0a
|
|
@ -174,6 +174,11 @@ def main():
|
|||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
||||
if "嵌入模型与本地存储不一致" in str(e):
|
||||
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
sys.exit(1)
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
from fastapi import FastAPI
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
|
||||
|
||||
app.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from fastapi import APIRouter
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
# from src.config.config import BotConfig
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.api.reload_config import reload_config as reload_config_func
|
||||
from src.common.server import global_server
|
||||
# import uvicorn
|
||||
# import os
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
# maiapi = FastAPI()
|
||||
logger.info("API server started.")
|
||||
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
|
||||
|
||||
router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
||||
|
||||
|
||||
@router.post("/config/reload")
|
||||
async def reload_config():
|
||||
return await reload_config_func()
|
||||
|
||||
|
||||
def start_api_server():
|
||||
"""启动API服务器"""
|
||||
global_server.register_router(router, prefix="/api/v1")
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from fastapi import HTTPException
|
||||
from rich.traceback import install
|
||||
from src.config.config import BotConfig
|
||||
import os
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
async def reload_config():
|
||||
try:
|
||||
from src.config import config as config_module
|
||||
|
||||
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
|
||||
config_module.global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
return {"status": "reloaded"}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
|
||||
|
|
@ -808,6 +808,22 @@ INIT_STYLE_CONFIG = {
|
|||
},
|
||||
}
|
||||
|
||||
API_SERVER_STYLE_CONFIG = {
|
||||
"advanced": {
|
||||
"console_format": (
|
||||
"<white>{time:YYYY-MM-DD HH:mm:ss}</white> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<light-yellow>API服务</light-yellow> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": "<level>{time:MM-DD HH:mm}</level> | <light-green>API服务</light-green> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 根据SIMPLE_OUTPUT选择配置
|
||||
MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"]
|
||||
|
|
@ -878,6 +894,7 @@ CHAT_MESSAGE_STYLE_CONFIG = (
|
|||
)
|
||||
CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"]
|
||||
INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"]
|
||||
API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"]
|
||||
|
||||
|
||||
def is_registered_module(record: dict) -> bool:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from src.common.logger import (
|
|||
CHAT_MESSAGE_STYLE_CONFIG,
|
||||
CHAT_IMAGE_STYLE_CONFIG,
|
||||
INIT_STYLE_CONFIG,
|
||||
API_SERVER_STYLE_CONFIG,
|
||||
)
|
||||
|
||||
# 可根据实际需要补充更多模块配置
|
||||
|
|
@ -86,6 +87,7 @@ MODULE_LOGGER_CONFIGS = {
|
|||
"chat_message": CHAT_MESSAGE_STYLE_CONFIG, # 聊天消息
|
||||
"chat_image": CHAT_IMAGE_STYLE_CONFIG, # 聊天图片
|
||||
"init": INIT_STYLE_CONFIG, # 初始化
|
||||
"api": API_SERVER_STYLE_CONFIG, # API服务器
|
||||
# ...如有更多模块,继续添加...
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from .plugins.remote import heartbeat_thread # noqa: F401
|
|||
from .individuality.individuality import Individuality
|
||||
from .common.server import global_server
|
||||
from rich.traceback import install
|
||||
from .api.main import start_api_server
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
|
@ -54,6 +55,9 @@ class MainSystem:
|
|||
self.llm_stats.start()
|
||||
logger.success("LLM统计功能启动成功")
|
||||
|
||||
# 启动API服务器
|
||||
start_api_server()
|
||||
logger.success("API服务器启动成功")
|
||||
# 初始化表情管理器
|
||||
emoji_manager.initialize()
|
||||
logger.success("表情包管理器初始化成功")
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 创建APIRouter而不是FastAPI实例
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config():
|
||||
try: # TODO: 实现配置重载
|
||||
# bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
|
||||
# BotConfig.reload_config(config_path=bot_config_path)
|
||||
return {"message": "TODO: 实现配置重载", "status": "unimplemented"}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
import requests
|
||||
|
||||
response = requests.post("http://localhost:8080/api/reload-config")
|
||||
print(response.json())
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -25,9 +26,39 @@ from rich.progress import (
|
|||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||
|
||||
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
|
||||
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
|
||||
EMBEDDING_TEST_STRINGS = [
|
||||
"阿卡伊真的太好玩了,神秘性感大女同等着你",
|
||||
"你怎么知道我arc12.64了",
|
||||
"我是蕾缪乐小姐的狗",
|
||||
"关注Oct谢谢喵",
|
||||
"不是w6我不草",
|
||||
"关注千石可乐谢谢喵",
|
||||
"来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵",
|
||||
"关注墨梓柒谢谢喵",
|
||||
"Ciallo~",
|
||||
"来玩巧克甜恋谢谢喵",
|
||||
"水印",
|
||||
"我也在纠结晚饭,铁锅炒鸡听着就香!",
|
||||
"test你妈喵",
|
||||
]
|
||||
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
|
||||
EMBEDDING_SIM_THRESHOLD = 0.99
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
# 计算余弦相似度
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingStoreItem:
|
||||
|
|
@ -64,6 +95,46 @@ class EmbeddingStore:
|
|||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||
|
||||
def get_test_file_path(self):
|
||||
return EMBEDDING_TEST_FILE
|
||||
|
||||
def save_embedding_test_vectors(self):
|
||||
"""保存测试字符串的嵌入到本地"""
|
||||
test_vectors = {}
|
||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def load_embedding_test_vectors(self):
|
||||
"""加载本地保存的测试字符串嵌入"""
|
||||
path = self.get_test_file_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def check_embedding_model_consistency(self):
|
||||
"""校验当前模型与本地嵌入模型是否一致"""
|
||||
local_vectors = self.load_embedding_test_vectors()
|
||||
if local_vectors is None:
|
||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
||||
local_emb = local_vectors.get(str(idx))
|
||||
if local_emb is None:
|
||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
new_emb = self._get_embedding(s)
|
||||
sim = cosine_similarity(local_emb, new_emb)
|
||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||
logger.error("嵌入模型一致性校验失败")
|
||||
return False
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||
"""向库中存入字符串"""
|
||||
total = len(strs)
|
||||
|
|
@ -216,6 +287,17 @@ class EmbeddingManager:
|
|||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
def check_all_embedding_model_consistency(self):
|
||||
"""对所有嵌入库做模型一致性校验"""
|
||||
for store in [
|
||||
self.paragraphs_embedding_store,
|
||||
self.entities_embedding_store,
|
||||
self.relation_embedding_store,
|
||||
]:
|
||||
if not store.check_embedding_model_consistency():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||
|
|
@ -239,6 +321,8 @@ class EmbeddingManager:
|
|||
|
||||
def load_from_file(self):
|
||||
"""从文件加载"""
|
||||
if not self.check_all_embedding_model_consistency():
|
||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||
self.paragraphs_embedding_store.load_from_file()
|
||||
self.entities_embedding_store.load_from_file()
|
||||
self.relation_embedding_store.load_from_file()
|
||||
|
|
@ -250,6 +334,8 @@ class EmbeddingManager:
|
|||
raw_paragraphs: Dict[str, str],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
):
|
||||
if not self.check_all_embedding_model_consistency():
|
||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||
"""存储新的数据集"""
|
||||
self._store_pg_into_embedding(raw_paragraphs)
|
||||
self._store_ent_into_embedding(triple_list_data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue