mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into PFC-test
commit
099500f8e6
21
bot.py
21
bot.py
|
|
@ -33,6 +33,23 @@ driver = None
|
|||
app = None
|
||||
loop = None
|
||||
|
||||
# shutdown_requested = False # 新增全局变量
|
||||
|
||||
|
||||
async def request_shutdown() -> bool:
|
||||
"""请求关闭程序"""
|
||||
try:
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.run_until_complete(graceful_shutdown())
|
||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"请求关闭程序时发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def easter_egg():
|
||||
# 彩蛋
|
||||
|
|
@ -252,6 +269,8 @@ if __name__ == "__main__":
|
|||
loop.run_until_complete(graceful_shutdown())
|
||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
# 新增:检测外部请求关闭
|
||||
|
||||
# except Exception as e: # 将主异常捕获移到外层 try...except
|
||||
# logger.error(f"事件循环内发生错误: {str(e)} {str(traceback.format_exc())}")
|
||||
# exit_code = 1
|
||||
|
|
@ -271,5 +290,5 @@ if __name__ == "__main__":
|
|||
loop.close()
|
||||
logger.info("事件循环已关闭")
|
||||
# 在程序退出前暂停,让你有机会看到输出
|
||||
input("按 Enter 键退出...") # <--- 添加这行
|
||||
# input("按 Enter 键退出...") # <--- 添加这行
|
||||
sys.exit(exit_code) # <--- 使用记录的退出码
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from src.heart_flow.heartflow import heartflow
|
||||
from src.heart_flow.sub_heartflow import ChatState
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
|
||||
async def get_all_subheartflow_ids() -> list:
|
||||
|
|
@ -14,3 +17,21 @@ async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatSt
|
|||
if subheartflow:
|
||||
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
||||
return False
|
||||
|
||||
|
||||
async def get_subheartflow_cycle_info(subheartflow_id: str, history_len: int) -> dict:
|
||||
"""获取子心流的循环信息"""
|
||||
subheartflow_cycle_info = await heartflow.api_get_subheartflow_cycle_info(subheartflow_id, history_len)
|
||||
logger.debug(f"子心流 {subheartflow_id} 循环信息: {subheartflow_cycle_info}")
|
||||
if subheartflow_cycle_info:
|
||||
return subheartflow_cycle_info
|
||||
else:
|
||||
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
||||
return None
|
||||
|
||||
|
||||
async def get_all_states():
|
||||
"""获取所有状态"""
|
||||
all_states = await heartflow.api_get_all_states()
|
||||
logger.debug(f"所有状态: {all_states}")
|
||||
return all_states
|
||||
|
|
|
|||
|
|
@ -1,12 +1,22 @@
|
|||
from fastapi import APIRouter
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
import os
|
||||
import sys
|
||||
|
||||
# from src.heart_flow.heartflow import heartflow
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
# from src.config.config import BotConfig
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.api.reload_config import reload_config as reload_config_func
|
||||
from src.common.server import global_server
|
||||
from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status
|
||||
from src.api.apiforgui import (
|
||||
get_all_subheartflow_ids,
|
||||
forced_change_subheartflow_status,
|
||||
get_subheartflow_cycle_info,
|
||||
get_all_states,
|
||||
)
|
||||
from src.heart_flow.sub_heartflow import ChatState
|
||||
|
||||
# import uvicorn
|
||||
# import os
|
||||
|
||||
|
|
@ -51,6 +61,42 @@ async def forced_change_subheartflow_status_api(subheartflow_id: str, status: Ch
|
|||
return {"status": "failed"}
|
||||
|
||||
|
||||
@router.get("/stop")
|
||||
async def force_stop_maibot():
|
||||
"""强制停止MAI Bot"""
|
||||
from bot import request_shutdown
|
||||
|
||||
success = await request_shutdown()
|
||||
if success:
|
||||
logger.info("MAI Bot已强制停止")
|
||||
return {"status": "success"}
|
||||
else:
|
||||
logger.error("MAI Bot强制停止失败")
|
||||
return {"status": "failed"}
|
||||
|
||||
|
||||
@router.get("/gui/subheartflow/cycleinfo")
|
||||
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
|
||||
"""获取子心流的循环信息"""
|
||||
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
|
||||
if cycle_info:
|
||||
return {"status": "success", "data": cycle_info}
|
||||
else:
|
||||
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
||||
return {"status": "failed", "reason": "subheartflow not found"}
|
||||
|
||||
|
||||
@router.get("/gui/get_all_states")
|
||||
async def get_all_states_api():
|
||||
"""获取所有状态"""
|
||||
all_states = await get_all_states()
|
||||
if all_states:
|
||||
return {"status": "success", "data": all_states}
|
||||
else:
|
||||
logger.warning("获取所有状态失败")
|
||||
return {"status": "failed", "reason": "failed to get all states"}
|
||||
|
||||
|
||||
def start_api_server():
|
||||
"""启动API服务器"""
|
||||
global_server.register_router(router, prefix="/api/v1")
|
||||
|
|
|
|||
|
|
@ -80,7 +80,8 @@ _custom_style_handlers: dict[Tuple[str, str], List[int]] = {} # 记录自定义
|
|||
|
||||
# 获取日志存储根地址
|
||||
current_file_path = Path(__file__).resolve()
|
||||
LOG_ROOT = "logs"
|
||||
ROOT_PATH = os.path.abspath(os.path.join(current_file_path, "..", ".."))
|
||||
LOG_ROOT = str(ROOT_PATH) + "/" + "logs"
|
||||
|
||||
SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower()
|
||||
if SIMPLE_OUTPUT == "true":
|
||||
|
|
|
|||
|
|
@ -67,6 +67,23 @@ class Heartflow:
|
|||
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
||||
return await self.subheartflow_manager.force_change_state(subheartflow_id, status)
|
||||
|
||||
async def api_get_all_states(self):
|
||||
"""获取所有状态"""
|
||||
return await self.interest_logger.api_get_all_states()
|
||||
|
||||
async def api_get_subheartflow_cycle_info(self, subheartflow_id: str, history_len: int) -> Optional[dict]:
|
||||
"""获取子心流的循环信息"""
|
||||
subheartflow = await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id)
|
||||
if not subheartflow:
|
||||
logger.warning(f"尝试获取不存在的子心流 {subheartflow_id} 的周期信息")
|
||||
return None
|
||||
heartfc_instance = subheartflow.heart_fc_instance
|
||||
if not heartfc_instance:
|
||||
logger.warning(f"子心流 {subheartflow_id} 没有心流实例,无法获取周期信息")
|
||||
return None
|
||||
|
||||
return heartfc_instance.get_cycle_history(last_n=history_len)
|
||||
|
||||
async def heartflow_start_working(self):
|
||||
"""启动后台任务"""
|
||||
await self.background_task_manager.start_tasks()
|
||||
|
|
|
|||
|
|
@ -158,3 +158,55 @@ class InterestLogger:
|
|||
except Exception as e:
|
||||
logger.error(f"记录状态时发生意外错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def api_get_all_states(self):
|
||||
"""获取主心流和所有子心流的状态。"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
|
||||
# main_mind = self.heartflow.current_mind
|
||||
# 获取 Mai 状态名称
|
||||
mai_state_name = self.heartflow.current_state.get_current_state().name
|
||||
|
||||
all_subflow_states = await self.get_all_subflow_states()
|
||||
|
||||
log_entry_base = {
|
||||
"timestamp": round(current_timestamp, 2),
|
||||
# "main_mind": main_mind,
|
||||
"mai_state": mai_state_name,
|
||||
"subflow_count": len(all_subflow_states),
|
||||
"subflows": [],
|
||||
}
|
||||
|
||||
subflow_details = []
|
||||
items_snapshot = list(all_subflow_states.items())
|
||||
for stream_id, state in items_snapshot:
|
||||
group_name = stream_id
|
||||
try:
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream:
|
||||
if chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
elif chat_stream.user_info:
|
||||
group_name = f"私聊_{chat_stream.user_info.user_nickname}"
|
||||
except Exception as e:
|
||||
logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}")
|
||||
|
||||
interest_state = state.get("interest_state", {})
|
||||
|
||||
subflow_entry = {
|
||||
"stream_id": stream_id,
|
||||
"group_name": group_name,
|
||||
"sub_mind": state.get("current_mind", "未知"),
|
||||
"sub_chat_state": state.get("chat_state", "未知"),
|
||||
"interest_level": interest_state.get("interest_level", 0.0),
|
||||
"start_hfc_probability": interest_state.get("start_hfc_probability", 0.0),
|
||||
# "is_above_threshold": interest_state.get("is_above_threshold", False),
|
||||
}
|
||||
subflow_details.append(subflow_entry)
|
||||
|
||||
log_entry_base["subflows"] = subflow_details
|
||||
return subflow_details
|
||||
except Exception as e:
|
||||
logger.error(f"记录状态时发生意外错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ from rich.progress import (
|
|||
|
||||
install(extra_lines=3)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = (
|
||||
os.path.join(ROOT_PATH, "data", "embedding")
|
||||
if global_config["persistence"]["embedding_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"])
|
||||
)
|
||||
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||
|
||||
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
|
||||
|
|
@ -288,17 +294,17 @@ class EmbeddingManager:
|
|||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
PG_NAMESPACE,
|
||||
global_config["persistence"]["embedding_data_dir"],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.entities_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
ENT_NAMESPACE,
|
||||
global_config["persistence"]["embedding_data_dir"],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.relation_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
REL_NAMESPACE,
|
||||
global_config["persistence"]["embedding_data_dir"],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ from .lpmmconfig import (
|
|||
|
||||
from .global_logger import logger
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
KG_DIR = (
|
||||
os.path.join(ROOT_PATH, "data/rag")
|
||||
if global_config["persistence"]["rag_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
|
||||
)
|
||||
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
|
||||
|
||||
|
||||
class KGManager:
|
||||
def __init__(self):
|
||||
|
|
@ -43,7 +51,7 @@ class KGManager:
|
|||
self.graph = di_graph.DiGraph()
|
||||
|
||||
# 持久化相关
|
||||
self.dir_path = global_config["persistence"]["rag_data_dir"]
|
||||
self.dir_path = KG_DIR_STR
|
||||
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
||||
|
|
|
|||
|
|
@ -5,15 +5,13 @@ import platform
|
|||
import os
|
||||
import json
|
||||
import threading
|
||||
from src.common.logger import get_module_logger, LogConfig, REMOTE_STYLE_CONFIG
|
||||
import subprocess
|
||||
|
||||
# from loguru import logger
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
remote_log_config = LogConfig(
|
||||
console_format=REMOTE_STYLE_CONFIG["console_format"],
|
||||
file_format=REMOTE_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
logger = get_module_logger("remote", config=remote_log_config)
|
||||
logger = get_logger("remote")
|
||||
|
||||
# --- 使用向上导航的方式定义路径 ---
|
||||
|
||||
|
|
@ -82,9 +80,76 @@ def get_unique_id():
|
|||
|
||||
# 生成客户端唯一ID
|
||||
def generate_unique_id():
|
||||
# 结合主机名、系统信息和随机UUID生成唯一ID
|
||||
# 基于机器码生成唯一ID,同一台机器上生成的UUID是固定的,只要机器码不变
|
||||
import hashlib
|
||||
|
||||
system_info = platform.system()
|
||||
unique_id = f"{system_info}-{uuid.uuid4()}"
|
||||
machine_code = None
|
||||
|
||||
try:
|
||||
if system_info == "Windows":
|
||||
# 使用wmic命令获取主机UUID(更稳定)
|
||||
result = subprocess.check_output(
|
||||
"wmic csproduct get uuid", shell=True, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL
|
||||
)
|
||||
lines = result.decode(errors="ignore").splitlines()
|
||||
# 过滤掉空行和表头,只取有效UUID
|
||||
uuids = [line.strip() for line in lines if line.strip() and line.strip().lower() != "uuid"]
|
||||
if uuids:
|
||||
uuid_val = uuids[0]
|
||||
# logger.debug(f"主机UUID: {uuid_val}")
|
||||
# 增加无效值判断
|
||||
if uuid_val and uuid_val.lower() not in ["to be filled by o.e.m.", "none", "", "standard"]:
|
||||
machine_code = uuid_val
|
||||
elif system_info == "Linux":
|
||||
# 优先读取 /etc/machine-id,其次 /var/lib/dbus/machine-id,取第一个非空且内容有效的
|
||||
for path in ["/etc/machine-id", "/var/lib/dbus/machine-id"]:
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
code = f.read().strip()
|
||||
# 只要内容非空且不是全0
|
||||
if code and set(code) != {"0"}:
|
||||
machine_code = code
|
||||
break
|
||||
elif system_info == "Darwin":
|
||||
# macOS: 使用IOPlatformUUID
|
||||
result = subprocess.check_output(
|
||||
"ioreg -rd1 -c IOPlatformExpertDevice | awk '/IOPlatformUUID/'", shell=True
|
||||
)
|
||||
uuid_line = result.decode(errors="ignore")
|
||||
# 解析出 "IOPlatformUUID" = "xxxx-xxxx-xxxx-xxxx"
|
||||
import re
|
||||
|
||||
m = re.search(r'"IOPlatformUUID"\s*=\s*"([^"]+)"', uuid_line)
|
||||
if m:
|
||||
uuid_val = m.group(1)
|
||||
logger.debug(f"IOPlatformUUID: {uuid_val}")
|
||||
if uuid_val and uuid_val.lower() not in ["to be filled by o.e.m.", "none", "", "standard"]:
|
||||
machine_code = uuid_val
|
||||
except Exception as e:
|
||||
logger.debug(f"获取机器码失败: {e}")
|
||||
|
||||
# 如果主板序列号无效,尝试用MAC地址
|
||||
if not machine_code:
|
||||
try:
|
||||
mac = uuid.getnode()
|
||||
if (mac >> 40) % 2 == 0: # 不是本地伪造MAC
|
||||
machine_code = str(mac)
|
||||
except Exception as e:
|
||||
logger.debug(f"获取MAC地址失败: {e}")
|
||||
|
||||
def md5_to_uuid(md5hex):
|
||||
# 将32位md5字符串格式化为8-4-4-4-12的UUID格式
|
||||
return f"{md5hex[0:8]}-{md5hex[8:12]}-{md5hex[12:16]}-{md5hex[16:20]}-{md5hex[20:32]}"
|
||||
|
||||
if machine_code:
|
||||
# print(f"machine_code={machine_code!r}") # 可用于调试
|
||||
md5 = hashlib.md5(machine_code.encode("utf-8")).hexdigest()
|
||||
uuid_str = md5_to_uuid(md5)
|
||||
else:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
|
||||
unique_id = f"{system_info}-{uuid_str}"
|
||||
return unique_id
|
||||
|
||||
|
||||
|
|
@ -175,3 +240,9 @@ def main():
|
|||
|
||||
return heartbeat_thread # 返回线程对象,便于外部控制
|
||||
return None
|
||||
|
||||
|
||||
# --- 测试用例 ---
|
||||
if __name__ == "__main__":
|
||||
print("测试唯一ID生成:")
|
||||
print("唯一ID:", get_unique_id())
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ nonebot-qq="http://127.0.0.1:18002/api/message"
|
|||
[chat] #麦麦的聊天通用设置
|
||||
allow_focus_mode = true # 是否允许专注聊天状态
|
||||
# 是否启用heart_flowC(HFC)模式
|
||||
# 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token
|
||||
# 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token
|
||||
base_normal_chat_num = 8 # 最多允许多少个群进行普通聊天
|
||||
base_focused_chat_num = 5 # 最多允许多少个群进行专注聊天
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue