mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/SnowindMe/MaiBot into dev
commit
f0bddd6dd0
21
bot.py
21
bot.py
|
|
@ -33,6 +33,23 @@ driver = None
|
|||
app = None
|
||||
loop = None
|
||||
|
||||
# shutdown_requested = False # 新增全局变量
|
||||
|
||||
|
||||
async def request_shutdown() -> bool:
|
||||
"""请求关闭程序"""
|
||||
try:
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.run_until_complete(graceful_shutdown())
|
||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"请求关闭程序时发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def easter_egg():
|
||||
# 彩蛋
|
||||
|
|
@ -252,6 +269,8 @@ if __name__ == "__main__":
|
|||
loop.run_until_complete(graceful_shutdown())
|
||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
# 新增:检测外部请求关闭
|
||||
|
||||
# except Exception as e: # 将主异常捕获移到外层 try...except
|
||||
# logger.error(f"事件循环内发生错误: {str(e)} {str(traceback.format_exc())}")
|
||||
# exit_code = 1
|
||||
|
|
@ -271,5 +290,5 @@ if __name__ == "__main__":
|
|||
loop.close()
|
||||
logger.info("事件循环已关闭")
|
||||
# 在程序退出前暂停,让你有机会看到输出
|
||||
input("按 Enter 键退出...") # <--- 添加这行
|
||||
# input("按 Enter 键退出...") # <--- 添加这行
|
||||
sys.exit(exit_code) # <--- 使用记录的退出码
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from src.heart_flow.heartflow import heartflow
|
||||
from src.heart_flow.sub_heartflow import ChatState
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
|
||||
async def get_all_subheartflow_ids() -> list:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from fastapi import APIRouter
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
# from src.config.config import BotConfig
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.api.reload_config import reload_config as reload_config_func
|
||||
from src.common.server import global_server
|
||||
from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status
|
||||
from src.api.apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status
|
||||
from src.heart_flow.sub_heartflow import ChatState
|
||||
|
||||
# import uvicorn
|
||||
# import os
|
||||
|
||||
|
|
@ -51,6 +55,20 @@ async def forced_change_subheartflow_status_api(subheartflow_id: str, status: Ch
|
|||
return {"status": "failed"}
|
||||
|
||||
|
||||
@router.get("/stop")
|
||||
async def force_stop_maibot():
|
||||
"""强制停止MAI Bot"""
|
||||
from bot import request_shutdown
|
||||
|
||||
success = await request_shutdown()
|
||||
if success:
|
||||
logger.info("MAI Bot已强制停止")
|
||||
return {"status": "success"}
|
||||
else:
|
||||
logger.error("MAI Bot强制停止失败")
|
||||
return {"status": "failed"}
|
||||
|
||||
|
||||
def start_api_server():
|
||||
"""启动API服务器"""
|
||||
global_server.register_router(router, prefix="/api/v1")
|
||||
|
|
|
|||
|
|
@ -80,7 +80,8 @@ _custom_style_handlers: dict[Tuple[str, str], List[int]] = {} # 记录自定义
|
|||
|
||||
# 获取日志存储根地址
|
||||
current_file_path = Path(__file__).resolve()
|
||||
LOG_ROOT = "logs"
|
||||
ROOT_PATH = os.path.abspath(os.path.join(current_file_path, "..", ".."))
|
||||
LOG_ROOT = str(ROOT_PATH) + "/" + "logs"
|
||||
|
||||
SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower()
|
||||
if SIMPLE_OUTPUT == "true":
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ from rich.progress import (
|
|||
|
||||
install(extra_lines=3)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = (
|
||||
os.path.join(ROOT_PATH, "data", "embedding")
|
||||
if global_config["persistence"]["embedding_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"])
|
||||
)
|
||||
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||
|
||||
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
|
||||
|
|
@ -288,17 +294,17 @@ class EmbeddingManager:
|
|||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
PG_NAMESPACE,
|
||||
global_config["persistence"]["embedding_data_dir"],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.entities_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
ENT_NAMESPACE,
|
||||
global_config["persistence"]["embedding_data_dir"],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.relation_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
REL_NAMESPACE,
|
||||
global_config["persistence"]["embedding_data_dir"],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ from .lpmmconfig import (
|
|||
|
||||
from .global_logger import logger
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
KG_DIR = (
|
||||
os.path.join(ROOT_PATH, "data/rag")
|
||||
if global_config["persistence"]["rag_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
|
||||
)
|
||||
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
|
||||
|
||||
|
||||
class KGManager:
|
||||
def __init__(self):
|
||||
|
|
@ -43,7 +51,7 @@ class KGManager:
|
|||
self.graph = di_graph.DiGraph()
|
||||
|
||||
# 持久化相关
|
||||
self.dir_path = global_config["persistence"]["rag_data_dir"]
|
||||
self.dir_path = KG_DIR_STR
|
||||
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
||||
|
|
|
|||
Loading…
Reference in New Issue