From 79b2cb45a9a389e3ccbc67b45e22293a06cc99e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Tue, 6 May 2025 21:02:08 +0800 Subject: [PATCH 1/8] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=94=AF?= =?UTF-8?q?=E4=B8=80ID=E7=94=9F=E6=88=90=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=A4=9A=E5=B9=B3=E5=8F=B0=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E6=9C=BA=E5=99=A8=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/remote/remote.py | 85 ++++++++++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 9 deletions(-) diff --git a/src/plugins/remote/remote.py b/src/plugins/remote/remote.py index 5d880271..e749447a 100644 --- a/src/plugins/remote/remote.py +++ b/src/plugins/remote/remote.py @@ -5,15 +5,12 @@ import platform import os import json import threading -from src.common.logger import get_module_logger, LogConfig, REMOTE_STYLE_CONFIG +import subprocess +# from loguru import logger +from src.common.logger_manager import get_logger from src.config.config import global_config - -remote_log_config = LogConfig( - console_format=REMOTE_STYLE_CONFIG["console_format"], - file_format=REMOTE_STYLE_CONFIG["file_format"], -) -logger = get_module_logger("remote", config=remote_log_config) +logger = get_logger("remote") # --- 使用向上导航的方式定义路径 --- @@ -82,9 +79,74 @@ def get_unique_id(): # 生成客户端唯一ID def generate_unique_id(): - # 结合主机名、系统信息和随机UUID生成唯一ID + # 基于机器码生成唯一ID,同一台机器上生成的UUID是固定的,只要机器码不变 + import hashlib system_info = platform.system() - unique_id = f"{system_info}-{uuid.uuid4()}" + machine_code = None + + try: + if system_info == "Windows": + # 使用wmic命令获取主机UUID(更稳定) + result = subprocess.check_output( + 'wmic csproduct get uuid', shell=True, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL + ) + lines = result.decode(errors="ignore").splitlines() + # 过滤掉空行和表头,只取有效UUID + uuids = [line.strip() for line in lines if line.strip() and line.strip().lower() != "uuid"] + if uuids: + uuid_val = uuids[0] + # logger.debug(f"主机UUID: {uuid_val}") + # 增加无效值判断 + if uuid_val and uuid_val.lower() not in ["to be filled by o.e.m.", "none", "", "standard"]: + machine_code = uuid_val + elif system_info == "Linux": + # 优先读取 /etc/machine-id,其次 /var/lib/dbus/machine-id,取第一个非空且内容有效的 + for path in ["/etc/machine-id", "/var/lib/dbus/machine-id"]: + if os.path.exists(path): + with open(path, "r") as f: + code = f.read().strip() + # 只要内容非空且不是全0 + if code and set(code) != {"0"}: + machine_code = code + break + elif system_info == "Darwin": + # macOS: 使用IOPlatformUUID + result = subprocess.check_output( + "ioreg -rd1 -c IOPlatformExpertDevice | awk '/IOPlatformUUID/'", shell=True + ) + uuid_line = result.decode(errors="ignore") + # 解析出 "IOPlatformUUID" = "xxxx-xxxx-xxxx-xxxx" + import re + m = re.search(r'"IOPlatformUUID"\s*=\s*"([^"]+)"', uuid_line) + if m: + uuid_val = m.group(1) + logger.debug(f"IOPlatformUUID: {uuid_val}") + if uuid_val and uuid_val.lower() not in ["to be filled by o.e.m.", "none", "", "standard"]: + machine_code = uuid_val + except Exception as e: + logger.debug(f"获取机器码失败: {e}") + + # 如果主板序列号无效,尝试用MAC地址 + if not machine_code: + try: + mac = uuid.getnode() + if (mac >> 40) % 2 == 0: # 不是本地伪造MAC + machine_code = str(mac) + except Exception as e: + logger.debug(f"获取MAC地址失败: {e}") + + def md5_to_uuid(md5hex): + # 将32位md5字符串格式化为8-4-4-4-12的UUID格式 + return f"{md5hex[0:8]}-{md5hex[8:12]}-{md5hex[12:16]}-{md5hex[16:20]}-{md5hex[20:32]}" + + if machine_code: + # print(f"machine_code={machine_code!r}") # 可用于调试 + md5 = hashlib.md5(machine_code.encode("utf-8")).hexdigest() + uuid_str = md5_to_uuid(md5) + else: + uuid_str = str(uuid.uuid4()) + + unique_id = f"{system_info}-{uuid_str}" return unique_id @@ -175,3 +237,8 @@ def main(): return heartbeat_thread # 返回线程对象,便于外部控制 return None + +# --- 测试用例 --- +if __name__ == "__main__": + print("测试唯一ID生成:") + print("唯一ID:", get_unique_id()) From ca55d646e13e35a61c9c366da5119c2fbd3333d8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 6 May 2025 13:02:29 +0000 Subject: [PATCH 2/8] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/remote/remote.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/plugins/remote/remote.py b/src/plugins/remote/remote.py index e749447a..68b02396 100644 --- a/src/plugins/remote/remote.py +++ b/src/plugins/remote/remote.py @@ -6,6 +6,7 @@ import os import json import threading import subprocess + # from loguru import logger from src.common.logger_manager import get_logger from src.config.config import global_config @@ -81,6 +82,7 @@ def get_unique_id(): def generate_unique_id(): # 基于机器码生成唯一ID,同一台机器上生成的UUID是固定的,只要机器码不变 import hashlib + system_info = platform.system() machine_code = None @@ -88,7 +90,7 @@ def generate_unique_id(): if system_info == "Windows": # 使用wmic命令获取主机UUID(更稳定) result = subprocess.check_output( - 'wmic csproduct get uuid', shell=True, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL + "wmic csproduct get uuid", shell=True, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL ) lines = result.decode(errors="ignore").splitlines() # 过滤掉空行和表头,只取有效UUID @@ -117,6 +119,7 @@ def generate_unique_id(): uuid_line = result.decode(errors="ignore") # 解析出 "IOPlatformUUID" = "xxxx-xxxx-xxxx-xxxx" import re + m = re.search(r'"IOPlatformUUID"\s*=\s*"([^"]+)"', uuid_line) if m: uuid_val = m.group(1) @@ -238,6 +241,7 @@ def main(): return heartbeat_thread # 返回线程对象,便于外部控制 return None + # --- 测试用例 --- if __name__ == "__main__": print("测试唯一ID生成:") From 1e2cdeeea536bdee8212a7bb8de837fef940c2a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 7 May 2025 00:21:04 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=BC=BA?= =?UTF-8?q?=E5=88=B6=E5=81=9C=E6=AD=A2MAI=20Bot=E7=9A=84API=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3(=E5=8D=8A=E6=88=90=E5=93=81)=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=B5=8C=E5=85=A5=E6=95=B0=E6=8D=AE=E7=9B=AE=E5=BD=95?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 23 +++++++++++++++++++- src/api/apiforgui.py | 2 ++ src/api/main.py | 18 +++++++++++++-- src/plugins/knowledge/src/embedding_store.py | 8 ++++--- src/plugins/knowledge/src/kg_manager.py | 7 +++--- 5 files changed, 49 insertions(+), 9 deletions(-) diff --git a/bot.py b/bot.py index 01041629..e3c3cb5d 100644 --- a/bot.py +++ b/bot.py @@ -33,6 +33,22 @@ driver = None app = None loop = None +# shutdown_requested = False # 新增全局变量 + +async def request_shutdown() -> bool: + """请求关闭程序""" + try: + if loop and not loop.is_closed(): + try: + loop.run_until_complete(graceful_shutdown()) + except Exception as ge: # 捕捉优雅关闭时可能发生的错误 + logger.error(f"优雅关闭时发生错误: {ge}") + return False + return True + except Exception as e: + logger.error(f"请求关闭程序时发生错误: {e}") + return False + def easter_egg(): # 彩蛋 @@ -230,6 +246,9 @@ def raw_main(): return MainSystem() + + + if __name__ == "__main__": exit_code = 0 # 用于记录程序最终的退出状态 try: @@ -252,6 +271,8 @@ if __name__ == "__main__": loop.run_until_complete(graceful_shutdown()) except Exception as ge: # 捕捉优雅关闭时可能发生的错误 logger.error(f"优雅关闭时发生错误: {ge}") + # 新增:检测外部请求关闭 + # except Exception as e: # 将主异常捕获移到外层 try...except # logger.error(f"事件循环内发生错误: {str(e)} {str(traceback.format_exc())}") # exit_code = 1 @@ -271,5 +292,5 @@ if __name__ == "__main__": loop.close() logger.info("事件循环已关闭") # 在程序退出前暂停,让你有机会看到输出 - input("按 Enter 键退出...") # <--- 添加这行 + # input("按 Enter 键退出...") # <--- 添加这行 sys.exit(exit_code) # <--- 使用记录的退出码 diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index 75ef2f8d..7e2460b0 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -1,5 +1,7 @@ from src.heart_flow.heartflow import heartflow from src.heart_flow.sub_heartflow import ChatState +from src.common.logger_manager import get_logger +logger = get_logger("api") async def get_all_subheartflow_ids() -> list: diff --git a/src/api/main.py b/src/api/main.py index 6d7e3c1e..a39dafd5 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -1,12 +1,15 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter - +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) # from src.config.config import BotConfig from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server -from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status +from src.api.apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status from src.heart_flow.sub_heartflow import ChatState + # import uvicorn # import os @@ -50,6 +53,17 @@ async def forced_change_subheartflow_status_api(subheartflow_id: str, status: Ch logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") return {"status": "failed"} +@router.get("/stop") +async def force_stop_maibot(): + """强制停止MAI Bot""" + from bot import request_shutdown + success = await request_shutdown() + if success: + logger.info("MAI Bot已强制停止") + return {"status": "success"} + else: + logger.error("MAI Bot强制停止失败") + return {"status": "failed"} def start_api_server(): """启动API服务器""" diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index d1eb7f90..2a27c539 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -28,6 +28,8 @@ from rich.progress import ( install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) +EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") if global_config["persistence"]["embedding_data_dir"] is None else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"]) +EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 # 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录 @@ -288,17 +290,17 @@ class EmbeddingManager: self.paragraphs_embedding_store = EmbeddingStore( llm_client, PG_NAMESPACE, - global_config["persistence"]["embedding_data_dir"], + EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( llm_client, ENT_NAMESPACE, - global_config["persistence"]["embedding_data_dir"], + EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( llm_client, REL_NAMESPACE, - global_config["persistence"]["embedding_data_dir"], + EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index fd922af4..19403f9b 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -30,8 +30,9 @@ from .lpmmconfig import ( ) from .global_logger import logger - - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) +KG_DIR = os.path.join(ROOT_PATH, "data/rag") if global_config["persistence"]["rag_data_dir"] is None else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) +KG_DIR_STR = str(KG_DIR).replace("\\", "/") class KGManager: def __init__(self): # 会被保存的字段 @@ -43,7 +44,7 @@ class KGManager: self.graph = di_graph.DiGraph() # 持久化相关 - self.dir_path = global_config["persistence"]["rag_data_dir"] + self.dir_path = KG_DIR_STR self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" From afbe4f280ead33169a082b1375423c67f30b93ca Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 6 May 2025 16:21:18 +0000 Subject: [PATCH 4/8] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 16 +++++++--------- src/api/apiforgui.py | 1 + src/api/main.py | 4 ++++ src/plugins/knowledge/src/embedding_store.py | 6 +++++- src/plugins/knowledge/src/kg_manager.py | 9 ++++++++- 5 files changed, 25 insertions(+), 11 deletions(-) diff --git a/bot.py b/bot.py index e3c3cb5d..8cecff75 100644 --- a/bot.py +++ b/bot.py @@ -35,15 +35,16 @@ loop = None # shutdown_requested = False # 新增全局变量 + async def request_shutdown() -> bool: """请求关闭程序""" try: if loop and not loop.is_closed(): - try: - loop.run_until_complete(graceful_shutdown()) - except Exception as ge: # 捕捉优雅关闭时可能发生的错误 - logger.error(f"优雅关闭时发生错误: {ge}") - return False + try: + loop.run_until_complete(graceful_shutdown()) + except Exception as ge: # 捕捉优雅关闭时可能发生的错误 + logger.error(f"优雅关闭时发生错误: {ge}") + return False return True except Exception as e: logger.error(f"请求关闭程序时发生错误: {e}") @@ -246,9 +247,6 @@ def raw_main(): return MainSystem() - - - if __name__ == "__main__": exit_code = 0 # 用于记录程序最终的退出状态 try: @@ -272,7 +270,7 @@ if __name__ == "__main__": except Exception as ge: # 捕捉优雅关闭时可能发生的错误 logger.error(f"优雅关闭时发生错误: {ge}") # 新增:检测外部请求关闭 - + # except Exception as e: # 将主异常捕获移到外层 try...except # logger.error(f"事件循环内发生错误: {str(e)} {str(traceback.format_exc())}") # exit_code = 1 diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index 7e2460b0..a8027c48 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -1,6 +1,7 @@ from src.heart_flow.heartflow import heartflow from src.heart_flow.sub_heartflow import ChatState from src.common.logger_manager import get_logger + logger = get_logger("api") diff --git a/src/api/main.py b/src/api/main.py index a39dafd5..1f47a57c 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -2,6 +2,7 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter import os import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) # from src.config.config import BotConfig from src.common.logger_manager import get_logger @@ -53,10 +54,12 @@ async def forced_change_subheartflow_status_api(subheartflow_id: str, status: Ch logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") return {"status": "failed"} + @router.get("/stop") async def force_stop_maibot(): """强制停止MAI Bot""" from bot import request_shutdown + success = await request_shutdown() if success: logger.info("MAI Bot已强制停止") @@ -65,6 +68,7 @@ async def force_stop_maibot(): logger.error("MAI Bot强制停止失败") return {"status": "failed"} + def start_api_server(): """启动API服务器""" global_server.register_router(router, prefix="/api/v1") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 2a27c539..cf139ad3 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -28,7 +28,11 @@ from rich.progress import ( install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) -EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") if global_config["persistence"]["embedding_data_dir"] is None else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"]) +EMBEDDING_DATA_DIR = ( + os.path.join(ROOT_PATH, "data", "embedding") + if global_config["persistence"]["embedding_data_dir"] is None + else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"]) +) EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index 19403f9b..ad5df092 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -30,9 +30,16 @@ from .lpmmconfig import ( ) from .global_logger import logger + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) -KG_DIR = os.path.join(ROOT_PATH, "data/rag") if global_config["persistence"]["rag_data_dir"] is None else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) +KG_DIR = ( + os.path.join(ROOT_PATH, "data/rag") + if global_config["persistence"]["rag_data_dir"] is None + else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) +) KG_DIR_STR = str(KG_DIR).replace("\\", "/") + + class KGManager: def __init__(self): # 会被保存的字段 From 54eaff8cf21fe7916df3f4ad8b694d360a0a3dd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 7 May 2025 00:27:48 +0800 Subject: [PATCH 5/8] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E5=9C=B0=E5=9D=80=EF=BC=8C=E6=94=B9=E4=B8=BA?= =?UTF-8?q?=E7=BB=9D=E5=AF=B9=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/logger.py b/src/common/logger.py index bf82cffa..318d9b37 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -80,7 +80,8 @@ _custom_style_handlers: dict[Tuple[str, str], List[int]] = {} # 记录自定义 # 获取日志存储根地址 current_file_path = Path(__file__).resolve() -LOG_ROOT = "logs" +ROOT_PATH = os.path.abspath(os.path.join(current_file_path, "..", "..")) +LOG_ROOT = str(ROOT_PATH) + "/" + "logs" SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower() if SIMPLE_OUTPUT == "true": From 5ebab2f3f6dc51628bd21bf895511f369885ce03 Mon Sep 17 00:00:00 2001 From: KeepingRunning <1599949878@qq.com> Date: Wed, 7 May 2025 08:18:05 +0800 Subject: [PATCH 6/8] =?UTF-8?q?fix:=20=E5=B0=86=E5=B7=A6=E5=8D=8A=E8=A7=92?= =?UTF-8?q?=E6=8B=AC=E5=8F=B7=E6=94=B9=E4=B8=BA=E5=85=A8=E8=A7=92=E6=8B=AC?= =?UTF-8?q?=E5=8F=B7=EF=BC=8C=E4=BF=9D=E6=8C=81=E6=B3=A8=E9=87=8A=E5=B7=A6?= =?UTF-8?q?=E5=8F=B3=E6=8B=AC=E5=8F=B7=E5=8C=B9=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- template/bot_config_template.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 52280783..8eab299c 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -68,7 +68,7 @@ nonebot-qq="http://127.0.0.1:18002/api/message" [chat] #麦麦的聊天通用设置 allow_focus_mode = true # 是否允许专注聊天状态 # 是否启用heart_flowC(HFC)模式 -# 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token +# 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token base_normal_chat_num = 8 # 最多允许多少个群进行普通聊天 base_focused_chat_num = 5 # 最多允许多少个群进行专注聊天 From 162dc49acd76e3abeec86c38054a6310de80f4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 7 May 2025 22:08:16 +0800 Subject: [PATCH 7/8] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E5=AD=90=E5=BF=83=E6=B5=81=E5=BE=AA=E7=8E=AF=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E5=92=8C=E6=89=80=E6=9C=89=E7=8A=B6=E6=80=81=E7=9A=84?= =?UTF-8?q?API=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/apiforgui.py | 17 ++++++++++ src/api/main.py | 28 +++++++++++++++-- src/heart_flow/heartflow.py | 18 +++++++++++ src/heart_flow/interest_logger.py | 52 +++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index a8027c48..1860aef7 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -17,3 +17,20 @@ async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatSt if subheartflow: return await heartflow.force_change_subheartflow_status(subheartflow_id, status) return False + +async def get_subheartflow_cycle_info(subheartflow_id: str, history_len: int) -> dict: + """获取子心流的循环信息""" + subheartflow_cycle_info = await heartflow.api_get_subheartflow_cycle_info(subheartflow_id, history_len) + logger.debug(f"子心流 {subheartflow_id} 循环信息: {subheartflow_cycle_info}") + if subheartflow_cycle_info: + return subheartflow_cycle_info + else: + logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") + return None + + +async def get_all_states(): + """获取所有状态""" + all_states = await heartflow.api_get_all_states() + logger.debug(f"所有状态: {all_states}") + return all_states diff --git a/src/api/main.py b/src/api/main.py index 1f47a57c..f5d299d8 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -2,13 +2,18 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter import os import sys - +# from src.heart_flow.heartflow import heartflow sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) # from src.config.config import BotConfig from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server -from src.api.apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status +from src.api.apiforgui import ( + get_all_subheartflow_ids, + forced_change_subheartflow_status, + get_subheartflow_cycle_info, + get_all_states, +) from src.heart_flow.sub_heartflow import ChatState # import uvicorn @@ -67,7 +72,26 @@ async def force_stop_maibot(): else: logger.error("MAI Bot强制停止失败") return {"status": "failed"} + +@router.get("/gui/subheartflow/cycleinfo") +async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int): + """获取子心流的循环信息""" + cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len) + if cycle_info: + return {"status": "success", "data": cycle_info} + else: + logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") + return {"status": "failed", "reason": "subheartflow not found"} +@router.get("/gui/get_all_states") +async def get_all_states_api(): + """获取所有状态""" + all_states = await get_all_states() + if all_states: + return {"status": "success", "data": all_states} + else: + logger.warning("获取所有状态失败") + return {"status": "failed", "reason": "failed to get all states"} def start_api_server(): """启动API服务器""" diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 894247ce..dd58f5cd 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -66,6 +66,24 @@ class Heartflow: """强制改变子心流的状态""" # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 return await self.subheartflow_manager.force_change_state(subheartflow_id, status) + + async def api_get_all_states(self): + """获取所有状态""" + return await self.interest_logger.api_get_all_states() + + + async def api_get_subheartflow_cycle_info(self, subheartflow_id: str, history_len: int) -> Optional[dict]: + """获取子心流的循环信息""" + subheartflow = await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) + if not subheartflow: + logger.warning(f"尝试获取不存在的子心流 {subheartflow_id} 的周期信息") + return None + heartfc_instance = subheartflow.heart_fc_instance + if not heartfc_instance: + logger.warning(f"子心流 {subheartflow_id} 没有心流实例,无法获取周期信息") + return None + + return heartfc_instance.get_cycle_history(last_n=history_len) async def heartflow_start_working(self): """启动后台任务""" diff --git a/src/heart_flow/interest_logger.py b/src/heart_flow/interest_logger.py index 1fe289b8..9b562156 100644 --- a/src/heart_flow/interest_logger.py +++ b/src/heart_flow/interest_logger.py @@ -158,3 +158,55 @@ class InterestLogger: except Exception as e: logger.error(f"记录状态时发生意外错误: {e}") logger.error(traceback.format_exc()) + + async def api_get_all_states(self): + """获取主心流和所有子心流的状态。""" + try: + current_timestamp = time.time() + + # main_mind = self.heartflow.current_mind + # 获取 Mai 状态名称 + mai_state_name = self.heartflow.current_state.get_current_state().name + + all_subflow_states = await self.get_all_subflow_states() + + log_entry_base = { + "timestamp": round(current_timestamp, 2), + # "main_mind": main_mind, + "mai_state": mai_state_name, + "subflow_count": len(all_subflow_states), + "subflows": [], + } + + subflow_details = [] + items_snapshot = list(all_subflow_states.items()) + for stream_id, state in items_snapshot: + group_name = stream_id + try: + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream: + if chat_stream.group_info: + group_name = chat_stream.group_info.group_name + elif chat_stream.user_info: + group_name = f"私聊_{chat_stream.user_info.user_nickname}" + except Exception as e: + logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}") + + interest_state = state.get("interest_state", {}) + + subflow_entry = { + "stream_id": stream_id, + "group_name": group_name, + "sub_mind": state.get("current_mind", "未知"), + "sub_chat_state": state.get("chat_state", "未知"), + "interest_level": interest_state.get("interest_level", 0.0), + "start_hfc_probability": interest_state.get("start_hfc_probability", 0.0), + # "is_above_threshold": interest_state.get("is_above_threshold", False), + } + subflow_details.append(subflow_entry) + + log_entry_base["subflows"] = subflow_details + return subflow_details + except Exception as e: + logger.error(f"记录状态时发生意外错误: {e}") + logger.error(traceback.format_exc()) \ No newline at end of file From 61f6bf3e7e277b9855a741bbfab2e9ada3b4e111 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 7 May 2025 14:08:29 +0000 Subject: [PATCH 8/8] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/apiforgui.py | 5 +++-- src/api/main.py | 8 ++++++-- src/heart_flow/heartflow.py | 5 ++--- src/heart_flow/interest_logger.py | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index 1860aef7..a266f8e8 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -18,16 +18,17 @@ async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatSt return await heartflow.force_change_subheartflow_status(subheartflow_id, status) return False + async def get_subheartflow_cycle_info(subheartflow_id: str, history_len: int) -> dict: """获取子心流的循环信息""" - subheartflow_cycle_info = await heartflow.api_get_subheartflow_cycle_info(subheartflow_id, history_len) + subheartflow_cycle_info = await heartflow.api_get_subheartflow_cycle_info(subheartflow_id, history_len) logger.debug(f"子心流 {subheartflow_id} 循环信息: {subheartflow_cycle_info}") if subheartflow_cycle_info: return subheartflow_cycle_info else: logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") return None - + async def get_all_states(): """获取所有状态""" diff --git a/src/api/main.py b/src/api/main.py index f5d299d8..4378ff1e 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -2,6 +2,7 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter import os import sys + # from src.heart_flow.heartflow import heartflow sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) # from src.config.config import BotConfig @@ -9,7 +10,7 @@ from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server from src.api.apiforgui import ( - get_all_subheartflow_ids, + get_all_subheartflow_ids, forced_change_subheartflow_status, get_subheartflow_cycle_info, get_all_states, @@ -72,7 +73,8 @@ async def force_stop_maibot(): else: logger.error("MAI Bot强制停止失败") return {"status": "failed"} - + + @router.get("/gui/subheartflow/cycleinfo") async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int): """获取子心流的循环信息""" @@ -83,6 +85,7 @@ async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") return {"status": "failed", "reason": "subheartflow not found"} + @router.get("/gui/get_all_states") async def get_all_states_api(): """获取所有状态""" @@ -93,6 +96,7 @@ async def get_all_states_api(): logger.warning("获取所有状态失败") return {"status": "failed", "reason": "failed to get all states"} + def start_api_server(): """启动API服务器""" global_server.register_router(router, prefix="/api/v1") diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index dd58f5cd..2cf7d365 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -66,12 +66,11 @@ class Heartflow: """强制改变子心流的状态""" # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 return await self.subheartflow_manager.force_change_state(subheartflow_id, status) - + async def api_get_all_states(self): """获取所有状态""" return await self.interest_logger.api_get_all_states() - async def api_get_subheartflow_cycle_info(self, subheartflow_id: str, history_len: int) -> Optional[dict]: """获取子心流的循环信息""" subheartflow = await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) @@ -82,7 +81,7 @@ class Heartflow: if not heartfc_instance: logger.warning(f"子心流 {subheartflow_id} 没有心流实例,无法获取周期信息") return None - + return heartfc_instance.get_cycle_history(last_n=history_len) async def heartflow_start_working(self): diff --git a/src/heart_flow/interest_logger.py b/src/heart_flow/interest_logger.py index 9b562156..fb33a6f6 100644 --- a/src/heart_flow/interest_logger.py +++ b/src/heart_flow/interest_logger.py @@ -209,4 +209,4 @@ class InterestLogger: return subflow_details except Exception as e: logger.error(f"记录状态时发生意外错误: {e}") - logger.error(traceback.format_exc()) \ No newline at end of file + logger.error(traceback.format_exc())