From 3935ce817ea14dd648c2e0c66d17cf8dafbaa037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sat, 29 Nov 2025 14:38:42 +0800 Subject: [PATCH] Ruff Fix & format --- bot.py | 7 +- src/chat/brain_chat/brain_chat.py | 5 +- src/chat/heart_flow/heartFC_chat.py | 7 +- .../message_receive/uni_message_sender.py | 31 +-- src/chat/planner_actions/planner.py | 33 ++- src/chat/utils/chat_message_builder.py | 2 +- src/chat/utils/utils_image.py | 16 +- src/express/expression_learner.py | 6 +- src/express/expression_reflector.py | 81 +++---- src/express/reflect_tracker.py | 68 +++--- .../chat_history_summarizer.py | 20 +- src/jargon/jargon_explainer.py | 37 ++-- src/jargon/jargon_miner.py | 28 ++- src/jargon/jargon_utils.py | 26 +-- src/llm_models/utils_model.py | 6 +- src/main.py | 4 +- src/memory_system/memory_retrieval.py | 9 +- .../retrieval_tools/query_chat_history.py | 16 +- src/plugin_system/base/config_types.py | 55 ++--- src/plugin_system/base/plugin_base.py | 24 +- src/webui/chat_routes.py | 207 ++++++++++-------- src/webui/config_routes.py | 90 ++++---- src/webui/config_schema.py | 12 +- src/webui/emoji_routes.py | 99 +++++---- src/webui/git_mirror_service.py | 4 +- src/webui/knowledge_routes.py | 182 +++++++-------- src/webui/model_routes.py | 113 +++++----- src/webui/plugin_routes.py | 148 ++++++------- src/webui/routers/system.py | 4 +- src/webui/webui_server.py | 14 +- test_edge.py | 8 +- 31 files changed, 678 insertions(+), 684 deletions(-) diff --git a/bot.py b/bot.py index b47ccfc2..cecffde6 100644 --- a/bot.py +++ b/bot.py @@ -78,6 +78,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression # 关闭 WebUI 服务器 try: from src.webui.webui_server import get_webui_server + webui_server = get_webui_server() if webui_server and webui_server._server: await webui_server.shutdown() @@ -236,15 +237,15 @@ if __name__ == "__main__": except KeyboardInterrupt: logger.warning("收到中断信号,正在优雅关闭...") - + # 取消主任务 - if 'main_tasks' in locals() and main_tasks and not main_tasks.done(): + if "main_tasks" in locals() and main_tasks and not main_tasks.done(): main_tasks.cancel() try: loop.run_until_complete(main_tasks) except asyncio.CancelledError: pass - + # 执行优雅关闭 if loop and not loop.is_closed(): try: diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 6eede670..8702248f 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -235,13 +235,13 @@ class BrainChatting: if recent_messages_list is None: recent_messages_list = [] _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError - + # ------------------------------------------------------------------------- # ReflectTracker Check # 在每次回复前检查一次上下文,看是否有反思问题得到了解答 # ------------------------------------------------------------------------- from src.express.reflect_tracker import reflect_tracker_manager - + tracker = reflect_tracker_manager.get_tracker(self.stream_id) if tracker: resolved = await tracker.trigger_tracker() @@ -254,6 +254,7 @@ class BrainChatting: # 检查是否需要提问表达反思 # ------------------------------------------------------------------------- from src.express.expression_reflector import expression_reflector_manager + reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) asyncio.create_task(reflector.check_and_ask()) diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index bb9c6d76..1af21a4f 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -400,7 +400,7 @@ class HeartFChatting: # ReflectTracker Check # 在每次回复前检查一次上下文,看是否有反思问题得到了解答 # ------------------------------------------------------------------------- - + reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) await reflector.check_and_ask() tracker = reflect_tracker_manager.get_tracker(self.stream_id) @@ -410,7 +410,6 @@ class HeartFChatting: reflect_tracker_manager.remove_tracker(self.stream_id) logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.") - start_time = time.time() async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) @@ -427,7 +426,9 @@ class HeartFChatting: # asyncio.create_task(self.chat_history_summarizer.process()) cycle_timers, thinking_id = self.start_cycle() - logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})") + logger.info( + f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})" + ) # 第一步:动作检查 available_actions: Dict[str, ActionInfo] = {} diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 343084eb..13ff3641 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -25,6 +25,7 @@ def get_webui_chat_broadcaster(): if _webui_chat_broadcaster is None: try: from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM + _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM) except ImportError: _webui_chat_broadcaster = (None, None) @@ -43,26 +44,28 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: # WebUI 聊天室消息,通过 WebSocket 广播 import time from src.config.config import global_config - - await chat_manager.broadcast({ - "type": "bot_message", - "content": message.processed_plain_text, - "message_type": "text", - "timestamp": time.time(), - "sender": { - "name": global_config.bot.nickname, - "avatar": None, - "is_bot": True, + + await chat_manager.broadcast( + { + "type": "bot_message", + "content": message.processed_plain_text, + "message_type": "text", + "timestamp": time.time(), + "sender": { + "name": global_config.bot.nickname, + "avatar": None, + "is_bot": True, + }, } - }) - + ) + # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库 # 无需手动保存 - + if show_log: logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") return True - + # 直接调用API发送消息 await get_global_api().send_message(message) if show_log: diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index d7c7c792..5c498e30 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -181,8 +181,12 @@ class ActionPlanner: found_ids = set(matches) missing_ids = found_ids - available_ids if missing_ids: - logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...") - logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中") + logger.info( + f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..." + ) + logger.info( + f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中" + ) def _replace(match: re.Match[str]) -> str: msg_id = match.group(0) @@ -234,17 +238,11 @@ class ActionPlanner: target_message = message_id_list[-1][1] logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message") - if ( - action != "no_reply" - and target_message is not None - and self._is_message_from_self(target_message) - ): + if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message): logger.info( f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply" ) - reasoning = ( - f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}" - ) + reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}" action = "no_reply" target_message = None @@ -295,10 +293,9 @@ class ActionPlanner: def _is_message_from_self(self, message: "DatabaseMessages") -> bool: """判断消息是否由机器人自身发送""" try: - return ( - str(message.user_info.user_id) == str(global_config.bot.qq_account) - and (message.user_info.platform or "") == (global_config.bot.platform or "") - ) + return str(message.user_info.user_id) == str(global_config.bot.qq_account) and ( + message.user_info.platform or "" + ) == (global_config.bot.platform or "") except AttributeError: logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段") return False @@ -780,20 +777,20 @@ class ActionPlanner: json_content_start = json_start_pos + 7 # ```json的长度 # 提取从```json之后到内容结尾的所有内容 incomplete_json_str = content[json_content_start:].strip() - + # 提取JSON之前的内容作为推理文本 if json_start_pos > 0: reasoning_content = content[:json_start_pos].strip() reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE) reasoning_content = reasoning_content.strip() - + if incomplete_json_str: try: # 清理可能的注释和格式问题 json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str) json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) json_str = json_str.strip() - + if json_str: # 尝试按行分割,每行可能是一个JSON对象 lines = [line.strip() for line in json_str.split("\n") if line.strip()] @@ -808,7 +805,7 @@ class ActionPlanner: json_objects.append(item) except json.JSONDecodeError: pass - + # 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组 if not json_objects: try: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index b45a53b3..5592e6cf 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -959,7 +959,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: b header = f"[{i + 1}] {anon_name}说 " else: header = f"{anon_name}说 " - + output_lines.append(header) stripped_line = content.strip() if stripped_line: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 99b00204..a244fecb 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -130,12 +130,10 @@ class ImageManager: try: # 清理Images表中type为emoji的记录 deleted_images = Images.delete().where(Images.type == "emoji").execute() - + # 清理ImageDescriptions表中type为emoji的记录 - deleted_descriptions = ( - ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute() - ) - + deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute() + total_deleted = deleted_images + deleted_descriptions if total_deleted > 0: logger.info( @@ -194,10 +192,14 @@ class ImageManager: if cache_record: # 优先使用情感标签,如果没有则使用详细描述 if cache_record.emotion_tags: - logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...") + logger.info( + f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..." + ) return f"[表情包:{cache_record.emotion_tags}]" elif cache_record.description: - logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...") + logger.info( + f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..." + ) return f"[表情包:{cache_record.description}]" except Exception as e: logger.debug(f"查询EmojiDescriptionCache时出错: {e}") diff --git a/src/express/expression_learner.py b/src/express/expression_learner.py index 60aa609a..76cc8408 100644 --- a/src/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -226,19 +226,19 @@ class ExpressionLearner: match_responses = [] try: response = response.strip() - + # 尝试提取JSON代码块(如果存在) json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) if matches: response = matches[0].strip() - + # 移除可能的markdown代码块标记(如果没有找到```json,但可能有```) if not matches: response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE) response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE) response = response.strip() - + # 检查是否已经是标准JSON数组格式 if response.startswith("[") and response.endswith("]"): match_responses = json.loads(response) diff --git a/src/express/expression_reflector.py b/src/express/expression_reflector.py index ecd6a822..0418b435 100644 --- a/src/express/expression_reflector.py +++ b/src/express/expression_reflector.py @@ -13,21 +13,21 @@ logger = get_logger("expression_reflector") class ExpressionReflector: """表达反思器,管理单个聊天流的表达反思提问""" - + def __init__(self, chat_id: str): self.chat_id = chat_id self.last_ask_time: float = 0.0 - + async def check_and_ask(self) -> bool: """ 检查是否需要提问表达反思,如果需要则提问 - + Returns: bool: 是否执行了提问 """ try: logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})") - + if not global_config.expression.reflect: logger.debug(f"[Expression Reflection] 表达反思功能未启用,跳过") return False @@ -48,7 +48,7 @@ class ExpressionReflector: allow_reflect_chat_ids.append(parsed_chat_id) else: logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}") - + if self.chat_id not in allow_reflect_chat_ids: logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过") return False @@ -56,17 +56,21 @@ class ExpressionReflector: # 检查上一次提问时间 current_time = time.time() time_since_last_ask = current_time - self.last_ask_time - + # 5-10分钟间隔,随机选择 min_interval = 10 * 60 # 5分钟 max_interval = 15 * 60 # 10分钟 interval = random.uniform(min_interval, max_interval) - - logger.info(f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask/60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval/60:.2f}分钟)") - + + logger.info( + f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)" + ) + if time_since_last_ask < interval: remaining_time = interval - time_since_last_ask - logger.info(f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time/60:.2f}分钟),跳过") + logger.info( + f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过" + ) return False # 检查是否已经有针对该 Operator 的 Tracker 在运行 @@ -78,21 +82,22 @@ class ExpressionReflector: # 获取未检查的表达 try: logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达") - expressions = (Expression - .select() - .where((Expression.checked == False) & (Expression.rejected == False)) - .limit(50)) - + expressions = ( + Expression.select().where((Expression.checked == False) & (Expression.rejected == False)).limit(50) + ) + expr_list = list(expressions) logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") - + if not expr_list: logger.info(f"[Expression Reflection] 没有可用的表达,跳过") return False target_expr: Expression = random.choice(expr_list) - logger.info(f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}") - + logger.info( + f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}" + ) + # 生成询问文本 ask_text = _generate_ask_text(target_expr) if not ask_text: @@ -102,31 +107,33 @@ class ExpressionReflector: logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问") # 发送给 Operator await _send_to_operator(operator_config, ask_text, target_expr) - + # 更新上一次提问时间 self.last_ask_time = current_time logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}") - + return True - + except Exception as e: logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") import traceback + logger.error(traceback.format_exc()) return False except Exception as e: logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") import traceback + logger.error(traceback.format_exc()) return False class ExpressionReflectorManager: """表达反思管理器,管理多个聊天流的表达反思实例""" - + def __init__(self): self.reflectors: Dict[str, ExpressionReflector] = {} - + def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector: """获取或创建指定聊天流的表达反思实例""" if chat_id not in self.reflectors: @@ -141,6 +148,7 @@ expression_reflector_manager = ExpressionReflectorManager() async def _check_tracker_exists(operator_config: str) -> bool: """检查指定 Operator 是否已有活跃的 Tracker""" from src.express.reflect_tracker import reflect_tracker_manager + chat_manager = get_chat_manager() chat_stream = None @@ -150,12 +158,12 @@ async def _check_tracker_exists(operator_config: str) -> bool: platform = parts[0] id_str = parts[1] stream_type = parts[2] - + user_info = None group_info = None - + from maim_message import UserInfo, GroupInfo - + if stream_type == "group": group_info = GroupInfo(group_id=id_str, platform=platform) user_info = UserInfo(user_id="system", user_nickname="System", platform=platform) @@ -203,12 +211,12 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression): platform = parts[0] id_str = parts[1] stream_type = parts[2] - + user_info = None group_info = None - + from maim_message import UserInfo, GroupInfo - + if stream_type == "group": group_info = GroupInfo(group_id=id_str, platform=platform) user_info = UserInfo(user_id="system", user_nickname="System", platform=platform) @@ -232,20 +240,13 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression): return stream_id = chat_stream.stream_id - + # 注册 Tracker from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager - + tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time()) reflect_tracker_manager.add_tracker(stream_id, tracker) - + # 发送消息 - await send_api.text_to_stream( - text=text, - stream_id=stream_id, - typing=True - ) + await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True) logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}") - - - diff --git a/src/express/reflect_tracker.py b/src/express/reflect_tracker.py index 7984e299..8973eaf4 100644 --- a/src/express/reflect_tracker.py +++ b/src/express/reflect_tracker.py @@ -17,21 +17,20 @@ if TYPE_CHECKING: logger = get_logger("reflect_tracker") + class ReflectTracker: def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float): self.chat_stream = chat_stream self.expression = expression self.created_time = created_time # self.message_count = 0 # Replaced by checking message list length - self.last_check_msg_count = 0 + self.last_check_msg_count = 0 self.max_message_count = 30 self.max_duration = 15 * 60 # 15 minutes - + # LLM for judging response - self.judge_model = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="reflect.tracker" - ) - + self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker") + self._init_prompts() def _init_prompts(self): @@ -72,16 +71,16 @@ class ReflectTracker: if time.time() - self.created_time > self.max_duration: logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).") return True - + # Fetch messages since creation msg_list = get_raw_msg_by_timestamp_with_chat( chat_id=self.chat_stream.stream_id, timestamp_start=self.created_time, timestamp_end=time.time(), ) - + current_msg_count = len(msg_list) - + # Check message limit if current_msg_count > self.max_message_count: logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).") @@ -90,9 +89,9 @@ class ReflectTracker: # If no new messages since last check, skip if current_msg_count <= self.last_check_msg_count: return False - + self.last_check_msg_count = current_msg_count - + # Build context block # Use simple readable format context_block = build_readable_messages( @@ -109,78 +108,83 @@ class ReflectTracker: "reflect_judge_prompt", situation=self.expression.situation, style=self.expression.style, - context_block=context_block + context_block=context_block, ) - + logger.info(f"ReflectTracker LLM Prompt: {prompt}") - + response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1) - + logger.info(f"ReflectTracker LLM Response: {response}") - + # Parse JSON import json import re from json_repair import repair_json - + json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) if not matches: # Try to parse raw response if no code block matches = [response] - + json_obj = json.loads(repair_json(matches[0])) - + judgment = json_obj.get("judgment") - + if judgment == "Approve": self.expression.checked = True self.expression.rejected = False self.expression.save() logger.info(f"Expression {self.expression.id} approved by operator.") return True - + elif judgment == "Reject": self.expression.checked = True corrected_situation = json_obj.get("corrected_situation") corrected_style = json_obj.get("corrected_style") - + # 检查是否有更新 has_update = bool(corrected_situation or corrected_style) - + if corrected_situation: self.expression.situation = corrected_situation if corrected_style: self.expression.style = corrected_style - + # 如果拒绝但未更新,标记为 rejected=1 if not has_update: self.expression.rejected = True else: self.expression.rejected = False - + self.expression.save() - + if has_update: - logger.info(f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}") + logger.info( + f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}" + ) else: - logger.info(f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1.") + logger.info( + f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1." + ) return True - + elif judgment == "Ignore": logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.") return False - + except Exception as e: logger.error(f"Error in ReflectTracker check: {e}") return False return False + # Global manager for trackers class ReflectTrackerManager: def __init__(self): - self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker + self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker def add_tracker(self, chat_id: str, tracker: ReflectTracker): self.trackers[chat_id] = tracker @@ -192,5 +196,5 @@ class ReflectTrackerManager: if chat_id in self.trackers: del self.trackers[chat_id] -reflect_tracker_manager = ReflectTrackerManager() +reflect_tracker_manager = ReflectTrackerManager() diff --git a/src/hippo_memorizer/chat_history_summarizer.py b/src/hippo_memorizer/chat_history_summarizer.py index 3f1b62e0..840f349d 100644 --- a/src/hippo_memorizer/chat_history_summarizer.py +++ b/src/hippo_memorizer/chat_history_summarizer.py @@ -315,7 +315,9 @@ class ChatHistorySummarizer: before_count = len(self.current_batch.messages) self.current_batch.messages.extend(new_messages) self.current_batch.end_time = current_time - logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息") + logger.info( + f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息" + ) # 更新批次后持久化 self._persist_topic_cache() else: @@ -361,9 +363,7 @@ class ChatHistorySummarizer: else: time_str = f"{time_since_last_check / 3600:.1f}小时" - logger.info( - f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}" - ) + logger.info(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}") # 检查“话题检查”触发条件 should_check = False @@ -413,7 +413,7 @@ class ChatHistorySummarizer: # 说明 bot 没有参与这段对话,不应该记录 bot_user_id = str(global_config.bot.qq_account) has_bot_message = False - + for msg in messages: if msg.user_info.user_id == bot_user_id: has_bot_message = True @@ -426,7 +426,9 @@ class ChatHistorySummarizer: return # 2. 构造编号后的消息字符串和参与者信息 - numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) + numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = ( + self._build_numbered_messages_for_llm(messages) + ) # 3. 调用 LLM 识别话题,并得到 topic -> indices existing_topics = list(self.topic_cache.keys()) @@ -588,9 +590,7 @@ class ChatHistorySummarizer: if not numbered_lines: return False, {} - history_topics_block = ( - "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" - ) + history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" messages_block = "\n".join(numbered_lines) prompt = await global_prompt_manager.format_prompt( @@ -607,6 +607,7 @@ class ChatHistorySummarizer: ) import re + logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") @@ -895,4 +896,3 @@ class ChatHistorySummarizer: init_prompt() - diff --git a/src/jargon/jargon_explainer.py b/src/jargon/jargon_explainer.py index 251af098..67a008b5 100644 --- a/src/jargon/jargon_explainer.py +++ b/src/jargon/jargon_explainer.py @@ -44,9 +44,7 @@ class JargonExplainer: request_type="jargon.explain", ) - def match_jargon_from_messages( - self, messages: List[Any] - ) -> List[Dict[str, str]]: + def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]: """ 通过直接匹配数据库中的jargon字符串来提取黑话 @@ -57,7 +55,7 @@ class JargonExplainer: List[Dict[str, str]]: 提取到的黑话列表,每个元素包含content """ start_time = time.time() - + if not messages: return [] @@ -67,8 +65,10 @@ class JargonExplainer: # 跳过机器人自己的消息 if is_bot_message(msg): continue - - msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip() + + msg_text = ( + getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "" + ).strip() if msg_text: message_texts.append(msg_text) @@ -79,9 +79,7 @@ class JargonExplainer: combined_text = " ".join(message_texts) # 查询所有有meaning的jargon记录 - query = Jargon.select().where( - (Jargon.meaning.is_null(False)) & (Jargon.meaning != "") - ) + query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) # 根据all_global配置决定查询逻辑 if global_config.jargon.all_global: @@ -98,7 +96,7 @@ class JargonExplainer: # 执行查询并匹配 matched_jargon: Dict[str, Dict[str, str]] = {} query_time = time.time() - + for jargon in query: content = jargon.content or "" if not content or not content.strip(): @@ -123,13 +121,13 @@ class JargonExplainer: pattern = re.escape(content) # 使用单词边界或中文字符边界来匹配,避免部分匹配 # 对于中文,使用Unicode字符类;对于英文,使用单词边界 - if re.search(r'[\u4e00-\u9fff]', content): + if re.search(r"[\u4e00-\u9fff]", content): # 包含中文,使用更宽松的匹配 search_pattern = pattern else: # 纯英文/数字,使用单词边界 - search_pattern = r'\b' + pattern + r'\b' - + search_pattern = r"\b" + pattern + r"\b" + if re.search(search_pattern, combined_text, re.IGNORECASE): # 找到匹配,记录(去重) if content not in matched_jargon: @@ -139,7 +137,7 @@ class JargonExplainer: total_time = match_time - start_time query_duration = query_time - start_time match_duration = match_time - query_time - + logger.info( f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, " f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话" @@ -147,9 +145,7 @@ class JargonExplainer: return list(matched_jargon.values()) - async def explain_jargon( - self, messages: List[Any], chat_context: str - ) -> Optional[str]: + async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]: """ 解释上下文中的黑话 @@ -183,7 +179,7 @@ class JargonExplainer: jargon_explanations: List[str] = [] for entry in jargon_list: content = entry["content"] - + # 根据是否开启全局黑话,决定查询方式 if global_config.jargon.all_global: # 开启全局黑话:查询所有is_global=True的记录 @@ -239,9 +235,7 @@ class JargonExplainer: return summary -async def explain_jargon_in_context( - chat_id: str, messages: List[Any], chat_context: str -) -> Optional[str]: +async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]: """ 解释上下文中的黑话(便捷函数) @@ -255,4 +249,3 @@ async def explain_jargon_in_context( """ explainer = JargonExplainer(chat_id) return await explainer.explain_jargon(messages, chat_context) - diff --git a/src/jargon/jargon_miner.py b/src/jargon/jargon_miner.py index 0e25af57..77cb15ce 100644 --- a/src/jargon/jargon_miner.py +++ b/src/jargon/jargon_miner.py @@ -17,20 +17,18 @@ from src.chat.utils.chat_message_builder import ( ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.jargon.jargon_utils import ( - is_bot_message, - build_context_paragraph, - contains_bot_self_name, - parse_chat_id_list, + is_bot_message, + build_context_paragraph, + contains_bot_self_name, + parse_chat_id_list, chat_id_list_contains, - update_chat_id_list + update_chat_id_list, ) logger = get_logger("jargon") - - def _init_prompt() -> None: prompt_str = """ **聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID** @@ -126,7 +124,6 @@ _init_prompt() _init_inference_prompts() - def _should_infer_meaning(jargon_obj: Jargon) -> bool: """ 判断是否需要进行含义推断 @@ -211,7 +208,9 @@ class JargonMiner: processed_pairs = set() for idx, msg in enumerate(messages): - msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip() + msg_text = ( + getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "" + ).strip() if not msg_text or is_bot_message(msg): continue @@ -270,7 +269,7 @@ class JargonMiner: prompt1 = await global_prompt_manager.format_prompt( "jargon_inference_with_context_prompt", content=content, - bot_name = global_config.bot.nickname, + bot_name=global_config.bot.nickname, raw_content_list=raw_content_text, ) @@ -588,7 +587,6 @@ class JargonMiner: content = entry["content"] raw_content_list = entry["raw_content"] # 已经是列表 - try: # 查询所有content匹配的记录 query = Jargon.select().where(Jargon.content == content) @@ -782,15 +780,15 @@ def search_jargon( # 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含 if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): continue - + # 只返回有meaning的记录 if not jargon.meaning or jargon.meaning.strip() == "": continue - + results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) - + # 达到限制数量后停止 if len(results) >= limit: break - return results \ No newline at end of file + return results diff --git a/src/jargon/jargon_utils.py b/src/jargon/jargon_utils.py index f17889f4..56fe13ad 100644 --- a/src/jargon/jargon_utils.py +++ b/src/jargon/jargon_utils.py @@ -13,19 +13,20 @@ from src.chat.utils.utils import parse_platform_accounts logger = get_logger("jargon") + def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: """ 解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表) - + Args: chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式) - + Returns: List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表 """ if not chat_id_value: return [] - + # 如果是字符串,尝试解析为JSON if isinstance(chat_id_value, str): # 尝试解析JSON @@ -54,12 +55,12 @@ def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]: """ 更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目 - + Args: chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...] target_chat_id: 要更新或添加的chat_id increment: 增加的计数,默认为1 - + Returns: List[List[Any]]: 更新后的chat_id列表 """ @@ -74,22 +75,22 @@ def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, incr item.append(increment) found = True break - + if not found: # 未找到,添加新条目 chat_id_list.append([target_chat_id, increment]) - + return chat_id_list def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: """ 检查chat_id列表中是否包含指定的chat_id - + Args: chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] target_chat_id: 要查找的chat_id - + Returns: bool: 如果包含则返回True """ @@ -168,10 +169,7 @@ def is_bot_message(msg: Any) -> bool: .strip() .lower() ) - user_id = ( - str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "") - .strip() - ) + user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() if not platform or not user_id: return False @@ -196,4 +194,4 @@ def is_bot_message(msg: Any) -> bool: bot_accounts[plat] = account bot_account = bot_accounts.get(platform) - return bool(bot_account and user_id == bot_account) \ No newline at end of file + return bool(bot_account and user_id == bot_account) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 74bbffaf..4f1725fd 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -338,8 +338,10 @@ class LLMRequest: if e.__cause__: original_error_type = type(e.__cause__).__name__ original_error_msg = str(e.__cause__) - original_error_info = f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" - + original_error_info = ( + f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" + ) + retry_remain -= 1 if retry_remain <= 0: logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}") diff --git a/src/main.py b/src/main.py index 02702f2c..23e34d61 100644 --- a/src/main.py +++ b/src/main.py @@ -56,7 +56,7 @@ class MainSystem: from src.webui.webui_server import get_webui_server self.webui_server = get_webui_server() - + if webui_mode == "development": logger.info("📝 WebUI 开发模式已启用") logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001") @@ -66,7 +66,7 @@ class MainSystem: logger.info("✅ WebUI 生产模式已启用") logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001") logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build") - + except Exception as e: logger.error(f"❌ 初始化 WebUI 服务器失败: {e}") diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 5820c206..a4de5f26 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -296,7 +296,6 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]: if not content: continue - if not global_config.jargon.all_global and not jargon.is_global: chat_id_list = parse_chat_id_list(jargon.chat_id) if not chat_id_list_contains(chat_id_list, chat_id): @@ -586,9 +585,7 @@ async def _react_agent_solve_question( step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}}) step["observations"] = ["从LLM输出内容中检测到found_answer"] thinking_steps.append(step) - logger.info( - f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}" - ) + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}") return True, found_answer_content, thinking_steps, False if not_enough_info_reason: @@ -1016,9 +1013,7 @@ async def build_memory_retrieval_prompt( if question_results: retrieved_memory = "\n\n".join(question_results) - logger.info( - f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆" - ) + logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆") return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" else: logger.debug("所有问题均未找到答案") diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index f5216de9..d7131505 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -54,7 +54,9 @@ async def search_chat_history( if record.participants: try: participants_data = ( - json.loads(record.participants) if isinstance(record.participants, str) else record.participants + json.loads(record.participants) + if isinstance(record.participants, str) + else record.participants ) if isinstance(participants_data, list): participants_list = [str(p).lower() for p in participants_data] @@ -156,9 +158,7 @@ async def search_chat_history( # 添加关键词 if record.keywords: try: - keywords_data = ( - json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords - ) + keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords if isinstance(keywords_data, list) and keywords_data: keywords_str = "、".join([str(k) for k in keywords_data]) result_parts.append(f"关键词:{keywords_str}") @@ -208,9 +208,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str: return "未提供有效的记忆ID" # 查询记录 - query = ChatHistory.select().where( - (ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list)) - ) + query = ChatHistory.select().where((ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list))) records = list(query.order_by(ChatHistory.start_time.desc())) if not records: @@ -256,9 +254,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str: # 添加关键词 if record.keywords: try: - keywords_data = ( - json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords - ) + keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords if isinstance(keywords_data, list) and keywords_data: keywords_str = "、".join([str(k) for k in keywords_data]) result_parts.append(f"关键词:{keywords_str}") diff --git a/src/plugin_system/base/config_types.py b/src/plugin_system/base/config_types.py index ef0f656c..5979ca54 100644 --- a/src/plugin_system/base/config_types.py +++ b/src/plugin_system/base/config_types.py @@ -12,12 +12,12 @@ from dataclasses import dataclass, field class ConfigField: """ 配置字段定义 - + 用于定义插件配置项的元数据,支持类型验证、UI 渲染等功能。 - + 基础示例: ConfigField(type=str, default="", description="API密钥") - + 完整示例: ConfigField( type=str, @@ -73,9 +73,9 @@ class ConfigField: def get_ui_type(self) -> str: """ 获取 UI 控件类型 - + 如果指定了 input_type 则直接返回,否则根据 type 和 choices 自动推断。 - + Returns: 控件类型字符串 """ @@ -103,7 +103,7 @@ class ConfigField: def to_dict(self) -> Dict[str, Any]: """ 转换为可序列化的字典(用于 API 传输) - + Returns: 包含所有配置信息的字典 """ @@ -139,9 +139,9 @@ class ConfigField: class ConfigSection: """ 配置节定义 - + 用于描述配置文件中一个 section 的元数据。 - + 示例: ConfigSection( title="API配置", @@ -150,6 +150,7 @@ class ConfigSection: order=1 ) """ + title: str # 显示标题 description: Optional[str] = None # 详细描述 icon: Optional[str] = None # 图标名称 @@ -171,9 +172,9 @@ class ConfigSection: class ConfigTab: """ 配置标签页定义 - + 用于将多个 section 组织到一个标签页中。 - + 示例: ConfigTab( id="general", @@ -182,6 +183,7 @@ class ConfigTab: sections=["plugin", "api"] ) """ + id: str # 标签页 ID title: str # 显示标题 sections: List[str] = field(default_factory=list) # 包含的 section 名称列表 @@ -201,18 +203,18 @@ class ConfigTab: } -@dataclass +@dataclass class ConfigLayout: """ 配置页面布局定义 - + 用于定义插件配置页面的整体布局结构。 - + 布局类型: - "auto": 自动布局,sections 作为折叠面板显示 - "tabs": 标签页布局 - "pages": 分页布局(左侧导航 + 右侧内容) - + 简单示例(标签页布局): ConfigLayout( type="tabs", @@ -222,9 +224,10 @@ class ConfigLayout: ] ) """ + type: str = "auto" # 布局类型: auto, tabs, pages tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表 - + def to_dict(self) -> Dict[str, Any]: """转换为可序列化的字典""" return { @@ -234,37 +237,27 @@ class ConfigLayout: def section_meta( - title: str, - description: Optional[str] = None, - icon: Optional[str] = None, - collapsed: bool = False, - order: int = 0 + title: str, description: Optional[str] = None, icon: Optional[str] = None, collapsed: bool = False, order: int = 0 ) -> Union[str, ConfigSection]: """ 便捷函数:创建 section 元数据 - + 可以在 config_section_descriptions 中使用,提供比纯字符串更丰富的信息。 - + Args: title: 显示标题 description: 详细描述 icon: 图标名称 collapsed: 默认是否折叠 order: 排序权重 - + Returns: ConfigSection 实例 - + 示例: config_section_descriptions = { "api": section_meta("API配置", icon="cloud", order=1), "debug": section_meta("调试设置", collapsed=True, order=99), } """ - return ConfigSection( - title=title, - description=description, - icon=icon, - collapsed=collapsed, - order=order - ) + return ConfigSection(title=title, description=description, icon=icon, collapsed=collapsed, order=order) diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 1fe99b8a..30d33ec2 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -574,14 +574,14 @@ class PluginBase(ABC): def get_webui_config_schema(self) -> Dict[str, Any]: """ 获取 WebUI 配置 Schema - + 返回完整的配置 schema,包含: - 插件基本信息 - 所有 section 及其字段定义 - 布局配置 - + 用于 WebUI 动态生成配置表单。 - + Returns: Dict: 完整的配置 schema """ @@ -596,12 +596,12 @@ class PluginBase(ABC): "sections": {}, "layout": None, } - + # 处理 sections for section_name, fields in self.config_schema.items(): if not isinstance(fields, dict): continue - + section_data = { "name": section_name, "title": section_name, @@ -611,7 +611,7 @@ class PluginBase(ABC): "order": 0, "fields": {}, } - + # 获取 section 元数据 section_meta = self.config_section_descriptions.get(section_name) if section_meta: @@ -625,16 +625,16 @@ class PluginBase(ABC): section_data["order"] = section_meta.order elif isinstance(section_meta, dict): section_data.update(section_meta) - + # 处理字段 for field_name, field_def in fields.items(): if isinstance(field_def, ConfigField): field_data = field_def.to_dict() field_data["name"] = field_name section_data["fields"][field_name] = field_data - + schema["sections"][section_name] = section_data - + # 处理布局 if self.config_layout: schema["layout"] = self.config_layout.to_dict() @@ -644,15 +644,15 @@ class PluginBase(ABC): "type": "auto", "tabs": [], } - + return schema def get_current_config_values(self) -> Dict[str, Any]: """ 获取当前配置值 - + 返回插件当前的配置值(已从配置文件加载)。 - + Returns: Dict: 当前配置值 """ diff --git a/src/webui/chat_routes.py b/src/webui/chat_routes.py index 0bab8cae..f0403d09 100644 --- a/src/webui/chat_routes.py +++ b/src/webui/chat_routes.py @@ -25,6 +25,7 @@ WEBUI_USER_ID_PREFIX = "webui_user_" class ChatHistoryMessage(BaseModel): """聊天历史消息""" + id: str type: str # 'user' | 'bot' | 'system' content: str @@ -36,17 +37,17 @@ class ChatHistoryMessage(BaseModel): class ChatHistoryManager: """聊天历史管理器 - 使用 SQLite 数据库存储""" - + def __init__(self, max_messages: int = 200): self.max_messages = max_messages - + def _message_to_dict(self, msg: Messages) -> Dict[str, Any]: """将数据库消息转换为前端格式""" # 判断是否是机器人消息 # WebUI 用户的 user_id 以 "webui_" 开头,其他都是机器人消息 user_id = msg.user_id or "" is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX) - + return { "id": msg.message_id, "type": "bot" if is_bot else "user", @@ -56,7 +57,7 @@ class ChatHistoryManager: "sender_id": "bot" if is_bot else user_id, "is_bot": is_bot, } - + def get_history(self, limit: int = 50) -> List[Dict[str, Any]]: """从数据库获取最近的历史记录""" try: @@ -67,25 +68,21 @@ class ChatHistoryManager: .order_by(Messages.time.desc()) .limit(limit) ) - + # 转换为列表并反转(使最旧的消息在前) result = [self._message_to_dict(msg) for msg in messages] result.reverse() - + logger.debug(f"从数据库加载了 {len(result)} 条聊天记录") return result except Exception as e: logger.error(f"从数据库加载聊天记录失败: {e}") return [] - + def clear_history(self) -> int: """清空 WebUI 聊天历史记录""" try: - deleted = ( - Messages.delete() - .where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID) - .execute() - ) + deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute() logger.info(f"已清空 {deleted} 条 WebUI 聊天记录") return deleted except Exception as e: @@ -100,31 +97,31 @@ chat_history = ChatHistoryManager() # 存储 WebSocket 连接 class ChatConnectionManager: """聊天连接管理器""" - + def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射 - + async def connect(self, websocket: WebSocket, session_id: str, user_id: str): await websocket.accept() self.active_connections[session_id] = websocket self.user_sessions[user_id] = session_id logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}") - + def disconnect(self, session_id: str, user_id: str): if session_id in self.active_connections: del self.active_connections[session_id] if user_id in self.user_sessions and self.user_sessions[user_id] == session_id: del self.user_sessions[user_id] logger.info(f"WebUI 聊天会话已断开: session={session_id}") - + async def send_message(self, session_id: str, message: dict): if session_id in self.active_connections: try: await self.active_connections[session_id].send_json(message) except Exception as e: logger.error(f"发送消息失败: {e}") - + async def broadcast(self, message: dict): """广播消息给所有连接""" for session_id in list(self.active_connections.keys()): @@ -135,16 +132,12 @@ chat_manager = ChatConnectionManager() def create_message_data( - content: str, - user_id: str, - user_name: str, - message_id: Optional[str] = None, - is_at_bot: bool = True + content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True ) -> Dict[str, Any]: """创建符合麦麦消息格式的消息数据""" if message_id is None: message_id = str(uuid.uuid4()) - + return { "message_info": { "platform": WEBUI_CHAT_PLATFORM, @@ -163,7 +156,7 @@ def create_message_data( }, "additional_config": { "at_bot": is_at_bot, - } + }, }, "message_segment": { "type": "seglist", @@ -175,8 +168,8 @@ def create_message_data( { "type": "mention_bot", "data": "1.0", - } - ] + }, + ], }, "raw_message": content, "processed_plain_text": content, @@ -186,10 +179,10 @@ def create_message_data( @router.get("/history") async def get_chat_history( limit: int = Query(default=50, ge=1, le=200), - user_id: Optional[str] = Query(default=None) # 保留参数兼容性,但不用于过滤 + user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤 ): """获取聊天历史记录 - + 所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录 """ history = chat_history.get_history(limit) @@ -217,76 +210,87 @@ async def websocket_chat( user_name: Optional[str] = Query(default="WebUI用户"), ): """WebSocket 聊天端点 - + Args: user_id: 用户唯一标识(由前端生成并持久化) user_name: 用户显示昵称(可修改) """ # 生成会话 ID(每次连接都是新的) session_id = str(uuid.uuid4()) - + # 如果没有提供 user_id,生成一个新的 if not user_id: user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}" elif not user_id.startswith(WEBUI_USER_ID_PREFIX): # 确保 user_id 有正确的前缀 user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}" - + await chat_manager.connect(websocket, session_id, user_id) - + try: # 发送会话信息(包含用户 ID,前端需要保存) - await chat_manager.send_message(session_id, { - "type": "session_info", - "session_id": session_id, - "user_id": user_id, - "user_name": user_name, - "bot_name": global_config.bot.nickname, - }) - + await chat_manager.send_message( + session_id, + { + "type": "session_info", + "session_id": session_id, + "user_id": user_id, + "user_name": user_name, + "bot_name": global_config.bot.nickname, + }, + ) + # 发送历史记录 history = chat_history.get_history(50) if history: - await chat_manager.send_message(session_id, { - "type": "history", - "messages": history, - }) - + await chat_manager.send_message( + session_id, + { + "type": "history", + "messages": history, + }, + ) + # 发送欢迎消息(不保存到历史) - await chat_manager.send_message(session_id, { - "type": "system", - "content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!", - "timestamp": time.time(), - }) - + await chat_manager.send_message( + session_id, + { + "type": "system", + "content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!", + "timestamp": time.time(), + }, + ) + while True: data = await websocket.receive_json() - + if data.get("type") == "message": content = data.get("content", "").strip() if not content: continue - + # 用户可以更新昵称 current_user_name = data.get("user_name", user_name) - + message_id = str(uuid.uuid4()) timestamp = time.time() - + # 广播用户消息给所有连接(包括发送者) # 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库 - await chat_manager.broadcast({ - "type": "user_message", - "content": content, - "message_id": message_id, - "timestamp": timestamp, - "sender": { - "name": current_user_name, - "user_id": user_id, - "is_bot": False, + await chat_manager.broadcast( + { + "type": "user_message", + "content": content, + "message_id": message_id, + "timestamp": timestamp, + "sender": { + "name": current_user_name, + "user_id": user_id, + "is_bot": False, + }, } - }) - + ) + # 创建麦麦消息格式 message_data = create_message_data( content=content, @@ -295,46 +299,59 @@ async def websocket_chat( message_id=message_id, is_at_bot=True, ) - + try: # 显示正在输入状态 - await chat_manager.broadcast({ - "type": "typing", - "is_typing": True, - }) - + await chat_manager.broadcast( + { + "type": "typing", + "is_typing": True, + } + ) + # 调用麦麦的消息处理 await chat_bot.message_process(message_data) - + except Exception as e: logger.error(f"处理消息时出错: {e}") - await chat_manager.send_message(session_id, { - "type": "error", - "content": f"处理消息时出错: {str(e)}", - "timestamp": time.time(), - }) + await chat_manager.send_message( + session_id, + { + "type": "error", + "content": f"处理消息时出错: {str(e)}", + "timestamp": time.time(), + }, + ) finally: - await chat_manager.broadcast({ - "type": "typing", - "is_typing": False, - }) - + await chat_manager.broadcast( + { + "type": "typing", + "is_typing": False, + } + ) + elif data.get("type") == "ping": - await chat_manager.send_message(session_id, { - "type": "pong", - "timestamp": time.time(), - }) - + await chat_manager.send_message( + session_id, + { + "type": "pong", + "timestamp": time.time(), + }, + ) + elif data.get("type") == "update_nickname": # 允许用户更新昵称 if new_name := data.get("user_name", "").strip(): current_user_name = new_name - await chat_manager.send_message(session_id, { - "type": "nickname_updated", - "user_name": current_user_name, - "timestamp": time.time(), - }) - + await chat_manager.send_message( + session_id, + { + "type": "nickname_updated", + "user_name": current_user_name, + "timestamp": time.time(), + }, + ) + except WebSocketDisconnect: logger.info(f"WebSocket 断开: session={session_id}, user={user_id}") except Exception as e: @@ -356,7 +373,7 @@ async def get_chat_info(): def get_webui_chat_broadcaster() -> tuple: """获取 WebUI 聊天广播器,供外部模块使用 - + Returns: (chat_manager, WEBUI_CHAT_PLATFORM) 元组 """ diff --git a/src/webui/config_routes.py b/src/webui/config_routes.py index 392f2d98..f26bcb09 100644 --- a/src/webui/config_routes.py +++ b/src/webui/config_routes.py @@ -5,7 +5,7 @@ import os import tomlkit from fastapi import APIRouter, HTTPException, Body -from typing import Any +from typing import Any, Annotated from src.common.logger import get_logger from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT @@ -41,6 +41,12 @@ from src.webui.config_schema import ConfigSchemaGenerator logger = get_logger("webui") +# 模块级别的类型别名(解决 B008 ruff 错误) +ConfigBody = Annotated[dict[str, Any], Body()] +SectionBody = Annotated[Any, Body()] +RawContentBody = Annotated[str, Body(embed=True)] +PathBody = Annotated[dict[str, str], Body()] + router = APIRouter(prefix="/config", tags=["config"]) @@ -90,7 +96,7 @@ async def get_bot_config_schema(): return {"success": True, "schema": schema} except Exception as e: logger.error(f"获取配置架构失败: {e}") - raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e @router.get("/schema/model") @@ -101,7 +107,7 @@ async def get_model_config_schema(): return {"success": True, "schema": schema} except Exception as e: logger.error(f"获取模型配置架构失败: {e}") - raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e # ===== 子配置架构获取接口 ===== @@ -174,7 +180,7 @@ async def get_config_section_schema(section_name: str): return {"success": True, "schema": schema} except Exception as e: logger.error(f"获取配置节架构失败: {e}") - raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e # ===== 配置读取接口 ===== @@ -196,7 +202,7 @@ async def get_bot_config(): raise except Exception as e: logger.error(f"读取配置文件失败: {e}") - raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e @router.get("/model") @@ -215,21 +221,21 @@ async def get_model_config(): raise except Exception as e: logger.error(f"读取配置文件失败: {e}") - raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e # ===== 配置更新接口 ===== @router.post("/bot") -async def update_bot_config(config_data: dict[str, Any] = Body(...)): +async def update_bot_config(config_data: ConfigBody): """更新麦麦主程序配置""" try: # 验证配置数据 try: Config.from_dict(config_data) except Exception as e: - raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") + raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e # 保存配置文件 config_path = os.path.join(CONFIG_DIR, "bot_config.toml") @@ -242,18 +248,18 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)): raise except Exception as e: logger.error(f"保存配置文件失败: {e}") - raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e @router.post("/model") -async def update_model_config(config_data: dict[str, Any] = Body(...)): +async def update_model_config(config_data: ConfigBody): """更新模型配置""" try: # 验证配置数据 try: APIAdapterConfig.from_dict(config_data) except Exception as e: - raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") + raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e # 保存配置文件 config_path = os.path.join(CONFIG_DIR, "model_config.toml") @@ -266,14 +272,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)): raise except Exception as e: logger.error(f"保存配置文件失败: {e}") - raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e # ===== 配置节更新接口 ===== @router.post("/bot/section/{section_name}") -async def update_bot_config_section(section_name: str, section_data: Any = Body(...)): +async def update_bot_config_section(section_name: str, section_data: SectionBody): """更新麦麦主程序配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -304,7 +310,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body( try: Config.from_dict(config_data) except Exception as e: - raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") + raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e # 保存配置(tomlkit.dump 会保留注释) with open(config_path, "w", encoding="utf-8") as f: @@ -316,7 +322,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body( raise except Exception as e: logger.error(f"更新配置节失败: {e}") - raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e # ===== 原始 TOML 文件操作接口 ===== @@ -338,24 +344,24 @@ async def get_bot_config_raw(): raise except Exception as e: logger.error(f"读取配置文件失败: {e}") - raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e @router.post("/bot/raw") -async def update_bot_config_raw(raw_content: str = Body(..., embed=True)): +async def update_bot_config_raw(raw_content: RawContentBody): """更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)""" try: # 验证 TOML 格式 try: config_data = tomlkit.loads(raw_content) except Exception as e: - raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") + raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e # 验证配置数据结构 try: Config.from_dict(config_data) except Exception as e: - raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") + raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e # 保存配置文件 config_path = os.path.join(CONFIG_DIR, "bot_config.toml") @@ -368,11 +374,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)): raise except Exception as e: logger.error(f"保存配置文件失败: {e}") - raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e @router.post("/model/section/{section_name}") -async def update_model_config_section(section_name: str, section_data: Any = Body(...)): +async def update_model_config_section(section_name: str, section_data: SectionBody): """更新模型配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -403,7 +409,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod try: APIAdapterConfig.from_dict(config_data) except Exception as e: - raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") + raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e # 保存配置(tomlkit.dump 会保留注释) with open(config_path, "w", encoding="utf-8") as f: @@ -415,7 +421,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod raise except Exception as e: logger.error(f"更新配置节失败: {e}") - raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e # ===== 适配器配置管理接口 ===== @@ -425,11 +431,11 @@ def _normalize_adapter_path(path: str) -> str: """将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)""" if not path: return path - + # 如果已经是绝对路径,直接返回 if os.path.isabs(path): return path - + # 相对路径,转换为相对于项目根目录的绝对路径 return os.path.normpath(os.path.join(PROJECT_ROOT, path)) @@ -438,17 +444,17 @@ def _to_relative_path(path: str) -> str: """尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径""" if not path or not os.path.isabs(path): return path - + try: # 尝试获取相对路径 rel_path = os.path.relpath(path, PROJECT_ROOT) # 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径 - if not rel_path.startswith('..'): + if not rel_path.startswith(".."): return rel_path except (ValueError, TypeError): # 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError pass - + # 无法转换为相对路径,返回绝对路径 return path @@ -463,6 +469,7 @@ async def get_adapter_config_path(): return {"success": True, "path": None} import json + with open(webui_data_path, "r", encoding="utf-8") as f: webui_data = json.load(f) @@ -472,10 +479,11 @@ async def get_adapter_config_path(): # 将路径规范化为绝对路径 abs_path = _normalize_adapter_path(adapter_config_path) - + # 检查文件是否存在并返回最后修改时间 if os.path.exists(abs_path): import datetime + mtime = os.path.getmtime(abs_path) last_modified = datetime.datetime.fromtimestamp(mtime).isoformat() # 返回相对路径(如果可能) @@ -487,11 +495,11 @@ async def get_adapter_config_path(): except Exception as e: logger.error(f"获取适配器配置路径失败: {e}") - raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e @router.post("/adapter-config/path") -async def save_adapter_config_path(data: dict[str, str] = Body(...)): +async def save_adapter_config_path(data: PathBody): """保存适配器配置文件路径偏好""" try: path = data.get("path") @@ -511,10 +519,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)): # 将路径规范化为绝对路径 abs_path = _normalize_adapter_path(path) - + # 尝试转换为相对路径保存(如果文件在项目目录内) save_path = _to_relative_path(abs_path) - + # 更新路径 webui_data["adapter_config_path"] = save_path @@ -530,7 +538,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)): raise except Exception as e: logger.error(f"保存适配器配置路径失败: {e}") - raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e @router.get("/adapter-config") @@ -542,7 +550,7 @@ async def get_adapter_config(path: str): # 将路径规范化为绝对路径 abs_path = _normalize_adapter_path(path) - + # 检查文件是否存在 if not os.path.exists(abs_path): raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}") @@ -562,11 +570,11 @@ async def get_adapter_config(path: str): raise except Exception as e: logger.error(f"读取适配器配置失败: {e}") - raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e @router.post("/adapter-config") -async def save_adapter_config(data: dict[str, str] = Body(...)): +async def save_adapter_config(data: PathBody): """保存适配器配置到指定路径""" try: path = data.get("path") @@ -579,17 +587,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)): # 将路径规范化为绝对路径 abs_path = _normalize_adapter_path(path) - + # 检查文件扩展名 if not abs_path.endswith(".toml"): raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件") # 验证 TOML 格式 try: - import tomlkit tomlkit.loads(content) except Exception as e: - raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") + raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e # 确保目录存在 dir_path = os.path.dirname(abs_path) @@ -607,5 +614,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)): raise except Exception as e: logger.error(f"保存适配器配置失败: {e}") - raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") - + raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e diff --git a/src/webui/config_schema.py b/src/webui/config_schema.py index c1608bc4..d160000a 100644 --- a/src/webui/config_schema.py +++ b/src/webui/config_schema.py @@ -117,7 +117,7 @@ class ConfigSchemaGenerator: if next_line.startswith('"""') or next_line.startswith("'''"): # 单行文档字符串 if next_line.count('"""') == 2 or next_line.count("'''") == 2: - description_lines.append(next_line.strip('"""').strip("'''").strip()) + description_lines.append(next_line.replace('"""', "").replace("'''", "").strip()) else: # 多行文档字符串 quote = '"""' if next_line.startswith('"""') else "'''" @@ -135,7 +135,7 @@ class ConfigSchemaGenerator: next_line = lines[i + 1].strip() if next_line.startswith('"""') or next_line.startswith("'''"): if next_line.count('"""') == 2 or next_line.count("'''") == 2: - description_lines.append(next_line.strip('"""').strip("'''").strip()) + description_lines.append(next_line.replace('"""', "").replace("'''", "").strip()) else: quote = '"""' if next_line.startswith('"""') else "'''" description_lines.append(next_line.strip(quote).strip()) @@ -199,13 +199,13 @@ class ConfigSchemaGenerator: return FieldType.ARRAY, None, items # 处理基本类型 - if field_type is bool or field_type == bool: + if field_type is bool: return FieldType.BOOLEAN, None, None - elif field_type is int or field_type == int: + elif field_type is int: return FieldType.INTEGER, None, None - elif field_type is float or field_type == float: + elif field_type is float: return FieldType.NUMBER, None, None - elif field_type is str or field_type == str: + elif field_type is str: return FieldType.STRING, None, None elif field_type is dict or origin is dict: return FieldType.OBJECT, None, None diff --git a/src/webui/emoji_routes.py b/src/webui/emoji_routes.py index e2aa6875..94f77b95 100644 --- a/src/webui/emoji_routes.py +++ b/src/webui/emoji_routes.py @@ -3,20 +3,25 @@ from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form from fastapi.responses import FileResponse from pydantic import BaseModel -from typing import Optional, List +from typing import Optional, List, Annotated from src.common.logger import get_logger from src.common.database.database_model import Emoji from .token_manager import get_token_manager -import json import time import os import hashlib -import base64 from PIL import Image import io logger = get_logger("webui.emoji") +# 模块级别的类型别名(解决 B008 ruff 错误) +EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] +EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] +DescriptionForm = Annotated[str, Form(description="表情包描述")] +EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")] +IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")] + # 创建路由器 router = APIRouter(prefix="/emoji", tags=["Emoji"]) @@ -592,10 +597,10 @@ class EmojiUploadResponse(BaseModel): @router.post("/upload", response_model=EmojiUploadResponse) async def upload_emoji( - file: UploadFile = File(..., description="表情包图片文件"), - description: str = Form("", description="表情包描述"), - emotion: str = Form("", description="情感标签,多个用逗号分隔"), - is_registered: bool = Form(True, description="是否直接注册"), + file: EmojiFile, + description: DescriptionForm = "", + emotion: EmotionForm = "", + is_registered: IsRegisteredForm = True, authorization: Optional[str] = Header(None), ): """ @@ -713,9 +718,9 @@ async def upload_emoji( @router.post("/batch/upload") async def batch_upload_emoji( - files: List[UploadFile] = File(..., description="多个表情包图片文件"), - emotion: str = Form("", description="情感标签,多个用逗号分隔"), - is_registered: bool = Form(True, description="是否直接注册"), + files: EmojiFiles, + emotion: EmotionForm = "", + is_registered: IsRegisteredForm = True, authorization: Optional[str] = Header(None), ): """ @@ -749,11 +754,13 @@ async def batch_upload_emoji( # 验证文件类型 if file.content_type not in allowed_types: results["failed"] += 1 - results["details"].append({ - "filename": file.filename, - "success": False, - "error": f"不支持的文件类型: {file.content_type}", - }) + results["details"].append( + { + "filename": file.filename, + "success": False, + "error": f"不支持的文件类型: {file.content_type}", + } + ) continue # 读取文件内容 @@ -761,11 +768,13 @@ async def batch_upload_emoji( if not file_content: results["failed"] += 1 - results["details"].append({ - "filename": file.filename, - "success": False, - "error": "文件内容为空", - }) + results["details"].append( + { + "filename": file.filename, + "success": False, + "error": "文件内容为空", + } + ) continue # 验证图片 @@ -774,11 +783,13 @@ async def batch_upload_emoji( img_format = img.format.lower() if img.format else "png" except Exception as e: results["failed"] += 1 - results["details"].append({ - "filename": file.filename, - "success": False, - "error": f"无效的图片: {str(e)}", - }) + results["details"].append( + { + "filename": file.filename, + "success": False, + "error": f"无效的图片: {str(e)}", + } + ) continue # 计算哈希 @@ -787,11 +798,13 @@ async def batch_upload_emoji( # 检查重复 if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash): results["failed"] += 1 - results["details"].append({ - "filename": file.filename, - "success": False, - "error": "已存在相同的表情包", - }) + results["details"].append( + { + "filename": file.filename, + "success": False, + "error": "已存在相同的表情包", + } + ) continue # 生成文件名并保存 @@ -829,19 +842,23 @@ async def batch_upload_emoji( ) results["uploaded"] += 1 - results["details"].append({ - "filename": file.filename, - "success": True, - "id": emoji.id, - }) + results["details"].append( + { + "filename": file.filename, + "success": True, + "id": emoji.id, + } + ) except Exception as e: results["failed"] += 1 - results["details"].append({ - "filename": file.filename, - "success": False, - "error": str(e), - }) + results["details"].append( + { + "filename": file.filename, + "success": False, + "error": str(e), + } + ) results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个" return results @@ -850,4 +867,4 @@ async def batch_upload_emoji( raise except Exception as e: logger.exception(f"批量上传表情包失败: {e}") - raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e \ No newline at end of file + raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e diff --git a/src/webui/git_mirror_service.py b/src/webui/git_mirror_service.py index df00cde9..a6a9b1bc 100644 --- a/src/webui/git_mirror_service.py +++ b/src/webui/git_mirror_service.py @@ -602,9 +602,9 @@ class GitMirrorService: # 执行 git clone(在线程池中运行以避免阻塞) loop = asyncio.get_event_loop() - def run_git_clone(): + def run_git_clone(clone_cmd=cmd): return subprocess.run( - cmd, + clone_cmd, capture_output=True, text=True, timeout=300, # 5分钟超时 diff --git a/src/webui/knowledge_routes.py b/src/webui/knowledge_routes.py index 717e20ca..af4594b6 100644 --- a/src/webui/knowledge_routes.py +++ b/src/webui/knowledge_routes.py @@ -1,4 +1,5 @@ """知识库图谱可视化 API 路由""" + from typing import List, Optional from fastapi import APIRouter, Query from pydantic import BaseModel @@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"]) class KnowledgeNode(BaseModel): """知识节点""" + id: str type: str # 'entity' or 'paragraph' content: str @@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel): class KnowledgeEdge(BaseModel): """知识边""" + source: str target: str weight: float @@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel): class KnowledgeGraph(BaseModel): """知识图谱""" + nodes: List[KnowledgeNode] edges: List[KnowledgeEdge] class KnowledgeStats(BaseModel): """知识库统计信息""" + total_nodes: int total_edges: int entity_nodes: int @@ -45,7 +50,7 @@ def _load_kg_manager(): """延迟加载 KGManager""" try: from src.chat.knowledge.kg_manager import KGManager - + kg_manager = KGManager() kg_manager.load_from_file() return kg_manager @@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph: """将 DiGraph 转换为 JSON 格式""" if kg_manager is None or kg_manager.graph is None: return KnowledgeGraph(nodes=[], edges=[]) - + graph = kg_manager.graph nodes = [] edges = [] - + # 转换节点 node_list = graph.get_node_list() for node_id in node_list: try: node_data = graph[node_id] # 节点类型: "ent" -> "entity", "pg" -> "paragraph" - node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" - content = node_data['content'] if 'content' in node_data else node_id - create_time = node_data['create_time'] if 'create_time' in node_data else None - - nodes.append(KnowledgeNode( - id=node_id, - type=node_type, - content=content, - create_time=create_time - )) + node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" + content = node_data["content"] if "content" in node_data else node_id + create_time = node_data["create_time"] if "create_time" in node_data else None + + nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time)) except Exception as e: logger.warning(f"跳过节点 {node_id}: {e}") continue - + # 转换边 edge_list = graph.get_edge_list() for edge_tuple in edge_list: @@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph: source, target = edge_tuple[0], edge_tuple[1] # 通过 graph[source, target] 获取边的属性数据 edge_data = graph[source, target] - + # edge_data 支持 [] 操作符但不支持 .get() - weight = edge_data['weight'] if 'weight' in edge_data else 1.0 - create_time = edge_data['create_time'] if 'create_time' in edge_data else None - update_time = edge_data['update_time'] if 'update_time' in edge_data else None - - edges.append(KnowledgeEdge( - source=source, - target=target, - weight=weight, - create_time=create_time, - update_time=update_time - )) + weight = edge_data["weight"] if "weight" in edge_data else 1.0 + create_time = edge_data["create_time"] if "create_time" in edge_data else None + update_time = edge_data["update_time"] if "update_time" in edge_data else None + + edges.append( + KnowledgeEdge( + source=source, target=target, weight=weight, create_time=create_time, update_time=update_time + ) + ) except Exception as e: logger.warning(f"跳过边 {edge_tuple}: {e}") continue - + return KnowledgeGraph(nodes=nodes, edges=edges) @router.get("/graph", response_model=KnowledgeGraph) async def get_knowledge_graph( limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"), - node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph") + node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"), ): """获取知识图谱(限制节点数量) - + Args: limit: 返回的最大节点数,默认 100,最大 10000 node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落) - + Returns: KnowledgeGraph: 包含指定数量节点和相关边的知识图谱 """ @@ -130,46 +128,43 @@ async def get_knowledge_graph( if kg_manager is None: logger.warning("KGManager 未初始化,返回空图谱") return KnowledgeGraph(nodes=[], edges=[]) - + graph = kg_manager.graph all_node_list = graph.get_node_list() - + # 按类型过滤节点 if node_type == "entity": - all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent'] + all_node_list = [ + n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent" + ] elif node_type == "paragraph": - all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg'] - + all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"] + # 限制节点数量 total_nodes = len(all_node_list) if len(all_node_list) > limit: node_list = all_node_list[:limit] else: node_list = all_node_list - + logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})") - + # 转换节点 nodes = [] node_ids = set() for node_id in node_list: try: node_data = graph[node_id] - node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" - content = node_data['content'] if 'content' in node_data else node_id - create_time = node_data['create_time'] if 'create_time' in node_data else None - - nodes.append(KnowledgeNode( - id=node_id, - type=node_type_val, - content=content, - create_time=create_time - )) + node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" + content = node_data["content"] if "content" in node_data else node_id + create_time = node_data["create_time"] if "create_time" in node_data else None + + nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time)) node_ids.add(node_id) except Exception as e: logger.warning(f"跳过节点 {node_id}: {e}") continue - + # 只获取涉及当前节点集的边(保证图的完整性) edges = [] edge_list = graph.get_edge_list() @@ -179,27 +174,25 @@ async def get_knowledge_graph( # 只包含两端都在当前节点集中的边 if source not in node_ids or target not in node_ids: continue - + edge_data = graph[source, target] - weight = edge_data['weight'] if 'weight' in edge_data else 1.0 - create_time = edge_data['create_time'] if 'create_time' in edge_data else None - update_time = edge_data['update_time'] if 'update_time' in edge_data else None - - edges.append(KnowledgeEdge( - source=source, - target=target, - weight=weight, - create_time=create_time, - update_time=update_time - )) + weight = edge_data["weight"] if "weight" in edge_data else 1.0 + create_time = edge_data["create_time"] if "create_time" in edge_data else None + update_time = edge_data["update_time"] if "update_time" in edge_data else None + + edges.append( + KnowledgeEdge( + source=source, target=target, weight=weight, create_time=create_time, update_time=update_time + ) + ) except Exception as e: logger.warning(f"跳过边 {edge_tuple}: {e}") continue - + graph_data = KnowledgeGraph(nodes=nodes, edges=edges) logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边") return graph_data - + except Exception as e: logger.error(f"获取知识图谱失败: {e}", exc_info=True) return KnowledgeGraph(nodes=[], edges=[]) @@ -208,71 +201,59 @@ async def get_knowledge_graph( @router.get("/stats", response_model=KnowledgeStats) async def get_knowledge_stats(): """获取知识库统计信息 - + Returns: KnowledgeStats: 统计信息 """ try: kg_manager = _load_kg_manager() if kg_manager is None or kg_manager.graph is None: - return KnowledgeStats( - total_nodes=0, - total_edges=0, - entity_nodes=0, - paragraph_nodes=0, - avg_connections=0.0 - ) - + return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0) + graph = kg_manager.graph node_list = graph.get_node_list() edge_list = graph.get_edge_list() - + total_nodes = len(node_list) total_edges = len(edge_list) - + # 统计节点类型 entity_nodes = 0 paragraph_nodes = 0 for node_id in node_list: try: node_data = graph[node_id] - node_type = node_data['type'] if 'type' in node_data else 'ent' - if node_type == 'ent': + node_type = node_data["type"] if "type" in node_data else "ent" + if node_type == "ent": entity_nodes += 1 - elif node_type == 'pg': + elif node_type == "pg": paragraph_nodes += 1 except Exception: continue - + # 计算平均连接数 avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0 - + return KnowledgeStats( total_nodes=total_nodes, total_edges=total_edges, entity_nodes=entity_nodes, paragraph_nodes=paragraph_nodes, - avg_connections=round(avg_connections, 2) + avg_connections=round(avg_connections, 2), ) - + except Exception as e: logger.error(f"获取统计信息失败: {e}", exc_info=True) - return KnowledgeStats( - total_nodes=0, - total_edges=0, - entity_nodes=0, - paragraph_nodes=0, - avg_connections=0.0 - ) + return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0) @router.get("/search", response_model=List[KnowledgeNode]) async def search_knowledge_node(query: str = Query(..., min_length=1)): """搜索知识节点 - + Args: query: 搜索关键词 - + Returns: List[KnowledgeNode]: 匹配的节点列表 """ @@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)): kg_manager = _load_kg_manager() if kg_manager is None or kg_manager.graph is None: return [] - + graph = kg_manager.graph node_list = graph.get_node_list() results = [] query_lower = query.lower() - + # 在节点内容中搜索 for node_id in node_list: try: node_data = graph[node_id] - content = node_data['content'] if 'content' in node_data else node_id - node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" - + content = node_data["content"] if "content" in node_data else node_id + node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" + if query_lower in content.lower() or query_lower in node_id.lower(): - create_time = node_data['create_time'] if 'create_time' in node_data else None - results.append(KnowledgeNode( - id=node_id, - type=node_type, - content=content, - create_time=create_time - )) + create_time = node_data["create_time"] if "create_time" in node_data else None + results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time)) except Exception: continue - + logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点") return results[:50] # 限制返回数量 - + except Exception as e: logger.error(f"搜索节点失败: {e}", exc_info=True) return [] diff --git a/src/webui/model_routes.py b/src/webui/model_routes.py index 871512f5..7d8310ee 100644 --- a/src/webui/model_routes.py +++ b/src/webui/model_routes.py @@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str: def _parse_openai_response(data: dict) -> list[dict]: """ 解析 OpenAI 格式的模型列表响应 - + 格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] } """ models = [] if "data" in data and isinstance(data["data"], list): for model in data["data"]: if isinstance(model, dict) and "id" in model: - models.append({ - "id": model["id"], - "name": model.get("name") or model["id"], - "owned_by": model.get("owned_by", ""), - }) + models.append( + { + "id": model["id"], + "name": model.get("name") or model["id"], + "owned_by": model.get("owned_by", ""), + } + ) return models def _parse_gemini_response(data: dict) -> list[dict]: """ 解析 Gemini 格式的模型列表响应 - + 格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] } """ models = [] @@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]: model_id = model["name"] if model_id.startswith("models/"): model_id = model_id[7:] # 去掉 "models/" 前缀 - models.append({ - "id": model_id, - "name": model.get("displayName") or model_id, - "owned_by": "google", - }) + models.append( + { + "id": model_id, + "name": model.get("displayName") or model_id, + "owned_by": "google", + } + ) return models @@ -89,55 +93,54 @@ async def _fetch_models_from_provider( ) -> list[dict]: """ 从提供商 API 获取模型列表 - + Args: base_url: 提供商的基础 URL api_key: API 密钥 endpoint: 获取模型列表的端点 parser: 响应解析器类型 ('openai' | 'gemini') client_type: 客户端类型 ('openai' | 'gemini') - + Returns: 模型列表 """ url = f"{_normalize_url(base_url)}{endpoint}" - + # 根据客户端类型设置请求头 headers = {} params = {} - + if client_type == "gemini": # Gemini 使用 URL 参数传递 API Key params["key"] = api_key else: # OpenAI 兼容格式使用 Authorization 头 headers["Authorization"] = f"Bearer {api_key}" - + try: async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url, headers=headers, params=params) response.raise_for_status() data = response.json() - except httpx.TimeoutException: - raise HTTPException(status_code=504, detail="请求超时,请稍后重试") + except httpx.TimeoutException as e: + raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e except httpx.HTTPStatusError as e: # 注意:使用 502 Bad Gateway 而不是原始的 401/403, # 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理 if e.response.status_code == 401: - raise HTTPException(status_code=502, detail="API Key 无效或已过期") + raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e elif e.response.status_code == 403: - raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") + raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e elif e.response.status_code == 404: - raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") + raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e else: raise HTTPException( - status_code=502, - detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}" - ) + status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}" + ) from e except Exception as e: logger.error(f"获取模型列表失败: {e}") - raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") - + raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e + # 根据解析器类型解析响应 if parser == "openai": return _parse_openai_response(data) @@ -150,26 +153,26 @@ async def _fetch_models_from_provider( def _get_provider_config(provider_name: str) -> Optional[dict]: """ 从 model_config.toml 获取指定提供商的配置 - + Args: provider_name: 提供商名称 - + Returns: 提供商配置,如果未找到则返回 None """ config_path = os.path.join(CONFIG_DIR, "model_config.toml") if not os.path.exists(config_path): return None - + try: with open(config_path, "r", encoding="utf-8") as f: config_data = tomlkit.load(f) - + providers = config_data.get("api_providers", []) for provider in providers: if provider.get("name") == provider_name: return dict(provider) - + return None except Exception as e: logger.error(f"读取提供商配置失败: {e}") @@ -184,23 +187,23 @@ async def get_provider_models( ): """ 获取指定提供商的可用模型列表 - + 通过提供商名称查找配置,然后请求对应的模型列表端点 """ # 获取提供商配置 provider_config = _get_provider_config(provider_name) if not provider_config: raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") - + base_url = provider_config.get("base_url") api_key = provider_config.get("api_key") client_type = provider_config.get("client_type", "openai") - + if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") if not api_key: raise HTTPException(status_code=400, detail="提供商配置缺少 api_key") - + # 获取模型列表 models = await _fetch_models_from_provider( base_url=base_url, @@ -209,7 +212,7 @@ async def get_provider_models( parser=parser, client_type=client_type, ) - + return { "success": True, "models": models, @@ -236,7 +239,7 @@ async def get_models_by_url( parser=parser, client_type=client_type, ) - + return { "success": True, "models": models, @@ -251,11 +254,11 @@ async def test_provider_connection( ): """ 测试提供商连接状态 - + 分两步测试: 1. 网络连通性测试:向 base_url 发送请求,检查是否能连接 2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效 - + 返回: - network_ok: 网络是否连通 - api_key_valid: API Key 是否有效(仅在提供 api_key 时返回) @@ -263,11 +266,11 @@ async def test_provider_connection( - error: 错误信息(如果有) """ import time - + base_url = _normalize_url(base_url) if not base_url: raise HTTPException(status_code=400, detail="base_url 不能为空") - + result = { "network_ok": False, "api_key_valid": None, @@ -275,7 +278,7 @@ async def test_provider_connection( "error": None, "http_status": None, } - + # 第一步:测试网络连通性 try: start_time = time.time() @@ -283,11 +286,11 @@ async def test_provider_connection( # 尝试 GET 请求 base_url(不需要 API Key) response = await client.get(base_url) latency = (time.time() - start_time) * 1000 - + result["network_ok"] = True result["latency_ms"] = round(latency, 2) result["http_status"] = response.status_code - + except httpx.ConnectError as e: result["error"] = f"连接失败:无法连接到服务器 ({str(e)})" return result @@ -300,7 +303,7 @@ async def test_provider_connection( except Exception as e: result["error"] = f"未知错误:{str(e)}" return result - + # 第二步:如果提供了 API Key,验证其有效性 if api_key: try: @@ -313,7 +316,7 @@ async def test_provider_connection( # 尝试获取模型列表 models_url = f"{base_url}/models" response = await client.get(models_url, headers=headers) - + if response.status_code == 200: result["api_key_valid"] = True elif response.status_code in (401, 403): @@ -322,12 +325,12 @@ async def test_provider_connection( else: # 其他状态码,可能是端点不支持,但 Key 可能是有效的 result["api_key_valid"] = None - + except Exception as e: # API Key 验证失败不影响网络连通性结果 logger.warning(f"API Key 验证失败: {e}") result["api_key_valid"] = None - + return result @@ -342,10 +345,10 @@ async def test_provider_connection_by_name( model_config_path = os.path.join(CONFIG_DIR, "model_config.toml") if not os.path.exists(model_config_path): raise HTTPException(status_code=404, detail="配置文件不存在") - + with open(model_config_path, "r", encoding="utf-8") as f: config = tomlkit.load(f) - + # 查找提供商 providers = config.get("api_providers", []) provider = None @@ -353,15 +356,15 @@ async def test_provider_connection_by_name( if p.get("name") == provider_name: provider = p break - + if not provider: raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") - + base_url = provider.get("base_url", "") api_key = provider.get("api_key", "") - + if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") - + # 调用测试接口 return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None) diff --git a/src/webui/plugin_routes.py b/src/webui/plugin_routes.py index 7480d65f..1236b02e 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/plugin_routes.py @@ -31,8 +31,9 @@ def parse_version(version_str: str) -> tuple[int, int, int]: """ # 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符) import re + # 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀 - base_version = re.split(r'[-.](?:snapshot|dev|alpha|beta|rc)', version_str, flags=re.IGNORECASE)[0] + base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0] parts = base_version.split(".") if len(parts) < 3: @@ -613,7 +614,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[ for field in required_fields: if field not in manifest: raise ValueError(f"缺少必需字段: {field}") - + # 将插件 ID 写入 manifest(用于后续准确识别) # 这样即使文件夹名称改变,也能通过 manifest 准确识别插件 manifest["id"] = request.plugin_id @@ -705,7 +706,7 @@ async def uninstall_plugin( plugin_path = plugins_dir / folder_name # 旧格式:点 old_format_path = plugins_dir / request.plugin_id - + # 优先使用新格式,如果不存在则尝试旧格式 if not plugin_path.exists(): if old_format_path.exists(): @@ -839,7 +840,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st plugin_path = plugins_dir / folder_name # 旧格式:点 old_format_path = plugins_dir / request.plugin_id - + # 优先使用新格式,如果不存在则尝试旧格式 if not plugin_path.exists(): if old_format_path.exists(): @@ -1092,21 +1093,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> # 尝试从 author.name 和 repository_url 构建标准 ID author_name = None repo_name = None - + # 获取作者名 if "author" in manifest: if isinstance(manifest["author"], dict) and "name" in manifest["author"]: author_name = manifest["author"]["name"] elif isinstance(manifest["author"], str): author_name = manifest["author"] - + # 从 repository_url 获取仓库名 if "repository_url" in manifest: repo_url = manifest["repository_url"].rstrip("/") if repo_url.endswith(".git"): repo_url = repo_url[:-4] repo_name = repo_url.split("/")[-1] - + # 构建 ID if author_name and repo_name: # 标准格式: Author.RepoName @@ -1122,7 +1123,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> else: # 直接使用文件夹名 plugin_id = folder_name - + # 将推断的 ID 写入 manifest(方便下次识别) logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}") manifest["id"] = plugin_id @@ -1167,12 +1168,10 @@ class UpdatePluginConfigRequest(BaseModel): @router.get("/config/{plugin_id}/schema") -async def get_plugin_config_schema( - plugin_id: str, authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 获取插件配置 Schema - + 返回插件的完整配置 schema,包含所有 section、字段定义和布局信息。 用于前端动态生成配置表单。 """ @@ -1187,10 +1186,10 @@ async def get_plugin_config_schema( try: # 尝试从已加载的插件中获取 from src.plugin_system.core.plugin_manager import plugin_manager - + # 查找插件实例 plugin_instance = None - + # 遍历所有已加载的插件 for loaded_plugin_name in plugin_manager.list_loaded_plugins(): instance = plugin_manager.get_plugin_instance(loaded_plugin_name) @@ -1204,17 +1203,17 @@ async def get_plugin_config_schema( if manifest_id == plugin_id: plugin_instance = instance break - - if plugin_instance and hasattr(plugin_instance, 'get_webui_config_schema'): + + if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"): # 从插件实例获取 schema schema = plugin_instance.get_webui_config_schema() return {"success": True, "schema": schema} - + # 如果插件未加载,尝试从文件系统读取 # 查找插件目录 plugins_dir = Path("plugins") plugin_path = None - + for p in plugins_dir.iterdir(): if p.is_dir(): manifest_path = p / "_manifest.json" @@ -1227,18 +1226,19 @@ async def get_plugin_config_schema( break except Exception: continue - + if not plugin_path: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - + # 读取配置文件获取当前配置 config_path = plugin_path / "config.toml" current_config = {} if config_path.exists(): import tomlkit + with open(config_path, "r", encoding="utf-8") as f: current_config = tomlkit.load(f) - + # 构建基础 schema(无法获取完整的 ConfigField 信息) schema = { "plugin_id": plugin_id, @@ -1252,7 +1252,7 @@ async def get_plugin_config_schema( "layout": {"type": "auto", "tabs": []}, "_note": "插件未加载,仅返回当前配置结构", } - + # 从当前配置推断 schema for section_name, section_data in current_config.items(): if isinstance(section_data, dict): @@ -1277,7 +1277,7 @@ async def get_plugin_config_schema( ui_type = "list" elif isinstance(field_value, dict): ui_type = "json" - + schema["sections"][section_name]["fields"][field_name] = { "name": field_name, "type": field_type, @@ -1290,7 +1290,7 @@ async def get_plugin_config_schema( "disabled": False, "order": 0, } - + return {"success": True, "schema": schema} except HTTPException: @@ -1301,12 +1301,10 @@ async def get_plugin_config_schema( @router.get("/config/{plugin_id}") -async def get_plugin_config( - plugin_id: str, authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 获取插件当前配置值 - + 返回插件的当前配置值。 """ # Token 验证 @@ -1321,7 +1319,7 @@ async def get_plugin_config( # 查找插件目录 plugins_dir = Path("plugins") plugin_path = None - + for p in plugins_dir.iterdir(): if p.is_dir(): manifest_path = p / "_manifest.json" @@ -1334,19 +1332,20 @@ async def get_plugin_config( break except Exception: continue - + if not plugin_path: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - + # 读取配置文件 config_path = plugin_path / "config.toml" if not config_path.exists(): return {"success": True, "config": {}, "message": "配置文件不存在"} - + import tomlkit + with open(config_path, "r", encoding="utf-8") as f: config = tomlkit.load(f) - + return {"success": True, "config": dict(config)} except HTTPException: @@ -1358,13 +1357,11 @@ async def get_plugin_config( @router.put("/config/{plugin_id}") async def update_plugin_config( - plugin_id: str, - request: UpdatePluginConfigRequest, - authorization: Optional[str] = Header(None) + plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None) ) -> Dict[str, Any]: """ 更新插件配置 - + 保存新的配置值到插件的配置文件。 """ # Token 验证 @@ -1379,7 +1376,7 @@ async def update_plugin_config( # 查找插件目录 plugins_dir = Path("plugins") plugin_path = None - + for p in plugins_dir.iterdir(): if p.is_dir(): manifest_path = p / "_manifest.json" @@ -1392,23 +1389,25 @@ async def update_plugin_config( break except Exception: continue - + if not plugin_path: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - + config_path = plugin_path / "config.toml" - + # 备份旧配置 import shutil import datetime + if config_path.exists(): backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" backup_path = plugin_path / backup_name shutil.copy(config_path, backup_path) logger.info(f"已备份配置文件: {backup_path}") - + # 写入新配置(使用 tomlkit 保留注释) import tomlkit + # 先读取原配置以保留注释和格式 existing_doc = tomlkit.document() if config_path.exists(): @@ -1419,14 +1418,10 @@ async def update_plugin_config( existing_doc[key] = value with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(existing_doc, f) - + logger.info(f"已更新插件配置: {plugin_id}") - - return { - "success": True, - "message": "配置已保存", - "note": "配置更改将在插件重新加载后生效" - } + + return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"} except HTTPException: raise @@ -1436,12 +1431,10 @@ async def update_plugin_config( @router.post("/config/{plugin_id}/reset") -async def reset_plugin_config( - plugin_id: str, authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 重置插件配置为默认值 - + 删除当前配置文件,下次加载插件时将使用默认配置。 """ # Token 验证 @@ -1456,7 +1449,7 @@ async def reset_plugin_config( # 查找插件目录 plugins_dir = Path("plugins") plugin_path = None - + for p in plugins_dir.iterdir(): if p.is_dir(): manifest_path = p / "_manifest.json" @@ -1469,29 +1462,26 @@ async def reset_plugin_config( break except Exception: continue - + if not plugin_path: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - + config_path = plugin_path / "config.toml" - + if not config_path.exists(): return {"success": True, "message": "配置文件不存在,无需重置"} - + # 备份并删除 import shutil import datetime + backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" backup_path = plugin_path / backup_name shutil.move(config_path, backup_path) - + logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}") - - return { - "success": True, - "message": "配置已重置,下次加载插件时将使用默认配置", - "backup": str(backup_path) - } + + return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)} except HTTPException: raise @@ -1501,12 +1491,10 @@ async def reset_plugin_config( @router.post("/config/{plugin_id}/toggle") -async def toggle_plugin( - plugin_id: str, authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 切换插件启用状态 - + 切换插件配置中的 enabled 字段。 """ # Token 验证 @@ -1521,7 +1509,7 @@ async def toggle_plugin( # 查找插件目录 plugins_dir = Path("plugins") plugin_path = None - + for p in plugins_dir.iterdir(): if p.is_dir(): manifest_path = p / "_manifest.json" @@ -1534,40 +1522,40 @@ async def toggle_plugin( break except Exception: continue - + if not plugin_path: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - + config_path = plugin_path / "config.toml" - + import tomlkit - + # 读取当前配置(保留注释和格式) config = tomlkit.document() if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: config = tomlkit.load(f) - + # 切换 enabled 状态 if "plugin" not in config: config["plugin"] = tomlkit.table() - + current_enabled = config["plugin"].get("enabled", True) new_enabled = not current_enabled config["plugin"]["enabled"] = new_enabled - + # 写入配置(保留注释) with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(config, f) - + status = "启用" if new_enabled else "禁用" logger.info(f"已{status}插件: {plugin_id}") - + return { "success": True, "enabled": new_enabled, "message": f"插件已{status}", - "note": "状态更改将在下次加载插件时生效" + "note": "状态更改将在下次加载插件时生效", } except HTTPException: diff --git a/src/webui/routers/system.py b/src/webui/routers/system.py index f826dc30..b78540b5 100644 --- a/src/webui/routers/system.py +++ b/src/webui/routers/system.py @@ -43,7 +43,7 @@ async def restart_maibot(): 注意:此操作会使麦麦暂时离线。 """ import asyncio - + try: # 记录重启操作 print(f"[{datetime.now()}] WebUI 触发重启操作") @@ -54,7 +54,7 @@ async def restart_maibot(): python = sys.executable args = [python] + sys.argv os.execv(python, args) - + # 创建后台任务执行重启 asyncio.create_task(delayed_restart()) diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 2c0a4c48..5997c3ba 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -20,10 +20,10 @@ class WebUIServer: self.port = port self.app = FastAPI(title="MaiBot WebUI") self._server = None - + # 显示 Access Token self._show_access_token() - + # 重要:先注册 API 路由,再设置静态文件 self._register_api_routes() self._setup_static_files() @@ -32,7 +32,7 @@ class WebUIServer: """显示 WebUI Access Token""" try: from src.webui.token_manager import get_token_manager - + token_manager = get_token_manager() current_token = token_manager.get_token() logger.info(f"🔑 WebUI Access Token: {current_token}") @@ -69,7 +69,7 @@ class WebUIServer: # 如果是根路径,直接返回 index.html if not full_path or full_path == "/": return FileResponse(static_path / "index.html", media_type="text/html") - + # 检查是否是静态文件 file_path = static_path / full_path if file_path.is_file() and file_path.exists(): @@ -88,13 +88,15 @@ class WebUIServer: # 导入所有 WebUI 路由 from src.webui.routes import router as webui_router from src.webui.logs_ws import router as logs_router - + logger.info("开始导入 knowledge_routes...") from src.webui.knowledge_routes import router as knowledge_router + logger.info("knowledge_routes 导入成功") - + # 导入本地聊天室路由 from src.webui.chat_routes import router as chat_router + logger.info("chat_routes 导入成功") # 注册路由 diff --git a/test_edge.py b/test_edge.py index a7ee8f05..7981bb30 100644 --- a/test_edge.py +++ b/test_edge.py @@ -8,23 +8,23 @@ if edges: e = edges[0] print(f"Edge tuple: {e}") print(f"Edge tuple type: {type(e)}") - + edge_data = kg.graph[e[0], e[1]] print(f"\nEdge data type: {type(edge_data)}") print(f"Edge data: {edge_data}") print(f"Has 'get' method: {hasattr(edge_data, 'get')}") print(f"Is dict: {isinstance(edge_data, dict)}") - + # 尝试不同的访问方式 try: print(f"\nUsing []: {edge_data['weight']}") except Exception as e: print(f"Using [] failed: {e}") - + try: print(f"Using .get(): {edge_data.get('weight')}") except Exception as e: print(f"Using .get() failed: {e}") - + # 查看所有属性 print(f"\nDir: {[x for x in dir(edge_data) if not x.startswith('_')]}")