diff --git a/bot.py b/bot.py index 3f3a4e9c..77465c40 100644 --- a/bot.py +++ b/bot.py @@ -50,6 +50,7 @@ print("警告:Dev进入不稳定开发状态,任何插件与WebUI均可能 print("\n\n\n\n\n") print("-----------------------------------------") + def run_runner_process(): """ Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。 diff --git a/plugins/MaiBot_MCPBridgePlugin/core/__init__.py b/plugins/MaiBot_MCPBridgePlugin/core/__init__.py index 8e85539f..d5656a8e 100644 --- a/plugins/MaiBot_MCPBridgePlugin/core/__init__.py +++ b/plugins/MaiBot_MCPBridgePlugin/core/__init__.py @@ -1,2 +1 @@ """Core helpers for MCP Bridge Plugin.""" - diff --git a/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py b/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py index 6ae5ff97..f2a6f011 100644 --- a/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py +++ b/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py @@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str: if not mcp_servers: return "" return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2) - diff --git a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py b/plugins/MaiBot_MCPBridgePlugin/mcp_client.py index 0d4eebff..de5abab2 100644 --- a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py +++ b/plugins/MaiBot_MCPBridgePlugin/mcp_client.py @@ -34,28 +34,31 @@ from enum import Enum # 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging try: from src.common.logger import get_logger + logger = get_logger("mcp_client") except ImportError: # Fallback: 使用标准 logging logger = logging.getLogger("mcp_client") if not logger.handlers: handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s')) + handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s")) logger.addHandler(handler) logger.setLevel(logging.INFO) class TransportType(Enum): """MCP 传输类型""" - STDIO = "stdio" # 本地进程通信 - SSE = "sse" # Server-Sent Events (旧版 HTTP) - HTTP = "http" # HTTP Streamable (新版,推荐) + + STDIO = "stdio" # 本地进程通信 + SSE = "sse" # Server-Sent Events (旧版 HTTP) + HTTP = "http" # HTTP Streamable (新版,推荐) STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名 @dataclass class MCPToolInfo: """MCP 工具信息""" + name: str description: str input_schema: Dict[str, Any] @@ -65,6 +68,7 @@ class MCPToolInfo: @dataclass class MCPResourceInfo: """MCP 资源信息""" + uri: str name: str description: str @@ -75,6 +79,7 @@ class MCPResourceInfo: @dataclass class MCPPromptInfo: """MCP 提示模板信息""" + name: str description: str arguments: List[Dict[str, Any]] # [{name, description, required}] @@ -84,6 +89,7 @@ class MCPPromptInfo: @dataclass class MCPServerConfig: """MCP 服务器配置""" + name: str enabled: bool = True transport: TransportType = TransportType.STDIO @@ -99,6 +105,7 @@ class MCPServerConfig: @dataclass class MCPCallResult: """MCP 工具调用结果""" + success: bool content: Any error: Optional[str] = None @@ -108,27 +115,28 @@ class MCPCallResult: class CircuitState(Enum): """断路器状态""" - CLOSED = "closed" # 正常状态,允许请求 - OPEN = "open" # 熔断状态,拒绝请求 + + CLOSED = "closed" # 正常状态,允许请求 + OPEN = "open" # 熔断状态,拒绝请求 HALF_OPEN = "half_open" # 半开状态,允许少量试探请求 @dataclass class CircuitBreaker: """v1.7.0: 断路器 - 防止对故障服务器持续请求 - + 状态转换: - CLOSED -> OPEN: 连续失败次数达到阈值 - OPEN -> HALF_OPEN: 熔断时间到期 - HALF_OPEN -> CLOSED: 试探请求成功 - HALF_OPEN -> OPEN: 试探请求失败 """ - + # 配置 - failure_threshold: int = 5 # 连续失败多少次后熔断 - recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) - half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 - + failure_threshold: int = 5 # 连续失败多少次后熔断 + recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) + half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 + # 状态 state: CircuitState = field(default=CircuitState.CLOSED) failure_count: int = 0 @@ -136,18 +144,18 @@ class CircuitBreaker: last_failure_time: float = 0.0 last_state_change: float = field(default_factory=time.time) half_open_calls: int = 0 - + def can_execute(self) -> Tuple[bool, Optional[str]]: """检查是否允许执行请求 - + Returns: (是否允许, 拒绝原因) """ current_time = time.time() - + if self.state == CircuitState.CLOSED: return True, None - + if self.state == CircuitState.OPEN: # 检查是否到了恢复时间 time_since_failure = current_time - self.last_failure_time @@ -158,20 +166,20 @@ class CircuitBreaker: else: remaining = self.recovery_timeout - time_since_failure return False, f"断路器熔断中,{remaining:.0f}秒后重试" - + if self.state == CircuitState.HALF_OPEN: # 半开状态,检查是否还有试探配额 if self.half_open_calls < self.half_open_max_calls: return True, None else: return False, "断路器半开状态,等待试探结果" - + return True, None - + def record_success(self) -> None: """记录成功调用""" self.success_count += 1 - + if self.state == CircuitState.HALF_OPEN: # 半开状态下成功,恢复到关闭状态 self._transition_to(CircuitState.CLOSED) @@ -179,12 +187,12 @@ class CircuitBreaker: elif self.state == CircuitState.CLOSED: # 正常状态下成功,重置失败计数 self.failure_count = 0 - + def record_failure(self) -> None: """记录失败调用""" self.failure_count += 1 self.last_failure_time = time.time() - + if self.state == CircuitState.HALF_OPEN: # 半开状态下失败,重新熔断 self._transition_to(CircuitState.OPEN) @@ -194,21 +202,21 @@ class CircuitBreaker: if self.failure_count >= self.failure_threshold: self._transition_to(CircuitState.OPEN) logger.warning(f"断路器熔断(连续失败 {self.failure_count} 次)") - + def _transition_to(self, new_state: CircuitState) -> None: """状态转换""" old_state = self.state self.state = new_state self.last_state_change = time.time() - + if new_state == CircuitState.CLOSED: self.failure_count = 0 self.half_open_calls = 0 elif new_state == CircuitState.HALF_OPEN: self.half_open_calls = 0 - + logger.debug(f"断路器状态: {old_state.value} -> {new_state.value}") - + def reset(self) -> None: """重置断路器""" self.state = CircuitState.CLOSED @@ -216,7 +224,7 @@ class CircuitBreaker: self.success_count = 0 self.half_open_calls = 0 self.last_state_change = time.time() - + def get_status(self) -> Dict[str, Any]: """获取断路器状态""" return { @@ -232,6 +240,7 @@ class CircuitBreaker: @dataclass class ToolCallStats: """工具调用统计""" + tool_key: str total_calls: int = 0 success_calls: int = 0 @@ -239,21 +248,21 @@ class ToolCallStats: total_duration_ms: float = 0.0 last_call_time: Optional[float] = None last_error: Optional[str] = None - + @property def success_rate(self) -> float: """成功率(0-100)""" if self.total_calls == 0: return 0.0 return (self.success_calls / self.total_calls) * 100 - + @property def avg_duration_ms(self) -> float: """平均耗时(毫秒)""" if self.success_calls == 0: return 0.0 return self.total_duration_ms / self.success_calls - + def record_call(self, success: bool, duration_ms: float, error: Optional[str] = None) -> None: """记录一次调用""" self.total_calls += 1 @@ -264,7 +273,7 @@ class ToolCallStats: else: self.failed_calls += 1 self.last_error = error - + def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { @@ -282,6 +291,7 @@ class ToolCallStats: @dataclass class ServerStats: """服务器统计""" + server_name: str connect_count: int = 0 # 连接次数 disconnect_count: int = 0 # 断开次数 @@ -290,26 +300,26 @@ class ServerStats: last_disconnect_time: Optional[float] = None last_heartbeat_time: Optional[float] = None consecutive_failures: int = 0 # 连续失败次数 - + def record_connect(self) -> None: self.connect_count += 1 self.last_connect_time = time.time() self.consecutive_failures = 0 - + def record_disconnect(self) -> None: self.disconnect_count += 1 self.last_disconnect_time = time.time() - + def record_reconnect(self) -> None: self.reconnect_count += 1 self.consecutive_failures = 0 - + def record_failure(self) -> None: self.consecutive_failures += 1 - + def record_heartbeat(self) -> None: self.last_heartbeat_time = time.time() - + def to_dict(self) -> Dict[str, Any]: return { "server_name": self.server_name, @@ -325,7 +335,7 @@ class ServerStats: class MCPClientSession: """MCP 客户端会话,管理与单个 MCP 服务器的连接""" - + def __init__(self, config: MCPServerConfig, call_timeout: float = 60.0): self.config = config self.call_timeout = call_timeout @@ -338,63 +348,63 @@ class MCPClientSession: self._prompts: List[MCPPromptInfo] = [] # v1.2.0: Prompts 支持 self._connected = False self._lock = asyncio.Lock() - + # 功能支持标记(服务器可能不支持某些功能) self._supports_resources: bool = False self._supports_prompts: bool = False - + # 统计信息 self.stats = ServerStats(server_name=config.name) self._tool_stats: Dict[str, ToolCallStats] = {} - + # v1.7.0: 断路器 self._circuit_breaker = CircuitBreaker() - + @property def is_connected(self) -> bool: return self._connected - + @property def tools(self) -> List[MCPToolInfo]: return self._tools.copy() - + @property def resources(self) -> List[MCPResourceInfo]: """v1.2.0: 获取资源列表""" return self._resources.copy() - + @property def prompts(self) -> List[MCPPromptInfo]: """v1.2.0: 获取提示模板列表""" return self._prompts.copy() - + @property def supports_resources(self) -> bool: """v1.2.0: 服务器是否支持 Resources""" return self._supports_resources - + @property def supports_prompts(self) -> bool: """v1.2.0: 服务器是否支持 Prompts""" return self._supports_prompts - + @property def server_name(self) -> str: return self.config.name - + def get_tool_stats(self, tool_name: str) -> Optional[ToolCallStats]: """获取工具统计""" return self._tool_stats.get(tool_name) - + def get_circuit_breaker_status(self) -> Dict[str, Any]: """v1.7.0: 获取断路器状态""" return self._circuit_breaker.get_status() - + def reset_circuit_breaker(self) -> None: """v1.7.0: 重置断路器""" self._circuit_breaker.reset() logger.info(f"[{self.server_name}] 断路器已重置") - + def get_all_tool_stats(self) -> Dict[str, ToolCallStats]: """获取所有工具统计""" return self._tool_stats.copy() @@ -404,7 +414,7 @@ class MCPClientSession: async with self._lock: if self._connected: return True - + try: success = False if self.config.transport == TransportType.STDIO: @@ -416,7 +426,7 @@ class MCPClientSession: else: logger.error(f"[{self.server_name}] 不支持的传输类型: {self.config.transport}") return False - + if success: self.stats.record_connect() # v1.7.0: 连接成功时重置断路器 @@ -424,13 +434,13 @@ class MCPClientSession: else: self.stats.record_failure() return success - + except Exception as e: logger.error(f"[{self.server_name}] 连接失败: {e}") self._connected = False self.stats.record_failure() return False - + async def _connect_stdio(self) -> bool: """通过 stdio 连接 MCP 服务器""" try: @@ -440,31 +450,29 @@ class MCPClientSession: except ImportError: logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") return False - + server_params = StdioServerParameters( - command=self.config.command, - args=self.config.args, - env=self.config.env if self.config.env else None + command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None ) - + self._stdio_context = stdio_client(server_params) self._read_stream, self._write_stream = await self._stdio_context.__aenter__() - + self._session_context = ClientSession(self._read_stream, self._write_stream) self._session = await self._session_context.__aenter__() - + await self._session.initialize() await self._fetch_tools() - + self._connected = True logger.info(f"[{self.server_name}] stdio 连接成功,发现 {len(self._tools)} 个工具") return True - + except Exception as e: logger.error(f"[{self.server_name}] stdio 连接失败: {e}") await self._cleanup() return False - + async def _connect_sse(self) -> bool: """通过 SSE 连接 MCP 服务器""" try: @@ -474,13 +482,13 @@ class MCPClientSession: except ImportError: logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") return False - + if not self.config.url: logger.error(f"[{self.server_name}] SSE 传输需要配置 url") return False - + logger.debug(f"[{self.server_name}] 正在连接 SSE MCP 服务器: {self.config.url}") - + # v1.4.2: 支持 headers 鉴权 sse_kwargs = { "url": self.config.url, @@ -489,23 +497,24 @@ class MCPClientSession: } if self.config.headers: sse_kwargs["headers"] = self.config.headers - + self._sse_context = sse_client(**sse_kwargs) self._read_stream, self._write_stream = await self._sse_context.__aenter__() - + self._session_context = ClientSession(self._read_stream, self._write_stream) self._session = await self._session_context.__aenter__() - + await self._session.initialize() await self._fetch_tools() - + self._connected = True logger.info(f"[{self.server_name}] SSE 连接成功,发现 {len(self._tools)} 个工具") return True - + except Exception as e: logger.error(f"[{self.server_name}] SSE 连接失败: {e}") import traceback + logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") await self._cleanup() return False @@ -519,13 +528,13 @@ class MCPClientSession: except ImportError: logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") return False - + if not self.config.url: logger.error(f"[{self.server_name}] HTTP 传输需要配置 url") return False - + logger.debug(f"[{self.server_name}] 正在连接 HTTP MCP 服务器: {self.config.url}") - + # v1.4.2: 支持 headers 鉴权 http_kwargs = { "url": self.config.url, @@ -534,23 +543,24 @@ class MCPClientSession: } if self.config.headers: http_kwargs["headers"] = self.config.headers - + self._http_context = streamablehttp_client(**http_kwargs) self._read_stream, self._write_stream, self._get_session_id = await self._http_context.__aenter__() - + self._session_context = ClientSession(self._read_stream, self._write_stream) self._session = await self._session_context.__aenter__() - + await self._session.initialize() await self._fetch_tools() - + self._connected = True logger.info(f"[{self.server_name}] HTTP 连接成功,发现 {len(self._tools)} 个工具") return True - + except Exception as e: logger.error(f"[{self.server_name}] HTTP 连接失败: {e}") import traceback + logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") await self._cleanup() return False @@ -559,59 +569,56 @@ class MCPClientSession: """获取 MCP 服务器的工具列表""" if not self._session: return - + try: result = await self._session.list_tools() self._tools = [] - + for tool in result.tools: tool_info = MCPToolInfo( name=tool.name, description=tool.description or f"MCP tool: {tool.name}", - input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {}, - server_name=self.server_name + input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {}, + server_name=self.server_name, ) self._tools.append(tool_info) # 初始化工具统计 if tool.name not in self._tool_stats: self._tool_stats[tool.name] = ToolCallStats(tool_key=tool.name) logger.debug(f"[{self.server_name}] 发现工具: {tool.name}") - + except Exception as e: logger.error(f"[{self.server_name}] 获取工具列表失败: {e}") self._tools = [] async def fetch_resources(self) -> bool: """v1.2.0: 获取 MCP 服务器的资源列表 - + Returns: bool: 是否成功获取(服务器不支持时返回 False) """ if not self._session: return False - + try: - result = await asyncio.wait_for( - self._session.list_resources(), - timeout=self.call_timeout - ) + result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout) self._resources = [] - + for resource in result.resources: resource_info = MCPResourceInfo( uri=str(resource.uri), name=resource.name or str(resource.uri), description=resource.description or "", - mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None, - server_name=self.server_name + mime_type=resource.mimeType if hasattr(resource, "mimeType") else None, + server_name=self.server_name, ) self._resources.append(resource_info) logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}") - + self._supports_resources = True logger.info(f"[{self.server_name}] 获取到 {len(self._resources)} 个资源") return True - + except Exception as e: # 服务器可能不支持 resources,这不是错误 error_str = str(e).lower() @@ -625,44 +632,43 @@ class MCPClientSession: async def fetch_prompts(self) -> bool: """v1.2.0: 获取 MCP 服务器的提示模板列表 - + Returns: bool: 是否成功获取(服务器不支持时返回 False) """ if not self._session: return False - + try: - result = await asyncio.wait_for( - self._session.list_prompts(), - timeout=self.call_timeout - ) + result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout) self._prompts = [] - + for prompt in result.prompts: # 解析参数 arguments = [] - if hasattr(prompt, 'arguments') and prompt.arguments: + if hasattr(prompt, "arguments") and prompt.arguments: for arg in prompt.arguments: - arguments.append({ - "name": arg.name, - "description": arg.description or "", - "required": arg.required if hasattr(arg, 'required') else False, - }) - + arguments.append( + { + "name": arg.name, + "description": arg.description or "", + "required": arg.required if hasattr(arg, "required") else False, + } + ) + prompt_info = MCPPromptInfo( name=prompt.name, description=prompt.description or f"MCP prompt: {prompt.name}", arguments=arguments, - server_name=self.server_name + server_name=self.server_name, ) self._prompts.append(prompt_info) logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}") - + self._supports_prompts = True logger.info(f"[{self.server_name}] 获取到 {len(self._prompts)} 个提示模板") return True - + except Exception as e: # 服务器可能不支持 prompts,这不是错误 error_str = str(e).lower() @@ -676,45 +682,35 @@ class MCPClientSession: async def read_resource(self, uri: str) -> MCPCallResult: """v1.2.0: 读取指定资源的内容 - + Args: uri: 资源 URI - + Returns: MCPCallResult: 包含资源内容的结果 """ start_time = time.time() - + if not self._connected or not self._session: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 未连接" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") + if not self._supports_resources: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 不支持 Resources 功能" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能") + try: - result = await asyncio.wait_for( - self._session.read_resource(uri), - timeout=self.call_timeout - ) - + result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout) + duration_ms = (time.time() - start_time) * 1000 - + # 处理返回内容 content_parts = [] for content in result.contents: - if hasattr(content, 'text'): + if hasattr(content, "text"): content_parts.append(content.text) - elif hasattr(content, 'blob'): + elif hasattr(content, "blob"): # 二进制数据,返回 base64 或提示 import base64 + blob_data = content.blob if len(blob_data) < 10000: # 小于 10KB 返回 base64 content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}") @@ -722,117 +718,85 @@ class MCPClientSession: content_parts.append(f"[二进制数据: {len(blob_data)} bytes]") else: content_parts.append(str(content)) - + return MCPCallResult( - success=True, - content="\n".join(content_parts) if content_parts else "", - duration_ms=duration_ms + success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms ) - + except asyncio.TimeoutError: duration_ms = (time.time() - start_time) * 1000 return MCPCallResult( - success=False, - content=None, - error=f"读取资源超时({self.call_timeout}秒)", - duration_ms=duration_ms + success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms ) except Exception as e: duration_ms = (time.time() - start_time) * 1000 logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}") - return MCPCallResult( - success=False, - content=None, - error=str(e), - duration_ms=duration_ms - ) + return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult: """v1.2.0: 获取提示模板的内容 - + Args: name: 提示模板名称 arguments: 模板参数 - + Returns: MCPCallResult: 包含提示内容的结果 """ start_time = time.time() - + if not self._connected or not self._session: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 未连接" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") + if not self._supports_prompts: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 不支持 Prompts 功能" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能") + try: result = await asyncio.wait_for( - self._session.get_prompt(name, arguments=arguments or {}), - timeout=self.call_timeout + self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout ) - + duration_ms = (time.time() - start_time) * 1000 - + # 处理返回的消息 messages = [] for msg in result.messages: - role = msg.role if hasattr(msg, 'role') else "unknown" + role = msg.role if hasattr(msg, "role") else "unknown" content_text = "" - if hasattr(msg, 'content'): - if hasattr(msg.content, 'text'): + if hasattr(msg, "content"): + if hasattr(msg.content, "text"): content_text = msg.content.text elif isinstance(msg.content, str): content_text = msg.content else: content_text = str(msg.content) messages.append(f"[{role}]: {content_text}") - + return MCPCallResult( - success=True, - content="\n\n".join(messages) if messages else "", - duration_ms=duration_ms + success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms ) - + except asyncio.TimeoutError: duration_ms = (time.time() - start_time) * 1000 return MCPCallResult( - success=False, - content=None, - error=f"获取提示模板超时({self.call_timeout}秒)", - duration_ms=duration_ms + success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms ) except Exception as e: duration_ms = (time.time() - start_time) * 1000 logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}") - return MCPCallResult( - success=False, - content=None, - error=str(e), - duration_ms=duration_ms - ) - + return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) + async def check_health(self) -> bool: """检查连接健康状态(心跳检测) - + 通过调用 list_tools 来验证连接是否正常 """ if not self._connected or not self._session: return False - + try: # 使用 list_tools 作为心跳检测 - await asyncio.wait_for( - self._session.list_tools(), - timeout=10.0 - ) + await asyncio.wait_for(self._session.list_tools(), timeout=10.0) self.stats.record_heartbeat() return True except Exception as e: @@ -841,25 +805,20 @@ class MCPClientSession: self._connected = False self.stats.record_disconnect() return False - + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPCallResult: """调用 MCP 工具""" start_time = time.time() - + # v1.7.0: 断路器检查 can_execute, reject_reason = self._circuit_breaker.can_execute() if not can_execute: - return MCPCallResult( - success=False, - content=None, - error=f"⚡ {reject_reason}", - circuit_broken=True - ) - + return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True) + # 半开状态下增加试探计数 if self._circuit_breaker.state == CircuitState.HALF_OPEN: self._circuit_breaker.half_open_calls += 1 - + if not self._connected or not self._session: error_msg = f"服务器 {self.server_name} 未连接" # 记录失败 @@ -867,38 +826,37 @@ class MCPClientSession: self._tool_stats[tool_name].record_call(False, 0, error_msg) self._circuit_breaker.record_failure() return MCPCallResult(success=False, content=None, error=error_msg) - + try: result = await asyncio.wait_for( - self._session.call_tool(tool_name, arguments=arguments), - timeout=self.call_timeout + self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout ) - + duration_ms = (time.time() - start_time) * 1000 - + # 处理返回内容 content_parts = [] for content in result.content: - if hasattr(content, 'text'): + if hasattr(content, "text"): content_parts.append(content.text) - elif hasattr(content, 'data'): + elif hasattr(content, "data"): content_parts.append(f"[二进制数据: {len(content.data)} bytes]") else: content_parts.append(str(content)) - + # 记录成功 if tool_name in self._tool_stats: self._tool_stats[tool_name].record_call(True, duration_ms) - + # v1.7.0: 断路器记录成功 self._circuit_breaker.record_success() - + return MCPCallResult( success=True, content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)", - duration_ms=duration_ms + duration_ms=duration_ms, ) - + except asyncio.TimeoutError: duration_ms = (time.time() - start_time) * 1000 error_msg = f"工具调用超时({self.call_timeout}秒)" @@ -907,7 +865,7 @@ class MCPClientSession: # v1.7.0: 断路器记录失败 self._circuit_breaker.record_failure() return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms) - + except Exception as e: duration_ms = (time.time() - start_time) * 1000 error_msg = str(e) @@ -928,7 +886,7 @@ class MCPClientSession: if self._connected: self.stats.record_disconnect() await self._cleanup() - + async def _cleanup(self) -> None: """清理资源""" self._connected = False @@ -937,31 +895,31 @@ class MCPClientSession: self._prompts = [] # v1.2.0 self._supports_resources = False # v1.2.0 self._supports_prompts = False # v1.2.0 - + try: - if hasattr(self, '_session_context') and self._session_context: + if hasattr(self, "_session_context") and self._session_context: await self._session_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}") - + try: - if hasattr(self, '_stdio_context') and self._stdio_context: + if hasattr(self, "_stdio_context") and self._stdio_context: await self._stdio_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}") - + try: - if hasattr(self, '_http_context') and self._http_context: + if hasattr(self, "_http_context") and self._http_context: await self._http_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}") - + try: - if hasattr(self, '_sse_context') and self._sse_context: + if hasattr(self, "_sse_context") and self._sse_context: await self._sse_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}") - + self._session = None self._session_context = None self._stdio_context = None @@ -969,27 +927,27 @@ class MCPClientSession: self._sse_context = None self._read_stream = None self._write_stream = None - + logger.debug(f"[{self.server_name}] 连接已关闭") class MCPClientManager: """MCP 客户端管理器,管理多个 MCP 服务器连接 - + 功能: - 管理多个 MCP 服务器连接 - 心跳检测和自动重连 - 调用统计 """ - + _instance: Optional["MCPClientManager"] = None - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - + def __init__(self): if self._initialized: return @@ -1000,14 +958,14 @@ class MCPClientManager: self._all_prompts: Dict[str, Tuple[MCPPromptInfo, MCPClientSession]] = {} # v1.2.0 self._settings: Dict[str, Any] = {} self._lock = asyncio.Lock() - + # 心跳检测任务 self._heartbeat_task: Optional[asyncio.Task] = None self._heartbeat_running = False - + # 状态变化回调 self._on_status_change: Optional[callable] = None - + # 全局统计 self._global_stats = { "total_tool_calls": 0, @@ -1015,15 +973,15 @@ class MCPClientManager: "failed_calls": 0, "start_time": time.time(), } - + def configure(self, settings: Dict[str, Any]) -> None: """配置管理器""" self._settings = settings - + def set_status_change_callback(self, callback: callable) -> None: """设置状态变化回调函数""" self._on_status_change = callback - + def _notify_status_change(self) -> None: """通知状态变化""" if self._on_status_change: @@ -1031,27 +989,27 @@ class MCPClientManager: self._on_status_change() except Exception as e: logger.debug(f"状态变化回调出错: {e}") - + @property def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]: """获取所有已注册的工具""" return self._all_tools.copy() - + @property def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]: """v1.2.0: 获取所有已注册的资源""" return self._all_resources.copy() - + @property def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]: """v1.2.0: 获取所有已注册的提示模板""" return self._all_prompts.copy() - + @property def connected_servers(self) -> List[str]: """获取已连接的服务器列表""" return [name for name, client in self._clients.items() if client.is_connected] - + @property def disconnected_servers(self) -> List[str]: """获取已断开的服务器列表""" @@ -1063,36 +1021,38 @@ class MCPClientManager: if config.name in self._clients: logger.warning(f"服务器 {config.name} 已存在") return False - + call_timeout = self._settings.get("call_timeout", 60.0) client = MCPClientSession(config, call_timeout) self._clients[config.name] = client - + if not config.enabled: logger.info(f"服务器 {config.name} 已添加但未启用") return True - + # 尝试连接 retry_attempts = self._settings.get("retry_attempts", 3) retry_interval = self._settings.get("retry_interval", 5.0) - + for attempt in range(1, retry_attempts + 1): if await client.connect(): self._register_tools(client) return True - + if attempt < retry_attempts: - logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") + logger.warning( + f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})" + ) await asyncio.sleep(retry_interval) - + logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})") # 连接失败,但保留在 _clients 中以便后续重连 return False - + def _register_tools(self, client: MCPClientSession) -> None: """注册客户端的工具""" tool_prefix = self._settings.get("tool_prefix", "mcp") - + for tool in client.tools: if tool.name.startswith(f"{tool_prefix}_{client.server_name}_"): tool_key = tool.name @@ -1100,12 +1060,12 @@ class MCPClientManager: tool_key = f"{tool_prefix}_{client.server_name}_{tool.name}" self._all_tools[tool_key] = (tool, client) logger.debug(f"注册 MCP 工具: {tool_key}") - + def _unregister_tools(self, server_name: str) -> List[str]: """注销服务器的工具,返回被注销的工具键列表""" tool_prefix = self._settings.get("tool_prefix", "mcp") prefix = f"{tool_prefix}_{server_name}_" - + keys_to_remove = [k for k in self._all_tools.keys() if k.startswith(prefix)] for key in keys_to_remove: del self._all_tools[key] @@ -1115,7 +1075,7 @@ class MCPClientManager: def _register_resources(self, client: MCPClientSession) -> None: """v1.2.0: 注册客户端的资源""" tool_prefix = self._settings.get("tool_prefix", "mcp") - + for resource in client.resources: # 资源键格式: mcp_{server}_{uri_safe_name} # 将 URI 转换为安全的键名 @@ -1123,12 +1083,12 @@ class MCPClientManager: resource_key = f"{tool_prefix}_{client.server_name}_res_{safe_uri}" self._all_resources[resource_key] = (resource, client) logger.debug(f"注册 MCP 资源: {resource_key}") - + def _unregister_resources(self, server_name: str) -> List[str]: """v1.2.0: 注销服务器的资源""" tool_prefix = self._settings.get("tool_prefix", "mcp") prefix = f"{tool_prefix}_{server_name}_res_" - + keys_to_remove = [k for k in self._all_resources.keys() if k.startswith(prefix)] for key in keys_to_remove: del self._all_resources[key] @@ -1138,56 +1098,56 @@ class MCPClientManager: def _register_prompts(self, client: MCPClientSession) -> None: """v1.2.0: 注册客户端的提示模板""" tool_prefix = self._settings.get("tool_prefix", "mcp") - + for prompt in client.prompts: prompt_key = f"{tool_prefix}_{client.server_name}_prompt_{prompt.name}" self._all_prompts[prompt_key] = (prompt, client) logger.debug(f"注册 MCP 提示模板: {prompt_key}") - + def _unregister_prompts(self, server_name: str) -> List[str]: """v1.2.0: 注销服务器的提示模板""" tool_prefix = self._settings.get("tool_prefix", "mcp") prefix = f"{tool_prefix}_{server_name}_prompt_" - + keys_to_remove = [k for k in self._all_prompts.keys() if k.startswith(prefix)] for key in keys_to_remove: del self._all_prompts[key] logger.debug(f"注销 MCP 提示模板: {key}") return keys_to_remove - + async def remove_server(self, server_name: str) -> bool: """移除 MCP 服务器""" async with self._lock: if server_name not in self._clients: return False - + client = self._clients[server_name] await client.disconnect() self._unregister_tools(server_name) self._unregister_resources(server_name) # v1.2.0 self._unregister_prompts(server_name) # v1.2.0 del self._clients[server_name] - + logger.info(f"服务器 {server_name} 已移除") return True - + async def reconnect_server(self, server_name: str) -> bool: """重新连接服务器""" if server_name not in self._clients: return False - + client = self._clients[server_name] - + async with self._lock: self._unregister_tools(server_name) self._unregister_resources(server_name) # v1.2.0 self._unregister_prompts(server_name) # v1.2.0 await client.disconnect() - + # 尝试重连 retry_attempts = self._settings.get("retry_attempts", 3) retry_interval = self._settings.get("retry_interval", 5.0) - + for attempt in range(1, retry_attempts + 1): if await client.connect(): async with self._lock: @@ -1202,46 +1162,42 @@ class MCPClientManager: client.stats.record_reconnect() logger.info(f"服务器 {server_name} 重连成功") return True - + if attempt < retry_attempts: logger.warning(f"服务器 {server_name} 重连失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") await asyncio.sleep(retry_interval) - + logger.error(f"服务器 {server_name} 重连失败") return False async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult: """调用 MCP 工具""" if tool_key not in self._all_tools: - return MCPCallResult( - success=False, - content=None, - error=f"工具 {tool_key} 不存在" - ) - + return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在") + tool_info, client = self._all_tools[tool_key] - + # 更新全局统计 self._global_stats["total_tool_calls"] += 1 - + result = await client.call_tool(tool_info.name, arguments) - + if result.success: self._global_stats["successful_calls"] += 1 else: self._global_stats["failed_calls"] += 1 - + return result async def fetch_resources_for_server(self, server_name: str) -> bool: """v1.2.0: 获取指定服务器的资源列表""" if server_name not in self._clients: return False - + client = self._clients[server_name] if not client.is_connected: return False - + success = await client.fetch_resources() if success: async with self._lock: @@ -1252,11 +1208,11 @@ class MCPClientManager: """v1.2.0: 获取指定服务器的提示模板列表""" if server_name not in self._clients: return False - + client = self._clients[server_name] if not client.is_connected: return False - + success = await client.fetch_prompts() if success: async with self._lock: @@ -1265,7 +1221,7 @@ class MCPClientManager: async def read_resource(self, uri: str, server_name: Optional[str] = None) -> MCPCallResult: """v1.2.0: 读取资源内容 - + Args: uri: 资源 URI server_name: 指定服务器名称(可选,不指定则自动查找) @@ -1273,36 +1229,29 @@ class MCPClientManager: # 如果指定了服务器 if server_name: if server_name not in self._clients: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {server_name} 不存在" - ) + return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") client = self._clients[server_name] return await client.read_resource(uri) - + # 自动查找拥有该资源的服务器 - for resource_key, (resource_info, client) in self._all_resources.items(): + for _resource_key, (resource_info, client) in self._all_resources.items(): if resource_info.uri == uri: return await client.read_resource(uri) - + # 尝试在所有支持 resources 的服务器上查找 for client in self._clients.values(): if client.is_connected and client.supports_resources: result = await client.read_resource(uri) if result.success: return result - - return MCPCallResult( - success=False, - content=None, - error=f"未找到资源: {uri}" - ) - async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None, - server_name: Optional[str] = None) -> MCPCallResult: + return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}") + + async def get_prompt( + self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None + ) -> MCPCallResult: """v1.2.0: 获取提示模板内容 - + Args: name: 提示模板名称 arguments: 模板参数 @@ -1311,42 +1260,34 @@ class MCPClientManager: # 如果指定了服务器 if server_name: if server_name not in self._clients: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {server_name} 不存在" - ) + return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") client = self._clients[server_name] return await client.get_prompt(name, arguments) - + # 自动查找拥有该提示模板的服务器 - for prompt_key, (prompt_info, client) in self._all_prompts.items(): + for _prompt_key, (prompt_info, client) in self._all_prompts.items(): if prompt_info.name == name: return await client.get_prompt(name, arguments) - - return MCPCallResult( - success=False, - content=None, - error=f"未找到提示模板: {name}" - ) - + + return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}") + # ==================== 心跳检测 ==================== - + async def start_heartbeat(self) -> None: """启动心跳检测任务""" if self._heartbeat_running: logger.warning("心跳检测任务已在运行") return - + heartbeat_enabled = self._settings.get("heartbeat_enabled", True) if not heartbeat_enabled: logger.info("心跳检测已禁用") return - + self._heartbeat_running = True self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) logger.info("心跳检测任务已启动") - + async def stop_heartbeat(self) -> None: """停止心跳检测任务""" self._heartbeat_running = False @@ -1358,52 +1299,52 @@ class MCPClientManager: pass self._heartbeat_task = None logger.info("心跳检测任务已停止") - + async def _heartbeat_loop(self) -> None: """心跳检测循环(v1.5.2: 智能心跳间隔)""" base_interval = self._settings.get("heartbeat_interval", 60.0) auto_reconnect = self._settings.get("auto_reconnect", True) max_reconnect_attempts = self._settings.get("max_reconnect_attempts", 3) - + # v1.5.2: 智能心跳配置 adaptive_enabled = self._settings.get("heartbeat_adaptive", True) max_multiplier = self._settings.get("heartbeat_max_multiplier", 3.0) - + # 每个服务器独立的心跳间隔(根据稳定性动态调整) server_intervals: Dict[str, float] = {} min_interval = max(base_interval * 0.5, 30.0) # 最小间隔 max_interval = base_interval * max_multiplier # 最大间隔 - + mode_str = "智能" if adaptive_enabled else "固定" logger.info(f"心跳检测循环启动,{mode_str}模式,基准间隔: {base_interval}秒") - + while self._heartbeat_running: try: # 使用最小的服务器间隔作为循环间隔 current_interval = min(server_intervals.values()) if server_intervals else base_interval current_interval = max(current_interval, min_interval) - + await asyncio.sleep(current_interval) - + if not self._heartbeat_running: break - + current_time = time.time() - + # 检查所有已启用的服务器 for server_name, client in list(self._clients.items()): if not client.config.enabled: continue - + # 初始化服务器间隔 if server_name not in server_intervals: server_intervals[server_name] = base_interval - + # 检查是否到达该服务器的心跳时间 last_heartbeat = client.stats.last_heartbeat_time or 0 if current_time - last_heartbeat < server_intervals[server_name] * 0.9: continue # 还没到心跳时间 - + if client.is_connected: # 检查健康状态 healthy = await client.check_health() @@ -1435,71 +1376,73 @@ class MCPClientManager: # 达到最大重连次数,降低检测频率 server_intervals[server_name] = max_interval logger.debug(f"[{server_name}] 已达最大重连次数,降低检测频率") - + except asyncio.CancelledError: break except Exception as e: logger.error(f"心跳检测循环出错: {e}") await asyncio.sleep(5) - + async def _try_reconnect(self, server_name: str, max_attempts: int) -> bool: """尝试重连服务器""" client = self._clients.get(server_name) if not client: return False - + if client.stats.consecutive_failures >= max_attempts: logger.warning(f"[{server_name}] 连续失败次数已达上限 ({max_attempts}),暂停重连") return False - + logger.info(f"[{server_name}] 尝试重连 (失败次数: {client.stats.consecutive_failures}/{max_attempts})") - + success = await self.reconnect_server(server_name) if not success: client.stats.record_failure() - + self._notify_status_change() # 重连后更新状态 return success # ==================== 统计和状态 ==================== - + def get_tool_stats(self, tool_key: str) -> Optional[Dict[str, Any]]: """获取指定工具的统计信息""" if tool_key not in self._all_tools: return None - + tool_info, client = self._all_tools[tool_key] stats = client.get_tool_stats(tool_info.name) return stats.to_dict() if stats else None - + def get_all_stats(self) -> Dict[str, Any]: """获取所有统计信息""" server_stats = {} tool_stats = {} - + for server_name, client in self._clients.items(): server_stats[server_name] = client.stats.to_dict() for tool_name, stats in client.get_all_tool_stats().items(): full_key = f"{self._settings.get('tool_prefix', 'mcp')}_{server_name}_{tool_name}" tool_stats[full_key] = stats.to_dict() - + uptime = time.time() - self._global_stats["start_time"] - + return { "global": { **self._global_stats, "uptime_seconds": round(uptime, 2), - "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0, + "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) + if uptime > 0 + else 0, }, "servers": server_stats, "tools": tool_stats, } - + async def shutdown(self) -> None: """关闭所有连接""" # 停止心跳检测 await self.stop_heartbeat() - + async with self._lock: for client in self._clients.values(): await client.disconnect() @@ -1508,7 +1451,7 @@ class MCPClientManager: self._all_resources.clear() # v1.2.0 self._all_prompts.clear() # v1.2.0 logger.info("MCP 客户端管理器已关闭") - + def get_status(self) -> Dict[str, Any]: """获取状态信息""" return { diff --git a/plugins/MaiBot_MCPBridgePlugin/plugin.py b/plugins/MaiBot_MCPBridgePlugin/plugin.py index 6af626d6..1d965e25 100644 --- a/plugins/MaiBot_MCPBridgePlugin/plugin.py +++ b/plugins/MaiBot_MCPBridgePlugin/plugin.py @@ -123,9 +123,11 @@ logger = get_logger("mcp_bridge_plugin") # v1.4.0: 调用链路追踪 # ============================================================================ + @dataclass class ToolCallRecord: """工具调用记录""" + call_id: str timestamp: float tool_name: str @@ -145,46 +147,46 @@ class ToolCallRecord: class ToolCallTracer: """工具调用追踪器""" - + def __init__(self, max_records: int = 100): self._records: deque[ToolCallRecord] = deque(maxlen=max_records) self._enabled: bool = True self._log_enabled: bool = False self._log_path: Optional[Path] = None - + def configure(self, enabled: bool, max_records: int, log_enabled: bool, log_path: Optional[Path] = None) -> None: """配置追踪器""" self._enabled = enabled self._records = deque(self._records, maxlen=max_records) self._log_enabled = log_enabled self._log_path = log_path - + def record(self, record: ToolCallRecord) -> None: """添加调用记录""" if not self._enabled: return - + self._records.append(record) - + if self._log_enabled and self._log_path: self._write_to_log(record) - + def get_recent(self, n: int = 10) -> List[ToolCallRecord]: """获取最近 N 条记录""" return list(self._records)[-n:] - + def get_by_tool(self, tool_name: str) -> List[ToolCallRecord]: """按工具名筛选记录""" return [r for r in self._records if r.tool_name == tool_name] - + def get_by_server(self, server_name: str) -> List[ToolCallRecord]: """按服务器名筛选记录""" return [r for r in self._records if r.server_name == server_name] - + def clear(self) -> None: """清空记录""" self._records.clear() - + def _write_to_log(self, record: ToolCallRecord) -> None: """写入 JSONL 日志文件""" try: @@ -194,7 +196,7 @@ class ToolCallTracer: f.write(json.dumps(asdict(record), ensure_ascii=False) + "\n") except Exception as e: logger.warning(f"写入追踪日志失败: {e}") - + @property def total_records(self) -> int: return len(self._records) @@ -208,9 +210,11 @@ tool_call_tracer = ToolCallTracer() # v1.4.0: 工具调用缓存 # ============================================================================ + @dataclass class CacheEntry: """缓存条目""" + tool_name: str args_hash: str result: str @@ -221,7 +225,7 @@ class CacheEntry: class ToolCallCache: """工具调用缓存(LRU)""" - + def __init__(self, max_entries: int = 200, ttl: int = 300): self._cache: OrderedDict[str, CacheEntry] = OrderedDict() self._max_entries = max_entries @@ -229,54 +233,54 @@ class ToolCallCache: self._enabled = False self._exclude_patterns: List[str] = [] self._stats = {"hits": 0, "misses": 0} - + def configure(self, enabled: bool, ttl: int, max_entries: int, exclude_tools: str) -> None: """配置缓存""" self._enabled = enabled self._ttl = ttl self._max_entries = max_entries self._exclude_patterns = [p.strip() for p in exclude_tools.strip().split("\n") if p.strip()] - + def get(self, tool_name: str, args: Dict) -> Optional[str]: """获取缓存""" if not self._enabled: return None - + if self._is_excluded(tool_name): return None - + key = self._generate_key(tool_name, args) - + if key not in self._cache: self._stats["misses"] += 1 return None - + entry = self._cache[key] - + # 检查是否过期 if time.time() > entry.expires_at: del self._cache[key] self._stats["misses"] += 1 return None - + # LRU: 移到末尾 self._cache.move_to_end(key) entry.hit_count += 1 self._stats["hits"] += 1 - + return entry.result - + def set(self, tool_name: str, args: Dict, result: str) -> None: """设置缓存""" if not self._enabled: return - + if self._is_excluded(tool_name): return - + key = self._generate_key(tool_name, args) now = time.time() - + entry = CacheEntry( tool_name=tool_name, args_hash=key, @@ -284,7 +288,7 @@ class ToolCallCache: created_at=now, expires_at=now + self._ttl, ) - + # 如果已存在,更新 if key in self._cache: self._cache[key] = entry @@ -293,25 +297,25 @@ class ToolCallCache: # 检查容量 self._evict_if_needed() self._cache[key] = entry - + def clear(self) -> None: """清空缓存""" self._cache.clear() self._stats = {"hits": 0, "misses": 0} - + def _generate_key(self, tool_name: str, args: Dict) -> str: """生成缓存键""" args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) content = f"{tool_name}:{args_str}" return hashlib.md5(content.encode()).hexdigest() - + def _is_excluded(self, tool_name: str) -> bool: """检查是否在排除列表中""" for pattern in self._exclude_patterns: if fnmatch.fnmatch(tool_name, pattern): return True return False - + def _evict_if_needed(self) -> None: """必要时淘汰条目""" # 先清理过期的 @@ -319,11 +323,11 @@ class ToolCallCache: expired_keys = [k for k, v in self._cache.items() if now > v.expires_at] for k in expired_keys: del self._cache[k] - + # LRU 淘汰 while len(self._cache) >= self._max_entries: self._cache.popitem(last=False) - + def get_stats(self) -> Dict[str, Any]: """获取缓存统计""" total = self._stats["hits"] + self._stats["misses"] @@ -347,16 +351,17 @@ tool_call_cache = ToolCallCache() # v1.4.0: 工具权限控制 # ============================================================================ + class PermissionChecker: """工具权限检查器""" - + def __init__(self): self._enabled = False self._default_mode = "allow_all" # allow_all 或 deny_all self._rules: List[Dict] = [] self._quick_deny_groups: set = set() self._quick_allow_users: set = set() - + def configure( self, enabled: bool, @@ -368,61 +373,61 @@ class PermissionChecker: """配置权限检查器""" self._enabled = enabled self._default_mode = default_mode if default_mode in ("allow_all", "deny_all") else "allow_all" - + # 解析快捷配置 self._quick_deny_groups = {g.strip() for g in quick_deny_groups.strip().split("\n") if g.strip()} self._quick_allow_users = {u.strip() for u in quick_allow_users.strip().split("\n") if u.strip()} - + try: self._rules = json.loads(rules_json) if rules_json.strip() else [] except json.JSONDecodeError as e: logger.warning(f"权限规则 JSON 解析失败: {e}") self._rules = [] - + def check(self, tool_name: str, chat_id: str, user_id: str, is_group: bool) -> bool: """检查权限 - + Args: tool_name: 工具名称 chat_id: 聊天 ID(群号或私聊 ID) user_id: 用户 ID is_group: 是否为群聊 - + Returns: True 表示允许,False 表示拒绝 """ if not self._enabled: return True - + # 快捷配置优先级最高 # 1. 管理员白名单(始终允许) if user_id and user_id in self._quick_allow_users: return True - + # 2. 禁用群列表(始终拒绝) if is_group and chat_id and chat_id in self._quick_deny_groups: return False - + # 查找匹配的规则 for rule in self._rules: tool_pattern = rule.get("tool", "") if not self._match_tool(tool_pattern, tool_name): continue - + # 找到匹配的规则 mode = rule.get("mode", "") allowed = rule.get("allowed", []) denied = rule.get("denied", []) - + # 构建当前上下文的 ID 列表 context_ids = self._build_context_ids(chat_id, user_id, is_group) - + # 检查 denied 列表(优先级最高) if denied: for ctx_id in context_ids: if self._match_id_list(denied, ctx_id): return False - + # 检查 allowed 列表 if allowed: for ctx_id in context_ids: @@ -431,41 +436,41 @@ class PermissionChecker: # 如果是 whitelist 模式且不在 allowed 中,拒绝 if mode == "whitelist": return False - + # 规则匹配但没有明确允许/拒绝,继续检查下一条规则 - + # 没有匹配的规则,使用默认模式 return self._default_mode == "allow_all" - + def _match_tool(self, pattern: str, tool_name: str) -> bool: """工具名通配符匹配""" if not pattern: return False return fnmatch.fnmatch(tool_name, pattern) - + def _build_context_ids(self, chat_id: str, user_id: str, is_group: bool) -> List[str]: """构建上下文 ID 列表""" ids = [] - + # 用户级别(任何场景生效) if user_id: ids.append(f"qq:{user_id}:user") - + # 场景级别 if is_group and chat_id: ids.append(f"qq:{chat_id}:group") elif chat_id: ids.append(f"qq:{chat_id}:private") - + return ids - + def _match_id_list(self, id_list: List[str], context_id: str) -> bool: """检查 ID 是否在列表中""" for rule_id in id_list: if fnmatch.fnmatch(context_id, rule_id): return True return False - + def get_rules_for_tool(self, tool_name: str) -> List[Dict]: """获取特定工具的权限规则""" return [r for r in self._rules if self._match_tool(r.get("tool", ""), tool_name)] @@ -479,6 +484,7 @@ permission_checker = PermissionChecker() # 工具类型转换 # ============================================================================ + def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: """将 JSON Schema 类型转换为 MaiBot 的 ToolParamType""" type_mapping = { @@ -492,41 +498,43 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: return type_mapping.get(json_type, ToolParamType.STRING) -def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: +def parse_mcp_parameters( + input_schema: Dict[str, Any], +) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: """解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式""" parameters = [] - + if not input_schema: # 为无参数的工具添加占位参数,避免某些模型报错 parameters.append(("_placeholder", ToolParamType.STRING, "占位参数,无需填写", False, None)) return parameters - + properties = input_schema.get("properties", {}) required = input_schema.get("required", []) - + # 如果没有任何参数,添加占位参数 if not properties: parameters.append(("_placeholder", ToolParamType.STRING, "占位参数,无需填写", False, None)) return parameters - + for param_name, param_info in properties.items(): json_type = param_info.get("type", "string") param_type = convert_json_type_to_tool_param_type(json_type) description = param_info.get("description", f"参数 {param_name}") - + if json_type == "array": description = f"{description} (JSON 数组格式)" elif json_type == "object": description = f"{description} (JSON 对象格式)" - + is_required = param_name in required enum_values = param_info.get("enum") - + if enum_values is not None: enum_values = [str(v) for v in enum_values] - + parameters.append((param_name, param_type, description, is_required, enum_values)) - + return parameters @@ -534,28 +542,29 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa # MCP 工具代理 # ============================================================================ + class MCPToolProxy(BaseTool): """MCP 工具代理基类""" - + name: str = "" description: str = "" parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] available_for_llm: bool = True - + _mcp_tool_key: str = "" _mcp_original_name: str = "" _mcp_server_name: str = "" - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: """执行 MCP 工具调用""" global _plugin_instance - + call_id = str(uuid.uuid4())[:8] start_time = time.time() - + # 移除 MaiBot 内部标记 args = {k: v for k, v in function_args.items() if k != "llm_called"} - + # 解析 JSON 字符串参数 parsed_args = {} for key, value in args.items(): @@ -569,24 +578,21 @@ class MCPToolProxy(BaseTool): parsed_args[key] = value else: parsed_args[key] = value - + # 获取上下文信息 chat_id, user_id, is_group, user_query = self._get_context_info() - + # v1.4.0: 权限检查 if not permission_checker.check(self.name, chat_id, user_id, is_group): logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}") - return { - "name": self.name, - "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用" - } - + return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"} + logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}") - + # v1.4.0: 检查缓存 cache_hit = False cached_result = tool_call_cache.get(self.name, parsed_args) - + if cached_result is not None: cache_hit = True content = cached_result @@ -597,13 +603,13 @@ class MCPToolProxy(BaseTool): else: # 调用 MCP result = await mcp_manager.call_tool(self._mcp_tool_key, parsed_args) - + if result.success: content = result.content raw_result = content success = True error = "" - + # 存入缓存 tool_call_cache.set(self.name, parsed_args, content) else: @@ -612,7 +618,7 @@ class MCPToolProxy(BaseTool): success = False error = result.error logger.warning(f"MCP 工具 {self.name} 调用失败: {result.error}") - + # v1.3.0: 后处理 post_processed = False processed_result = content @@ -622,9 +628,9 @@ class MCPToolProxy(BaseTool): post_processed = True processed_result = processed_content content = processed_content - + duration_ms = (time.time() - start_time) * 1000 - + # v1.4.0: 记录调用追踪 record = ToolCallRecord( call_id=call_id, @@ -644,16 +650,16 @@ class MCPToolProxy(BaseTool): cache_hit=cache_hit, ) tool_call_tracer.record(record) - + return {"name": self.name, "content": content} - + def _get_context_info(self) -> Tuple[str, str, bool, str]: """获取上下文信息""" chat_id = "" user_id = "" is_group = False user_query = "" - + if self.chat_stream and hasattr(self.chat_stream, "context") and self.chat_stream.context: try: ctx = self.chat_stream.context @@ -663,53 +669,53 @@ class MCPToolProxy(BaseTool): user_id = str(ctx.user_id) if ctx.user_id else "" if hasattr(ctx, "is_group"): is_group = bool(ctx.is_group) - + last_message = ctx.get_last_message() if last_message and hasattr(last_message, "processed_plain_text"): user_query = last_message.processed_plain_text or "" except Exception as e: logger.debug(f"获取上下文信息失败: {e}") - + return chat_id, user_id, is_group, user_query async def _post_process_result(self, content: str) -> str: """v1.3.0: 对工具返回结果进行后处理(摘要提炼)""" global _plugin_instance - + if _plugin_instance is None: return content - + settings = _plugin_instance.config.get("settings", {}) - + if not settings.get("post_process_enabled", False): return content - + server_post_config = self._get_server_post_process_config() - + if server_post_config is not None: if not server_post_config.get("enabled", True): return content - + threshold = settings.get("post_process_threshold", 500) if server_post_config and "threshold" in server_post_config: threshold = server_post_config["threshold"] - + content_length = len(content) if content else 0 if content_length <= threshold: return content - + user_query = self._get_context_info()[3] if not user_query: return content - + max_tokens = settings.get("post_process_max_tokens", 500) if server_post_config and "max_tokens" in server_post_config: max_tokens = server_post_config["max_tokens"] - + prompt_template = settings.get("post_process_prompt", "") if server_post_config and "prompt" in server_post_config: prompt_template = server_post_config["prompt"] - + if not prompt_template: prompt_template = """用户问题:{query} @@ -717,13 +723,13 @@ class MCPToolProxy(BaseTool): {result} 请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:""" - + try: prompt = prompt_template.format(query=user_query, result=content) except KeyError as e: logger.warning(f"后处理 prompt 模板格式错误: {e}") return content - + try: processed_content = await self._call_post_process_llm(prompt, max_tokens, settings, server_post_config) if processed_content: @@ -733,11 +739,11 @@ class MCPToolProxy(BaseTool): except Exception as e: logger.error(f"MCP 工具 {self.name} 后处理失败: {e}") return content - + def _get_server_post_process_config(self) -> Optional[Dict[str, Any]]: """获取当前服务器的后处理配置""" global _plugin_instance - + if _plugin_instance is None: return None @@ -745,25 +751,21 @@ class MCPToolProxy(BaseTool): for server_conf in servers: if server_conf.get("name") == self._mcp_server_name: return server_conf.get("post_process") - + return None - + async def _call_post_process_llm( - self, - prompt: str, - max_tokens: int, - settings: Dict[str, Any], - server_config: Optional[Dict[str, Any]] + self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]] ) -> Optional[str]: """调用 LLM 进行后处理""" from src.config.config import model_config from src.config.model_configs import TaskConfig from src.llm_models.utils_model import LLMRequest - + model_name = settings.get("post_process_model", "") if server_config and "model" in server_config: model_name = server_config["model"] - + if model_name: task_config = TaskConfig( model_list=[model_name], @@ -773,59 +775,56 @@ class MCPToolProxy(BaseTool): ) else: task_config = model_config.model_task_config.utils - + llm_request = LLMRequest(model_set=task_config, request_type="mcp_post_process") - + response, (reasoning, model_used, _) = await llm_request.generate_response_async( prompt=prompt, max_tokens=max_tokens, temperature=0.3, ) - + return response.strip() if response else None - + def _format_error_message(self, error: str, duration_ms: float) -> str: """格式化友好的错误消息""" if not error: return "工具调用失败(未知错误)" - + error_lower = error.lower() - + if "未连接" in error or "not connected" in error_lower: return f"⚠️ MCP 服务器 [{self._mcp_server_name}] 未连接,请检查服务器状态或等待自动重连" - + if "超时" in error or "timeout" in error_lower: return f"⏱️ 工具调用超时(耗时 {duration_ms:.0f}ms),服务器响应过慢,请稍后重试" - + if "connection" in error_lower and ("closed" in error_lower or "reset" in error_lower): return f"🔌 与 MCP 服务器 [{self._mcp_server_name}] 的连接已断开,正在尝试重连..." - + if "invalid" in error_lower and "argument" in error_lower: return f"❌ 参数错误: {error}" - + return f"❌ 工具调用失败: {error}" - + async def direct_execute(self, **function_args) -> Dict[str, Any]: """直接执行(供其他插件调用)""" return await self.execute(function_args) def create_mcp_tool_class( - tool_key: str, - tool_info: MCPToolInfo, - tool_prefix: str, - disabled: bool = False + tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False ) -> Type[MCPToolProxy]: """根据 MCP 工具信息动态创建 BaseTool 子类""" parameters = parse_mcp_parameters(tool_info.input_schema) - + class_name = f"MCPTool_{tool_info.server_name}_{tool_info.name}".replace("-", "_").replace(".", "_") tool_name = tool_key.replace("-", "_").replace(".", "_") - + description = tool_info.description if not description.endswith(f"[来自 MCP 服务器: {tool_info.server_name}]"): description = f"{description} [来自 MCP 服务器: {tool_info.server_name}]" - + tool_class = type( class_name, (MCPToolProxy,), @@ -837,31 +836,27 @@ def create_mcp_tool_class( "_mcp_tool_key": tool_key, "_mcp_original_name": tool_info.name, "_mcp_server_name": tool_info.server_name, - } + }, ) - + return tool_class class MCPToolRegistry: """MCP 工具注册表""" - + def __init__(self): self._tool_classes: Dict[str, Type[MCPToolProxy]] = {} self._tool_infos: Dict[str, ToolInfo] = {} - + def register_tool( - self, - tool_key: str, - tool_info: MCPToolInfo, - tool_prefix: str, - disabled: bool = False + self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False ) -> Tuple[ToolInfo, Type[MCPToolProxy]]: """注册 MCP 工具""" tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled) - + self._tool_classes[tool_key] = tool_class - + info = ToolInfo( name=tool_class.name, tool_description=tool_class.description, @@ -870,9 +865,9 @@ class MCPToolRegistry: component_type=ComponentType.TOOL, ) self._tool_infos[tool_key] = info - + return info, tool_class - + def unregister_tool(self, tool_key: str) -> bool: """注销工具""" if tool_key in self._tool_classes: @@ -880,11 +875,11 @@ class MCPToolRegistry: del self._tool_infos[tool_key] return True return False - + def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: """获取所有工具组件""" return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - + def clear(self) -> None: """清空所有注册""" self._tool_classes.clear() @@ -902,9 +897,10 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None # 内置工具 # ============================================================================ + class MCPReadResourceTool(BaseTool): """v1.2.0: MCP 资源读取工具""" - + name = "mcp_read_resource" description = "读取 MCP 服务器提供的资源内容(如文件、数据库记录等)。使用前请先用 mcp_status 查看可用资源。" parameters = [ @@ -912,28 +908,28 @@ class MCPReadResourceTool(BaseTool): ("server_name", ToolParamType.STRING, "指定服务器名称(可选,不指定则自动查找)", False, None), ] available_for_llm = True - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: uri = function_args.get("uri", "") server_name = function_args.get("server_name") - + if not uri: return {"name": self.name, "content": "❌ 请提供资源 URI"} - + result = await mcp_manager.read_resource(uri, server_name) - + if result.success: return {"name": self.name, "content": result.content} else: return {"name": self.name, "content": f"❌ 读取资源失败: {result.error}"} - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) class MCPGetPromptTool(BaseTool): """v1.2.0: MCP 提示模板工具""" - + name = "mcp_get_prompt" description = "获取 MCP 服务器提供的提示模板内容。使用前请先用 mcp_status 查看可用模板。" parameters = [ @@ -942,29 +938,29 @@ class MCPGetPromptTool(BaseTool): ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), ] available_for_llm = True - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: prompt_name = function_args.get("name", "") arguments_str = function_args.get("arguments", "") server_name = function_args.get("server_name") - + if not prompt_name: return {"name": self.name, "content": "❌ 请提供提示模板名称"} - + arguments = None if arguments_str: try: arguments = json.loads(arguments_str) except json.JSONDecodeError: return {"name": self.name, "content": "❌ 参数格式错误,请使用 JSON 对象格式"} - + result = await mcp_manager.get_prompt(prompt_name, arguments, server_name) - + if result.success: return {"name": self.name, "content": result.content} else: return {"name": self.name, "content": f"❌ 获取提示模板失败: {result.error}"} - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) @@ -973,40 +969,41 @@ class MCPGetPromptTool(BaseTool): # v1.8.0: 工具链代理工具 # ============================================================================ + class ToolChainProxyBase(BaseTool): """工具链代理基类""" - + name: str = "" description: str = "" parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] available_for_llm: bool = True - + _chain_name: str = "" - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: """执行工具链""" # 移除内部标记 args = {k: v for k, v in function_args.items() if k != "llm_called"} - + logger.debug(f"执行工具链 {self._chain_name},参数: {args}") - + result = await tool_chain_manager.execute_chain(self._chain_name, args) - + if result.success: # 构建输出 output_parts = [] output_parts.append(result.final_output) - + # 可选:添加执行摘要 # output_parts.append(f"\n\n---\n执行摘要:\n{result.to_summary()}") - + return {"name": self.name, "content": "\n".join(output_parts)} else: error_msg = f"⚠️ 工具链执行失败: {result.error}" if result.step_results: error_msg += f"\n\n执行详情:\n{result.to_summary()}" return {"name": self.name, "content": error_msg} - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) @@ -1017,17 +1014,17 @@ def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBa parameters = [] for param_name, param_desc in chain.input_params.items(): parameters.append((param_name, ToolParamType.STRING, param_desc, True, None)) - + # 生成类名和工具名 class_name = f"ToolChain_{chain.name}".replace("-", "_").replace(".", "_") tool_name = f"chain_{chain.name}".replace("-", "_").replace(".", "_") - + # 构建描述 description = chain.description if chain.steps: step_names = [s.tool_name.split("_")[-1] for s in chain.steps[:3]] description += f" (执行流程: {' → '.join(step_names)}{'...' if len(chain.steps) > 3 else ''})" - + tool_class = type( class_name, (ToolChainProxyBase,), @@ -1037,25 +1034,25 @@ def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBa "parameters": parameters, "available_for_llm": True, "_chain_name": chain.name, - } + }, ) - + return tool_class class ToolChainRegistry: """工具链注册表""" - + def __init__(self): self._tool_classes: Dict[str, Type[ToolChainProxyBase]] = {} self._tool_infos: Dict[str, ToolInfo] = {} - + def register_chain(self, chain: ToolChainDefinition) -> Tuple[ToolInfo, Type[ToolChainProxyBase]]: """注册工具链为组合工具""" tool_class = create_chain_tool_class(chain) - + self._tool_classes[chain.name] = tool_class - + info = ToolInfo( name=tool_class.name, tool_description=tool_class.description, @@ -1064,9 +1061,9 @@ class ToolChainRegistry: component_type=ComponentType.TOOL, ) self._tool_infos[chain.name] = info - + return info, tool_class - + def unregister_chain(self, chain_name: str) -> bool: """注销工具链""" if chain_name in self._tool_classes: @@ -1074,11 +1071,11 @@ class ToolChainRegistry: del self._tool_infos[chain_name] return True return False - + def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: """获取所有工具链组件""" return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - + def clear(self) -> None: """清空所有注册""" self._tool_classes.clear() @@ -1091,52 +1088,55 @@ tool_chain_registry = ToolChainRegistry() class MCPStatusTool(BaseTool): """MCP 状态查询工具""" - + name = "mcp_status" description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息" parameters = [ - ("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"]), + ( + "query_type", + ToolParamType.STRING, + "查询类型", + False, + ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"], + ), ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), ] available_for_llm = True - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: query_type = function_args.get("query_type", "status") server_name = function_args.get("server_name") - + result_parts = [] - + if query_type in ("status", "all"): result_parts.append(self._format_status(server_name)) - + if query_type in ("tools", "all"): result_parts.append(self._format_tools(server_name)) - + if query_type in ("chains", "all"): result_parts.append(self._format_chains()) - + if query_type in ("resources", "all"): result_parts.append(self._format_resources(server_name)) - + if query_type in ("prompts", "all"): result_parts.append(self._format_prompts(server_name)) - + if query_type in ("stats", "all"): result_parts.append(self._format_stats(server_name)) - + # v1.4.0: 追踪记录 if query_type in ("trace",): result_parts.append(self._format_trace()) - + # v1.4.0: 缓存状态 if query_type in ("cache",): result_parts.append(self._format_cache()) - - return { - "name": self.name, - "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型" - } - + + return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"} + def _format_status(self, server_name: Optional[str] = None) -> str: status = mcp_manager.get_status() lines = ["📊 MCP 桥接插件状态"] @@ -1145,24 +1145,24 @@ class MCPStatusTool(BaseTool): lines.append(f" 已断开: {status['disconnected_servers']}") lines.append(f" 可用工具数: {status['total_tools']}") lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}") - + lines.append("\n🔌 服务器详情:") - for name, info in status['servers'].items(): + for name, info in status["servers"].items(): if server_name and name != server_name: continue - status_icon = "✅" if info['connected'] else "❌" - enabled_text = "" if info['enabled'] else " (已禁用)" + status_icon = "✅" if info["connected"] else "❌" + enabled_text = "" if info["enabled"] else " (已禁用)" lines.append(f" {status_icon} {name}{enabled_text}") lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}") - if info['consecutive_failures'] > 0: + if info["consecutive_failures"] > 0: lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次") - + return "\n".join(lines) - + def _format_tools(self, server_name: Optional[str] = None) -> str: tools = mcp_manager.all_tools lines = ["🔧 可用 MCP 工具"] - + by_server: Dict[str, List[str]] = {} for tool_key, (tool_info, _) in tools.items(): if server_name and tool_info.server_name != server_name: @@ -1170,78 +1170,78 @@ class MCPStatusTool(BaseTool): if tool_info.server_name not in by_server: by_server[tool_info.server_name] = [] by_server[tool_info.server_name].append(f" • {tool_key}: {tool_info.description[:50]}...") - + for srv_name, tool_list in by_server.items(): lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个工具):") lines.extend(tool_list) - + if not by_server: lines.append(" (无可用工具)") - + return "\n".join(lines) - + def _format_stats(self, server_name: Optional[str] = None) -> str: stats = mcp_manager.get_all_stats() lines = ["📈 调用统计"] - - g = stats['global'] + + g = stats["global"] lines.append(f" 总调用次数: {g['total_tool_calls']}") lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}") - if g['total_tool_calls'] > 0: - success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100 + if g["total_tool_calls"] > 0: + success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100 lines.append(f" 成功率: {success_rate:.1f}%") lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒") - + return "\n".join(lines) - + def _format_resources(self, server_name: Optional[str] = None) -> str: resources = mcp_manager.all_resources if not resources: return "📦 当前没有可用的 MCP 资源" - + lines = ["📦 可用 MCP 资源"] by_server: Dict[str, List[MCPResourceInfo]] = {} - for key, (resource_info, _) in resources.items(): + for _key, (resource_info, _) in resources.items(): if server_name and resource_info.server_name != server_name: continue if resource_info.server_name not in by_server: by_server[resource_info.server_name] = [] by_server[resource_info.server_name].append(resource_info) - + for srv_name, resource_list in by_server.items(): lines.append(f"\n🔌 {srv_name} ({len(resource_list)} 个资源):") for res in resource_list: lines.append(f" • {res.name}: {res.uri}") - + return "\n".join(lines) - + def _format_prompts(self, server_name: Optional[str] = None) -> str: prompts = mcp_manager.all_prompts if not prompts: return "📝 当前没有可用的 MCP 提示模板" - + lines = ["📝 可用 MCP 提示模板"] by_server: Dict[str, List[MCPPromptInfo]] = {} - for key, (prompt_info, _) in prompts.items(): + for _key, (prompt_info, _) in prompts.items(): if server_name and prompt_info.server_name != server_name: continue if prompt_info.server_name not in by_server: by_server[prompt_info.server_name] = [] by_server[prompt_info.server_name].append(prompt_info) - + for srv_name, prompt_list in by_server.items(): lines.append(f"\n🔌 {srv_name} ({len(prompt_list)} 个模板):") for prompt in prompt_list: lines.append(f" • {prompt.name}") - + return "\n".join(lines) - + def _format_trace(self) -> str: """v1.4.0: 格式化追踪记录""" records = tool_call_tracer.get_recent(10) if not records: return "🔍 暂无调用追踪记录" - + lines = ["🔍 最近调用追踪记录"] for r in reversed(records): status = "✅" if r.success else "❌" @@ -1250,9 +1250,9 @@ class MCPStatusTool(BaseTool): lines.append(f" {status}{cache}{post} {r.tool_name} ({r.duration_ms:.0f}ms)") if r.error: lines.append(f" 错误: {r.error[:50]}") - + return "\n".join(lines) - + def _format_cache(self) -> str: """v1.4.0: 格式化缓存状态""" stats = tool_call_cache.get_stats() @@ -1263,13 +1263,13 @@ class MCPStatusTool(BaseTool): lines.append(f" 命中: {stats['hits']}, 未命中: {stats['misses']}") lines.append(f" 命中率: {stats['hit_rate']}") return "\n".join(lines) - + def _format_chains(self) -> str: """v1.8.0: 格式化工具链列表""" chains = tool_chain_manager.get_all_chains() if not chains: return "🔗 当前没有配置工具链" - + lines = ["🔗 工具链列表"] for name, chain in chains.items(): status = "✅" if chain.enabled else "❌" @@ -1277,15 +1277,15 @@ class MCPStatusTool(BaseTool): lines.append(f" 描述: {chain.description[:50]}...") lines.append(f" 步骤: {len(chain.steps)} 个") for i, step in enumerate(chain.steps[:3]): - lines.append(f" {i+1}. {step.tool_name}") + lines.append(f" {i + 1}. {step.tool_name}") if len(chain.steps) > 3: lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤") if chain.input_params: params = ", ".join(chain.input_params.keys()) lines.append(f" 参数: {params}") - + return "\n".join(lines) - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) @@ -1294,6 +1294,7 @@ class MCPStatusTool(BaseTool): # 命令处理 # ============================================================================ + class MCPStatusCommand(BaseCommand): """MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态""" @@ -1308,27 +1309,27 @@ class MCPStatusCommand(BaseCommand): if subcommand == "reconnect": return await self._handle_reconnect(arg) - + # v1.4.0: 追踪命令 if subcommand == "trace": return await self._handle_trace(arg) - + # v1.4.0: 缓存命令 if subcommand == "cache": return await self._handle_cache(arg) - + # v1.4.0: 权限命令 if subcommand == "perm": return await self._handle_perm(arg) - + # v1.6.0: 导出命令 if subcommand == "export": return await self._handle_export(arg) - + # v1.7.0: 工具搜索命令 if subcommand == "search": return await self._handle_search(arg) - + # v1.8.0: 工具链命令 if subcommand == "chain": return await self._handle_chain(arg) @@ -1341,7 +1342,7 @@ class MCPStatusCommand(BaseCommand): """查找相似的服务器名称""" name_lower = name.lower() all_servers = list(mcp_manager._clients.keys()) - + # 简单的相似度匹配:包含关系或前缀匹配 similar = [] for srv in all_servers: @@ -1350,7 +1351,7 @@ class MCPStatusCommand(BaseCommand): similar.append(srv) elif srv_lower.startswith(name_lower[:3]) if len(name_lower) >= 3 else False: similar.append(srv) - + return similar[:max_results] async def _handle_reconnect(self, server_name: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: @@ -1384,7 +1385,7 @@ class MCPStatusCommand(BaseCommand): await self.send_text(f"{status} {srv}") return (True, None, True) - + async def _handle_trace(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.4.0: 处理追踪命令""" if arg and arg.isdigit(): @@ -1397,11 +1398,11 @@ class MCPStatusCommand(BaseCommand): else: # /mcp trace - 最近 10 条 records = tool_call_tracer.get_recent(10) - + if not records: await self.send_text("🔍 暂无调用追踪记录\n\n用法: /mcp trace [数量|工具名]") return (True, None, True) - + lines = [f"🔍 调用追踪记录 ({len(records)} 条)"] lines.append("-" * 30) for i, r in enumerate(reversed(records)): @@ -1415,17 +1416,17 @@ class MCPStatusCommand(BaseCommand): lines.append(f" 错误: {r.error[:50]}") if i < len(records) - 1: lines.append("") - + await self.send_text("\n".join(lines)) return (True, None, True) - + async def _handle_cache(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.4.0: 处理缓存命令""" if arg == "clear": tool_call_cache.clear() await self.send_text("✅ 缓存已清空") return (True, None, True) - + stats = tool_call_cache.get_stats() lines = ["🗄️ 缓存状态"] lines.append(f"├ 启用: {'是' if stats['enabled'] else '否'}") @@ -1434,22 +1435,22 @@ class MCPStatusCommand(BaseCommand): lines.append(f"├ 命中: {stats['hits']}") lines.append(f"├ 未命中: {stats['misses']}") lines.append(f"└ 命中率: {stats['hit_rate']}") - + await self.send_text("\n".join(lines)) return (True, None, True) - + async def _handle_perm(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.4.0: 处理权限命令""" global _plugin_instance - + if _plugin_instance is None: await self.send_text("❌ 插件未初始化") return (True, None, True) - + perm_config = _plugin_instance.config.get("permissions", {}) enabled = perm_config.get("perm_enabled", False) default_mode = perm_config.get("perm_default_mode", "allow_all") - + if arg: # 查看特定工具的权限 rules = permission_checker.get_rules_for_tool(arg) @@ -1478,17 +1479,17 @@ class MCPStatusCommand(BaseCommand): lines.append(f"├ 管理员白名单: {allow_count} 人") lines.append(f"└ 高级规则: {len(permission_checker._rules)} 条") await self.send_text("\n".join(lines)) - + return (True, None, True) - + async def _handle_export(self, format_type: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.6.0: 处理导出命令""" global _plugin_instance - + if _plugin_instance is None: await self.send_text("❌ 插件未初始化") return (True, None, True) - + servers_section = _plugin_instance.config.get("servers", {}) if not isinstance(servers_section, dict): servers_section = {} @@ -1513,7 +1514,7 @@ class MCPStatusCommand(BaseCommand): lines.append("") lines.append(pretty) await self.send_text("\n".join(lines)) - + return (True, None, True) async def _handle_search(self, query: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: @@ -1558,7 +1559,7 @@ class MCPStatusCommand(BaseCommand): # 按服务器分组显示 by_server: Dict[str, List[Tuple[str, Any]]] = {} - for tool_key, tool_info, client in matched: + for tool_key, tool_info, _client in matched: server_name = tool_info.server_name if server_name not in by_server: by_server[server_name] = [] @@ -1570,11 +1571,11 @@ class MCPStatusCommand(BaseCommand): for srv_name, tool_list in by_server.items(): lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个):") - + # 单服务器或结果少于 15 个时显示全部 show_all = single_server or len(matched) <= 15 display_limit = len(tool_list) if show_all else 5 - + for tool_key, tool_info in tool_list[:display_limit]: desc = tool_info.description[:40] + "..." if len(tool_info.description) > 40 else tool_info.description lines.append(f" • {tool_key}") @@ -1590,10 +1591,10 @@ class MCPStatusCommand(BaseCommand): if not arg or not arg.strip(): # 显示工具链列表和帮助 chains = tool_chain_manager.get_all_chains() - + lines = ["🔗 工具链管理"] lines.append("") - + if chains: lines.append(f"已配置 {len(chains)} 个工具链:") for name, chain in chains.items(): @@ -1602,7 +1603,7 @@ class MCPStatusCommand(BaseCommand): lines.append(f" {status} {name} ({steps_count} 步)") else: lines.append("当前没有配置工具链") - + lines.append("") lines.append("命令:") lines.append(" /mcp chain list 查看所有工具链") @@ -1611,20 +1612,20 @@ class MCPStatusCommand(BaseCommand): lines.append(" /mcp chain reload 重新加载配置") lines.append("") lines.append("💡 在 WebUI「工具链」配置区编辑工具链") - + await self.send_text("\n".join(lines)) return (True, None, True) - + parts = arg.strip().split(maxsplit=2) sub_action = parts[0].lower() - + if sub_action == "list": # 列出所有工具链 chains = tool_chain_manager.get_all_chains() if not chains: await self.send_text("🔗 当前没有配置工具链") return (True, None, True) - + lines = [f"🔗 工具链列表 ({len(chains)} 个)"] for name, chain in chains.items(): status = "✅" if chain.enabled else "❌" @@ -1633,10 +1634,10 @@ class MCPStatusCommand(BaseCommand): lines.append(f" 步骤: {' → '.join([s.tool_name.split('_')[-1] for s in chain.steps[:4]])}") if chain.input_params: lines.append(f" 参数: {', '.join(chain.input_params.keys())}") - + await self.send_text("\n".join(lines)) return (True, None, True) - + elif sub_action == "reload": # 重新加载工具链配置 global _plugin_instance @@ -1644,8 +1645,9 @@ class MCPStatusCommand(BaseCommand): _plugin_instance._load_tool_chains() chains = tool_chain_manager.get_all_chains() from src.plugin_system.core.component_registry import component_registry + registered = 0 - for name, chain in tool_chain_manager.get_enabled_chains().items(): + for name, _chain in tool_chain_manager.get_enabled_chains().items(): tool_name = f"chain_{name}".replace("-", "_").replace(".", "_") if component_registry.get_component_info(tool_name, ComponentType.TOOL): registered += 1 @@ -1662,27 +1664,27 @@ class MCPStatusCommand(BaseCommand): else: await self.send_text("❌ 插件未初始化") return (True, None, True) - + elif sub_action == "test" and len(parts) >= 2: # 测试执行工具链 chain_name = parts[1] args_json = parts[2] if len(parts) > 2 else "{}" - + chain = tool_chain_manager.get_chain(chain_name) if not chain: await self.send_text(f"❌ 工具链 '{chain_name}' 不存在") return (True, None, True) - + try: input_args = json.loads(args_json) except json.JSONDecodeError: await self.send_text("❌ 参数 JSON 格式错误") return (True, None, True) - + await self.send_text(f"🔄 正在执行工具链 {chain_name}...") - + result = await tool_chain_manager.execute_chain(chain_name, input_args) - + lines = [] if result.success: lines.append(f"✅ 工具链执行成功 ({result.total_duration_ms:.0f}ms)") @@ -1702,15 +1704,15 @@ class MCPStatusCommand(BaseCommand): lines.append("") lines.append("执行详情:") lines.append(result.to_summary()) - + await self.send_text("\n".join(lines)) return (True, None, True) - + else: # 查看特定工具链详情 chain_name = sub_action chain = tool_chain_manager.get_chain(chain_name) - + if not chain: # 尝试模糊匹配 all_chains = tool_chain_manager.get_all_chains() @@ -1720,22 +1722,22 @@ class MCPStatusCommand(BaseCommand): msg += f"\n💡 你是不是想找: {', '.join(similar[:3])}" await self.send_text(msg) return (True, None, True) - + lines = [f"🔗 工具链: {chain.name}"] lines.append(f"状态: {'✅ 启用' if chain.enabled else '❌ 禁用'}") lines.append(f"描述: {chain.description}") lines.append("") - + if chain.input_params: lines.append("📥 输入参数:") for param, desc in chain.input_params.items(): lines.append(f" • {param}: {desc}") lines.append("") - + lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):") for i, step in enumerate(chain.steps): optional_tag = " (可选)" if step.optional else "" - lines.append(f" {i+1}. {step.tool_name}{optional_tag}") + lines.append(f" {i + 1}. {step.tool_name}{optional_tag}") if step.description: lines.append(f" {step.description}") if step.output_key: @@ -1743,10 +1745,10 @@ class MCPStatusCommand(BaseCommand): if step.args_template: args_preview = json.dumps(step.args_template, ensure_ascii=False)[:60] lines.append(f" 参数: {args_preview}...") - + lines.append("") lines.append(f"💡 测试: /mcp chain test {chain.name} " + '{"参数": "值"}') - + await self.send_text("\n".join(lines)) return (True, None, True) @@ -1786,14 +1788,14 @@ class MCPStatusCommand(BaseCommand): if tools: lines.append("\n🔧 可用工具:") by_server = {} - for key, (info, _) in tools.items(): + for _key, (info, _) in tools.items(): if server_name and info.server_name != server_name: continue by_server.setdefault(info.server_name, []).append(info.name) # 如果指定了服务器名,显示全部工具;否则折叠显示 show_all = server_name is not None - + for srv, tool_list in by_server.items(): lines.append(f" 📦 {srv} ({len(tool_list)})") if show_all: @@ -1861,13 +1863,13 @@ class MCPImportCommand(BaseCommand): async def execute(self) -> Tuple[bool, Optional[str], bool]: """执行导入命令""" global _plugin_instance - + if _plugin_instance is None: await self.send_text("❌ 插件未初始化") return (True, None, True) - + content = self.matched_groups.get("content", "") - + if not content or not content.strip(): # 显示使用帮助 help_text = """📥 MCP 配置导入 @@ -1983,52 +1985,53 @@ class MCPImportCommand(BaseCommand): # 事件处理器 # ============================================================================ + class MCPStartupHandler(BaseEventHandler): """MCP 启动事件处理器""" - + event_type = EventType.ON_START handler_name = "mcp_startup_handler" handler_description = "MCP 桥接插件启动处理器" weight = 0 intercept_message = False - + async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: """处理启动事件""" global _plugin_instance - + if _plugin_instance is None: logger.warning("MCP 桥接插件实例未初始化") return (False, True, None, None, None) - + logger.info("MCP 桥接插件收到 ON_START 事件,开始连接 MCP 服务器...") await _plugin_instance._async_connect_servers() - + await mcp_manager.start_heartbeat() - + return (True, True, None, None, None) class MCPStopHandler(BaseEventHandler): """MCP 停止事件处理器""" - + event_type = EventType.ON_STOP handler_name = "mcp_stop_handler" handler_description = "MCP 桥接插件停止处理器" weight = 0 intercept_message = False - + async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: """处理停止事件""" global _plugin_instance - + logger.info("MCP 桥接插件收到 ON_STOP 事件,正在关闭...") if _plugin_instance is not None: await _plugin_instance._stop_status_refresher() - + await mcp_manager.shutdown() mcp_tool_registry.clear() - + logger.info("MCP 桥接插件已关闭所有连接") return (True, True, None, None, None) @@ -2037,16 +2040,17 @@ class MCPStopHandler(BaseEventHandler): # 主插件类 # ============================================================================ + @register_plugin class MCPBridgePlugin(BasePlugin): """MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot""" - + plugin_name: str = "mcp_bridge_plugin" enable_plugin: bool = False # 默认禁用,用户需在 WebUI 手动启用 dependencies: List[str] = [] python_dependencies: List[str] = ["mcp"] config_file_name: str = "config.toml" - + config_section_descriptions = { "guide": section_meta("📖 快速入门", order=1), "plugin": section_meta("🔘 插件开关", order=2), @@ -2058,7 +2062,7 @@ class MCPBridgePlugin(BasePlugin): "permissions": section_meta("🔐 权限控制", collapsed=True, order=21), "settings": section_meta("⚙️ 高级设置", collapsed=True, order=30), } - + config_schema: dict = { # 新手引导区(只读) "guide": { @@ -2505,7 +2509,7 @@ class MCPBridgePlugin(BasePlugin): label="📋 工具链列表", input_type="textarea", rows=20, - placeholder='''[ + placeholder="""[ { "name": "search_and_detail", "description": "先搜索再获取详情", @@ -2515,7 +2519,7 @@ class MCPBridgePlugin(BasePlugin): {"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}} ] } -]''', +]""", hint="每个工具链包含 name、description、input_params、steps", order=30, ), @@ -2653,9 +2657,9 @@ mcp_bing_*""", label="📜 高级权限规则(可选)", input_type="textarea", rows=10, - placeholder='''[ + placeholder="""[ {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} -]''', +]""", hint="格式: qq:ID:group/private/user,工具名支持通配符 *", order=10, ), @@ -2710,43 +2714,43 @@ mcp_bing_*""", ), }, } - + @staticmethod def _fix_config_multiline_strings(config_path: Path) -> bool: """修复配置文件中的多行字符串格式问题 - + 处理两种情况: 1. 带转义 \\n 的单行字符串(json.dumps 生成) 2. 跨越多行但使用普通双引号的字符串(控制字符错误) - + Returns: bool: 是否进行了修复 """ if not config_path.exists(): return False - + try: content = config_path.read_text(encoding="utf-8") - + # 情况1: 修复带转义 \n 的单行字符串 # 匹配: key = "内容包含\n的字符串" pattern1 = r'^(\s*\w+\s*=\s*)"((?:[^"\\]|\\.)*\\n(?:[^"\\]|\\.)*)"(\s*)$' - + # 情况2: 修复跨越多行的普通双引号字符串 # 匹配: key = "第一行 # 第二行 # 第三行" pattern2_start = r'^(\s*\w+\s*=\s*)"([^"]*?)$' # 开始行 pattern2_end = r'^([^"]*)"(\s*)$' # 结束行 - + lines = content.split("\n") fixed_lines = [] modified = False - + i = 0 while i < len(lines): line = lines[i] - + # 情况1: 单行带转义换行符 match1 = re.match(pattern1, line) if match1: @@ -2754,24 +2758,26 @@ mcp_bing_*""", value = match1.group(2) suffix = match1.group(3) # 将转义的换行符还原为实际换行符 - unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") + unescaped = ( + value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") + ) fixed_line = f'{prefix}"""{unescaped}"""{suffix}' fixed_lines.append(fixed_line) modified = True i += 1 continue - + # 情况2: 跨越多行的字符串 match2_start = re.match(pattern2_start, line) if match2_start: prefix = match2_start.group(1) first_part = match2_start.group(2) - + # 收集后续行直到找到结束引号 multiline_parts = [first_part] j = i + 1 found_end = False - + while j < len(lines): next_line = lines[j] match2_end = re.match(pattern2_end, next_line) @@ -2784,7 +2790,7 @@ mcp_bing_*""", else: multiline_parts.append(next_line) j += 1 - + if found_end and len(multiline_parts) > 1: # 合并为三引号字符串 full_value = "\n".join(multiline_parts) @@ -2793,27 +2799,27 @@ mcp_bing_*""", modified = True i = j continue - + fixed_lines.append(line) i += 1 - + if modified: config_path.write_text("\n".join(fixed_lines), encoding="utf-8") logger.info("已自动修复配置文件中的多行字符串格式") return True - + return False except Exception as e: logger.warning(f"修复配置文件格式失败: {e}") return False - + def __init__(self, *args, **kwargs): global _plugin_instance - + # 在父类初始化前尝试修复配置文件格式 config_path = Path(__file__).parent / "config.toml" self._fix_config_multiline_strings(config_path) - + super().__init__(*args, **kwargs) self._initialized = False self._status_refresh_running = False @@ -2821,11 +2827,11 @@ mcp_bing_*""", self._last_persisted_display_hash: str = "" self._last_servers_config_error: str = "" _plugin_instance = self - + # 配置 MCP 管理器 settings = self.config.get("settings", {}) mcp_manager.configure(settings) - + # v1.4.0: 配置追踪器 trace_log_path = Path(__file__).parent / "logs" / "trace.jsonl" tool_call_tracer.configure( @@ -2834,7 +2840,7 @@ mcp_bing_*""", log_enabled=settings.get("trace_log_enabled", False), log_path=trace_log_path, ) - + # v1.4.0: 配置缓存 tool_call_cache.configure( enabled=settings.get("cache_enabled", False), @@ -2842,7 +2848,7 @@ mcp_bing_*""", max_entries=settings.get("cache_max_entries", 200), exclude_tools=settings.get("cache_exclude_tools", ""), ) - + # v1.4.0: 配置权限检查器 perm_config = self.config.get("permissions", {}) permission_checker.configure( @@ -2852,12 +2858,12 @@ mcp_bing_*""", quick_deny_groups=perm_config.get("quick_deny_groups", ""), quick_allow_users=perm_config.get("quick_allow_users", ""), ) - + # 注册状态变化回调 mcp_manager.set_status_change_callback(self._update_status_display) - + # v2.0: 服务器配置统一由 servers.claude_config_json 提供(不再通过 WebUI 导入/快速添加写入旧 servers.list) - + # v1.8.0: 初始化工具链管理器 tool_chain_manager.set_executor(mcp_manager) self._load_tool_chains() @@ -2881,38 +2887,38 @@ mcp_bing_*""", self._last_persisted_display_hash = digest except Exception as e: logger.debug(f"写回运行状态到配置文件失败: {e}") - + def _process_quick_add_chain(self) -> None: """v1.8.0: 处理快速添加工具链表单""" chains_config = self.config.get("tool_chains", {}) - + # 检查是否触发添加 add_trigger = chains_config.get("quick_chain_add", "").strip().upper() if add_trigger != "ADD": return - + # 获取表单数据 chain_name = chains_config.get("quick_chain_name", "").strip() chain_desc = chains_config.get("quick_chain_desc", "").strip() params_str = chains_config.get("quick_chain_params", "").strip() steps_str = chains_config.get("quick_chain_steps", "").strip() - + # 验证必填字段 if not chain_name: logger.warning("快速添加工具链: 名称不能为空") self._clear_quick_chain_fields() return - + if not chain_desc: logger.warning("快速添加工具链: 描述不能为空") self._clear_quick_chain_fields() return - + if not steps_str: logger.warning("快速添加工具链: 步骤不能为空") self._clear_quick_chain_fields() return - + # 解析输入参数 input_params = {} if params_str: @@ -2924,41 +2930,43 @@ mcp_bing_*""", param_name = parts[0].strip() param_desc = parts[1].strip() if len(parts) > 1 else param_name input_params[param_name] = param_desc - + # 解析步骤 steps = [] for line in steps_str.split("\n"): line = line.strip() if not line: continue - + parts = line.split("|") if len(parts) < 2: logger.warning(f"快速添加工具链: 步骤格式错误: {line}") continue - + tool_name = parts[0].strip() args_str = parts[1].strip() if len(parts) > 1 else "{}" output_key = parts[2].strip() if len(parts) > 2 else "" - + # 解析参数 JSON try: args_template = json.loads(args_str) if args_str else {} except json.JSONDecodeError: logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}") args_template = {} - - steps.append({ - "tool_name": tool_name, - "args_template": args_template, - "output_key": output_key, - }) - + + steps.append( + { + "tool_name": tool_name, + "args_template": args_template, + "output_key": output_key, + } + ) + if not steps: logger.warning("快速添加工具链: 没有有效的步骤") self._clear_quick_chain_fields() return - + # 构建新工具链 new_chain = { "name": chain_name, @@ -2967,36 +2975,36 @@ mcp_bing_*""", "steps": steps, "enabled": True, } - + # 获取现有工具链列表 chains_json = chains_config.get("chains_list", "[]") try: chains_list = json.loads(chains_json) if chains_json.strip() else [] except json.JSONDecodeError: chains_list = [] - + # 检查是否已存在同名工具链 for existing in chains_list: if existing.get("name") == chain_name: logger.info(f"快速添加: 工具链 {chain_name} 已存在,将更新") chains_list.remove(existing) break - + # 添加新工具链 chains_list.append(new_chain) new_chains_json = json.dumps(chains_list, ensure_ascii=False, indent=2) - + # 更新配置 self.config["tool_chains"]["chains_list"] = new_chains_json - + # 清空表单字段 self._clear_quick_chain_fields() - + # 保存到配置文件 self._save_chains_list(new_chains_json) - + logger.info(f"快速添加: 已添加工具链 {chain_name} ({len(steps)} 个步骤)") - + def _clear_quick_chain_fields(self) -> None: """清空快速添加工具链表单字段""" if "tool_chains" not in self.config: @@ -3006,7 +3014,7 @@ mcp_bing_*""", self.config["tool_chains"]["quick_chain_params"] = "" self.config["tool_chains"]["quick_chain_steps"] = "" self.config["tool_chains"]["quick_chain_add"] = "" - + def _save_chains_list(self, chains_json: str) -> None: """保存工具链列表到配置文件""" try: @@ -3015,12 +3023,12 @@ mcp_bing_*""", logger.info("工具链列表已保存到配置文件") except Exception as e: logger.warning(f"保存工具链列表失败: {e}") - + def _load_tool_chains(self) -> None: """v1.8.0: 加载工具链配置""" # 先处理快速添加 self._process_quick_add_chain() - + chains_config = self.config.get("tool_chains", {}) if not isinstance(chains_config, dict): chains_config = {} @@ -3056,7 +3064,9 @@ mcp_bing_*""", if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict): self.config["tool_chains"] = {} self.config["tool_chains"]["chains_list"] = chains_json - logger.info("检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)") + logger.info( + "检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)" + ) chains_config = self.config.get("tool_chains", {}) if not isinstance(chains_config, dict): @@ -3065,32 +3075,32 @@ mcp_bing_*""", if not chains_config.get("chains_enabled", True): logger.info("工具链功能已禁用") return - + chains_json = str(chains_config.get("chains_list", "[]") or "") if not chains_json or not chains_json.strip(): return - + # 清空现有工具链 tool_chain_manager.clear() tool_chain_registry.clear() - + # 加载新配置 loaded, errors = tool_chain_manager.load_from_json(chains_json) - + if errors: for err in errors: logger.warning(f"工具链配置错误: {err}") - + if loaded > 0: logger.info(f"已加载 {loaded} 个工具链") # 注册工具链到组件系统 self._register_tool_chains() self._update_chains_status_display() - + def _register_tool_chains(self) -> None: """v1.8.1: 将工具链注册到 MaiBot 组件系统,使 LLM 可调用""" from src.plugin_system.core.component_registry import component_registry - + chain_count = 0 for chain_name, chain in tool_chain_manager.get_enabled_chains().items(): try: @@ -3110,16 +3120,16 @@ mcp_bing_*""", logger.warning(f"⚠️ 工具链注册被跳过(可能已存在): {tool_class.name}") except Exception as e: logger.error(f"注册工具链 {chain_name} 失败: {e}") - + if chain_count > 0: logger.info(f"已注册 {chain_count} 个工具链到组件系统") - + def _register_tools_to_react(self) -> int: """v1.9.0: 将 MCP 工具注册到记忆检索 ReAct 系统(软流程) - + 这样 MaiBot 的 ReAct Agent 在检索记忆时可以调用 MCP 工具, 实现 LLM 自主决策的多轮工具调用。 - + Returns: int: 成功注册的工具数量 """ @@ -3128,36 +3138,33 @@ mcp_bing_*""", except ImportError: logger.warning("无法导入记忆检索工具注册模块,跳过 ReAct 工具注册") return 0 - + react_config = self.config.get("react", {}) filter_mode = react_config.get("filter_mode", "whitelist") tool_filter = react_config.get("tool_filter", "").strip() - + # 解析过滤列表(支持 # 注释) filter_patterns = [] for line in tool_filter.split("\n"): line = line.strip() if line and not line.startswith("#"): filter_patterns.append(line) - + registered_count = 0 disabled_tools = self._get_disabled_tools() registered_tools = [] # 记录已注册的工具名 - + for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): tool_name = tool_key.replace("-", "_").replace(".", "_") - + # 跳过禁用的工具 if tool_name in disabled_tools: continue - + # 应用过滤器 if filter_patterns: - matched = any( - fnmatch.fnmatch(tool_name, p) or tool_name == p - for p in filter_patterns - ) - + matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns) + if filter_mode == "whitelist": # 白名单模式:只注册匹配的 if not matched: @@ -3166,23 +3173,24 @@ mcp_bing_*""", # 黑名单模式:排除匹配的 if matched: continue - + try: # 转换参数格式 parameters = self._convert_mcp_params_to_react_format(tool_info.input_schema) - + # 创建异步执行函数(使用闭包捕获 tool_key) def make_execute_func(tk: str): - async def execute_func(**kwargs) -> str: + async def _execute_func(**kwargs) -> str: result = await mcp_manager.call_tool(tk, kwargs) if result.success: return result.content or "(无返回内容)" else: return f"工具调用失败: {result.error}" - return execute_func - + + return _execute_func + execute_func = make_execute_func(tool_key) - + # 注册到 ReAct 系统 register_memory_retrieval_tool( name=f"mcp_{tool_name}", @@ -3190,24 +3198,26 @@ mcp_bing_*""", parameters=parameters, execute_func=execute_func, ) - + registered_count += 1 registered_tools.append(f"mcp_{tool_name}") logger.debug(f"🔄 注册 ReAct 工具: mcp_{tool_name}") - + except Exception as e: logger.warning(f"注册 ReAct 工具 {tool_name} 失败: {e}") - + if registered_count > 0: mode_str = "白名单" if filter_mode == "whitelist" else "黑名单" logger.info(f"已注册 {registered_count} 个 MCP 工具到 ReAct 系统 (过滤模式: {mode_str})") - + # 更新状态显示 self._update_react_status_display(registered_tools, filter_mode, filter_patterns) - + return registered_count - - def _update_react_status_display(self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]) -> None: + + def _update_react_status_display( + self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str] + ) -> None: """更新 ReAct 工具状态显示""" if not registered_tools: status_text = "(未注册任何工具)" @@ -3222,40 +3232,42 @@ mcp_bing_*""", if len(registered_tools) > 20: lines.append(f" ... 还有 {len(registered_tools) - 20} 个") status_text = "\n".join(lines) - + # 更新内存配置 if "react" not in self.config: self.config["react"] = {} self.config["react"]["react_status"] = status_text - + def _convert_mcp_params_to_react_format(self, input_schema: Dict) -> List[Dict[str, Any]]: """将 MCP 工具参数转换为 ReAct 工具参数格式""" parameters = [] - + if not input_schema: return parameters - + properties = input_schema.get("properties", {}) required = input_schema.get("required", []) - + for param_name, param_info in properties.items(): param_type = param_info.get("type", "string") description = param_info.get("description", f"参数 {param_name}") is_required = param_name in required - - parameters.append({ - "name": param_name, - "type": param_type, - "description": description, - "required": is_required, - }) - + + parameters.append( + { + "name": param_name, + "type": param_type, + "description": description, + "required": is_required, + } + ) + return parameters - + def _update_chains_status_display(self) -> None: """v1.8.0: 更新工具链状态显示""" chains = tool_chain_manager.get_all_chains() - + if not chains: status_text = "(无工具链配置)" else: @@ -3265,40 +3277,41 @@ mcp_bing_*""", # 显示工具链基本信息 lines.append(f"{status} chain_{name}") lines.append(f" 描述: {chain.description[:40]}{'...' if len(chain.description) > 40 else ''}") - + # 显示输入参数 if chain.input_params: params = ", ".join(chain.input_params.keys()) lines.append(f" 参数: {params}") - + # 显示步骤 lines.append(f" 步骤: {len(chain.steps)} 个") for i, step in enumerate(chain.steps): opt = " (可选)" if step.optional else "" out = f" → {step.output_key}" if step.output_key else "" - lines.append(f" {i+1}. {step.tool_name}{out}{opt}") + lines.append(f" {i + 1}. {step.tool_name}{out}{opt}") lines.append("") - + status_text = "\n".join(lines) - + # 更新内存配置 if "tool_chains" not in self.config: self.config["tool_chains"] = {} self.config["tool_chains"]["chains_status"] = status_text - + def _get_disabled_tools(self) -> set: """v1.4.0: 获取禁用的工具列表""" tools_config = self.config.get("tools", {}) disabled_str = tools_config.get("disabled_tools", "") return {t.strip() for t in disabled_str.strip().split("\n") if t.strip()} - + async def _async_connect_servers(self) -> None: """异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)""" import asyncio + settings = self.config.get("settings", {}) servers_config = self._load_mcp_servers_config() - + if not servers_config: logger.warning("未配置任何 MCP 服务器") self._initialized = True @@ -3308,7 +3321,7 @@ mcp_bing_*""", self._start_status_refresher() self._persist_runtime_displays() return - + auto_connect = settings.get("auto_connect", True) if not auto_connect: logger.info("auto_connect 已禁用,跳过自动连接") @@ -3319,27 +3332,27 @@ mcp_bing_*""", self._start_status_refresher() self._persist_runtime_displays() return - + tool_prefix = settings.get("tool_prefix", "mcp") disabled_tools = self._get_disabled_tools() enable_resources = settings.get("enable_resources", False) enable_prompts = settings.get("enable_prompts", False) - + # 解析所有服务器配置 enabled_configs: List[MCPServerConfig] = [] for idx, server_conf in enumerate(servers_config): server_name = server_conf.get("name", f"unknown_{idx}") - + if not server_conf.get("enabled", True): logger.info(f"服务器 {server_name} 已禁用,跳过") continue - + try: config = self._parse_server_config(server_conf) enabled_configs.append(config) except Exception as e: logger.error(f"解析服务器 {server_name} 配置失败: {e}") - + if not enabled_configs: logger.warning("没有已启用的 MCP 服务器") self._initialized = True @@ -3349,9 +3362,9 @@ mcp_bing_*""", self._start_status_refresher() self._persist_runtime_displays() return - + logger.info(f"准备并行连接 {len(enabled_configs)} 个 MCP 服务器") - + # v1.5.0: 并行连接所有服务器 async def connect_single_server(config: MCPServerConfig) -> Tuple[MCPServerConfig, bool]: """连接单个服务器""" @@ -3377,15 +3390,12 @@ mcp_bing_*""", except Exception as e: logger.error(f"❌ 服务器 {config.name} 连接异常: {e}") return config, False - + # 并行执行所有连接 start_time = time.time() - results = await asyncio.gather( - *[connect_single_server(cfg) for cfg in enabled_configs], - return_exceptions=True - ) + results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True) connect_duration = time.time() - start_time - + # 统计连接结果 success_count = 0 failed_count = 0 @@ -3399,49 +3409,50 @@ mcp_bing_*""", success_count += 1 else: failed_count += 1 - + logger.info(f"并行连接完成: {success_count} 成功, {failed_count} 失败, 耗时 {connect_duration:.2f}s") - + # 注册所有工具 from src.plugin_system.core.component_registry import component_registry + registered_count = 0 - + for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): tool_name = tool_key.replace("-", "_").replace(".", "_") is_disabled = tool_name in disabled_tools - - info, tool_class = mcp_tool_registry.register_tool( - tool_key, tool_info, tool_prefix, disabled=is_disabled - ) + + info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled) info.plugin_name = self.plugin_name - + if component_registry.register_component(info, tool_class): registered_count += 1 status = "🚫" if is_disabled else "✅" logger.info(f"{status} 注册 MCP 工具: {tool_class.name}") else: logger.warning(f"❌ 注册 MCP 工具失败: {tool_class.name}") - + chains_config = self.config.get("tool_chains", {}) chains_enabled = bool(chains_config.get("chains_enabled", True)) if isinstance(chains_config, dict) else True chain_count = len(tool_chain_manager.get_enabled_chains()) if chains_enabled else 0 - + # v1.9.0: 注册 MCP 工具到记忆检索 ReAct 系统(软流程) react_count = 0 react_config = self.config.get("react", {}) if react_config.get("react_enabled", False): react_count = self._register_tools_to_react() - + self._initialized = True - logger.info(f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具") - + logger.info( + f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具" + ) + # 更新状态显示 self._update_status_display() self._update_tool_list_display() self._update_chains_status_display() self._start_status_refresher() self._persist_runtime_displays() - + def _start_status_refresher(self) -> None: """启动 WebUI 状态刷新任务(不写入磁盘)""" task = getattr(self, "_status_refresh_task", None) @@ -3508,7 +3519,9 @@ mcp_bing_*""", logger.info("检测到旧版 servers.list,已自动迁移为 Claude mcpServers(请在 WebUI 保存一次以固化)") if not claude_json.strip(): - self._last_servers_config_error = "未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)" + self._last_servers_config_error = ( + "未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)" + ) return [] try: @@ -3553,11 +3566,11 @@ mcp_bing_*""", configs.append(cfg) return configs - + def _parse_server_config(self, conf: Dict) -> MCPServerConfig: """解析服务器配置字典""" transport_str = conf.get("transport", "stdio").lower() - + transport_map = { "stdio": TransportType.STDIO, "sse": TransportType.SSE, @@ -3565,7 +3578,7 @@ mcp_bing_*""", "streamable_http": TransportType.STREAMABLE_HTTP, } transport = transport_map.get(transport_str, TransportType.STDIO) - + return MCPServerConfig( name=conf.get("name", "unnamed"), enabled=conf.get("enabled", True), @@ -3576,39 +3589,39 @@ mcp_bing_*""", url=conf.get("url", ""), headers=conf.get("headers", {}), # v1.4.2: 鉴权头支持 ) - + def _update_tool_list_display(self) -> None: """v1.4.0: 更新工具列表显示""" tools = mcp_manager.all_tools disabled_tools = self._get_disabled_tools() - + lines = [] by_server: Dict[str, List[str]] = {} - + for tool_key, (tool_info, _) in tools.items(): tool_name = tool_key.replace("-", "_").replace(".", "_") if tool_info.server_name not in by_server: by_server[tool_info.server_name] = [] - + is_disabled = tool_name in disabled_tools status = " ❌" if is_disabled else "" by_server[tool_info.server_name].append(f" • {tool_name}{status}") - + for srv_name, tool_list in by_server.items(): lines.append(f"📦 {srv_name} ({len(tool_list)}个工具):") lines.extend(tool_list) lines.append("") - + if not by_server: lines.append("(无已注册工具)") - + tool_list_text = "\n".join(lines) - + # 更新内存配置 if "tools" not in self.config: self.config["tools"] = {} self.config["tools"]["tool_list"] = tool_list_text - + def _update_status_display(self) -> None: """更新配置文件中的状态显示字段""" status = mcp_manager.get_status() @@ -3619,7 +3632,7 @@ mcp_bing_*""", if cfg_err: lines.append(f"⚠️ 配置: {cfg_err}") lines.append("") - + lines.append(f"服务器: {status['connected_servers']}/{status['total_servers']} 已连接") lines.append(f"工具数: {status['total_tools']}") if settings.get("enable_resources", False): @@ -3628,13 +3641,13 @@ mcp_bing_*""", lines.append(f"模板数: {status.get('total_prompts', 0)}") lines.append(f"心跳: {'运行中' if status['heartbeat_running'] else '已停止'}") lines.append("") - + tools = mcp_manager.all_tools - + for name, info in status.get("servers", {}).items(): icon = "✅" if info["connected"] else "❌" lines.append(f"{icon} {name} ({info['transport']})") - + # v1.7.0: 显示断路器状态 cb_status = info.get("circuit_breaker", {}) cb_state = cb_status.get("state", "closed") @@ -3642,35 +3655,35 @@ mcp_bing_*""", lines.append(" ⚡ 断路器: 熔断中") elif cb_state == "half_open": lines.append(" ⚡ 断路器: 试探中") - + server_tools = [t.name for key, (t, _) in tools.items() if t.server_name == name] if server_tools: for tool_name in server_tools: lines.append(f" • {tool_name}") else: lines.append(" (无工具)") - + if not status.get("servers"): lines.append("(无服务器)") - + status_text = "\n".join(lines) - + if "status" not in self.config: self.config["status"] = {} self.config["status"]["connection_status"] = status_text - + def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: """返回插件的所有组件""" components: List[Tuple[ComponentInfo, Type]] = [] - + # 事件处理器 components.append((MCPStartupHandler.get_handler_info(), MCPStartupHandler)) components.append((MCPStopHandler.get_handler_info(), MCPStopHandler)) - + # 命令 components.append((MCPStatusCommand.get_command_info(), MCPStatusCommand)) components.append((MCPImportCommand.get_command_info(), MCPImportCommand)) - + # 内置工具 status_tool_info = ToolInfo( name=MCPStatusTool.name, @@ -3680,9 +3693,9 @@ mcp_bing_*""", component_type=ComponentType.TOOL, ) components.append((status_tool_info, MCPStatusTool)) - + settings = self.config.get("settings", {}) - + if settings.get("enable_resources", False): read_resource_info = ToolInfo( name=MCPReadResourceTool.name, @@ -3692,7 +3705,7 @@ mcp_bing_*""", component_type=ComponentType.TOOL, ) components.append((read_resource_info, MCPReadResourceTool)) - + if settings.get("enable_prompts", False): get_prompt_info = ToolInfo( name=MCPGetPromptTool.name, @@ -3702,9 +3715,9 @@ mcp_bing_*""", component_type=ComponentType.TOOL, ) components.append((get_prompt_info, MCPGetPromptTool)) - + return components - + def get_status(self) -> Dict[str, Any]: """获取插件状态""" return { @@ -3714,7 +3727,7 @@ mcp_bing_*""", "trace_records": tool_call_tracer.total_records, "cache_stats": tool_call_cache.get_stats(), } - + def get_stats(self) -> Dict[str, Any]: """获取详细统计信息""" return mcp_manager.get_all_stats() diff --git a/plugins/MaiBot_MCPBridgePlugin/tool_chain.py b/plugins/MaiBot_MCPBridgePlugin/tool_chain.py index 96837b94..6a1530cc 100644 --- a/plugins/MaiBot_MCPBridgePlugin/tool_chain.py +++ b/plugins/MaiBot_MCPBridgePlugin/tool_chain.py @@ -22,21 +22,24 @@ from typing import Any, Dict, List, Optional, Tuple try: from src.common.logger import get_logger + logger = get_logger("mcp_tool_chain") except ImportError: import logging + logger = logging.getLogger("mcp_tool_chain") @dataclass class ToolChainStep: """工具链步骤""" + tool_name: str # 要调用的工具名(如 mcp_server_tool) args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换 output_key: str = "" # 输出存储的键名,供后续步骤引用 description: str = "" # 步骤描述 optional: bool = False # 是否可选(失败时继续执行) - + def to_dict(self) -> Dict[str, Any]: return { "tool_name": self.tool_name, @@ -45,7 +48,7 @@ class ToolChainStep: "description": self.description, "optional": self.optional, } - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep": return cls( @@ -60,12 +63,13 @@ class ToolChainStep: @dataclass class ToolChainDefinition: """工具链定义""" + name: str # 工具链名称(将作为组合工具的名称) description: str # 工具链描述(供 LLM 理解) steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤 input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述} enabled: bool = True # 是否启用 - + def to_dict(self) -> Dict[str, Any]: return { "name": self.name, @@ -74,7 +78,7 @@ class ToolChainDefinition: "input_params": self.input_params, "enabled": self.enabled, } - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition": steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])] @@ -90,12 +94,13 @@ class ToolChainDefinition: @dataclass class ChainExecutionResult: """工具链执行结果""" + success: bool final_output: str # 最终输出(最后一个步骤的结果) step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果 error: str = "" total_duration_ms: float = 0.0 - + def to_summary(self) -> str: """生成执行摘要""" lines = [] @@ -103,7 +108,7 @@ class ChainExecutionResult: status = "✅" if step.get("success") else "❌" tool = step.get("tool_name", "unknown") duration = step.get("duration_ms", 0) - lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)") + lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)") if not step.get("success") and step.get("error"): lines.append(f" 错误: {step['error'][:50]}") return "\n".join(lines) @@ -111,49 +116,49 @@ class ChainExecutionResult: class ToolChainExecutor: """工具链执行器""" - + # 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev} - VAR_PATTERN = re.compile(r'\$\{([^}]+)\}') - + VAR_PATTERN = re.compile(r"\$\{([^}]+)\}") + def __init__(self, mcp_manager): self._mcp_manager = mcp_manager - + def _resolve_tool_key(self, tool_name: str) -> Optional[str]: """解析工具名,返回有效的 tool_key - + 支持: - 直接使用 tool_key(如 mcp_server_tool) - 使用注册后的工具名(会自动转换 - 和 . 为 _) """ all_tools = self._mcp_manager.all_tools - + # 直接匹配 if tool_name in all_tools: return tool_name - + # 尝试转换后匹配(用户可能使用了注册后的名称) normalized = tool_name.replace("-", "_").replace(".", "_") if normalized in all_tools: return normalized - + # 尝试查找包含该名称的工具 for key in all_tools.keys(): if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"): return key - + return None - + async def execute( self, chain: ToolChainDefinition, input_args: Dict[str, Any], ) -> ChainExecutionResult: """执行工具链 - + Args: chain: 工具链定义 input_args: 用户输入的参数 - + Returns: ChainExecutionResult: 执行结果 """ @@ -164,15 +169,15 @@ class ToolChainExecutor: "step": {}, # 各步骤输出,按 output_key 存储 "prev": "", # 上一步的输出 } - + final_output = "" - + # 验证必需的输入参数 missing_params = [] for param_name in chain.input_params.keys(): if param_name not in context["input"]: missing_params.append(param_name) - + if missing_params: return ChainExecutionResult( success=False, @@ -180,7 +185,7 @@ class ToolChainExecutor: error=f"缺少必需参数: {', '.join(missing_params)}", total_duration_ms=(time.time() - start_time) * 1000, ) - + for i, step in enumerate(chain.steps): step_start = time.time() step_result = { @@ -191,96 +196,96 @@ class ToolChainExecutor: "error": "", "duration_ms": 0, } - + try: # 替换参数中的变量 resolved_args = self._resolve_args(step.args_template, context) step_result["resolved_args"] = resolved_args - + # 解析工具名 tool_key = self._resolve_tool_key(step.tool_name) if not tool_key: step_result["error"] = f"工具 {step.tool_name} 不存在" - logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在") - + logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在") + if not step.optional: step_results.append(step_result) return ChainExecutionResult( success=False, final_output="", step_results=step_results, - error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在", + error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在", total_duration_ms=(time.time() - start_time) * 1000, ) step_results.append(step_result) continue - - logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}") - + + logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}") + # 调用工具 result = await self._mcp_manager.call_tool(tool_key, resolved_args) - + step_duration = (time.time() - step_start) * 1000 step_result["duration_ms"] = step_duration - + if result.success: step_result["success"] = True # 确保 content 不为 None content = result.content if result.content is not None else "" step_result["output"] = content - + # 更新上下文 context["prev"] = content if step.output_key: context["step"][step.output_key] = content - + final_output = content content_preview = content[:100] if content else "(空)" - logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...") + logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...") else: step_result["error"] = result.error or "未知错误" - logger.warning(f"工具链步骤 {i+1} 失败: {result.error}") - + logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}") + if not step.optional: step_results.append(step_result) return ChainExecutionResult( success=False, final_output="", step_results=step_results, - error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}", + error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}", total_duration_ms=(time.time() - start_time) * 1000, ) - + except Exception as e: step_duration = (time.time() - step_start) * 1000 step_result["duration_ms"] = step_duration step_result["error"] = str(e) - logger.error(f"工具链步骤 {i+1} 异常: {e}") - + logger.error(f"工具链步骤 {i + 1} 异常: {e}") + if not step.optional: step_results.append(step_result) return ChainExecutionResult( success=False, final_output="", step_results=step_results, - error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}", + error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}", total_duration_ms=(time.time() - start_time) * 1000, ) - + step_results.append(step_result) - + total_duration = (time.time() - start_time) * 1000 - + return ChainExecutionResult( success=True, final_output=final_output, step_results=step_results, total_duration_ms=total_duration, ) - + def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: """解析参数模板,替换变量 - + 支持的变量格式: - ${input.param_name}: 用户输入的参数 - ${step.output_key}: 某个步骤的输出 @@ -288,50 +293,48 @@ class ToolChainExecutor: - ${prev.field}: 上一步输出(JSON)的某个字段 """ resolved = {} - + for key, value in args_template.items(): if isinstance(value, str): resolved[key] = self._substitute_vars(value, context) elif isinstance(value, dict): resolved[key] = self._resolve_args(value, context) elif isinstance(value, list): - resolved[key] = [ - self._substitute_vars(v, context) if isinstance(v, str) else v - for v in value - ] + resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value] else: resolved[key] = value - + return resolved - + def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str: """替换字符串中的变量""" + def replacer(match): var_path = match.group(1) return self._get_var_value(var_path, context) - + return self.VAR_PATTERN.sub(replacer, template) - + def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str: """获取变量值 - + Args: var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id" context: 上下文 """ parts = self._parse_var_path(var_path) - + if not parts: return "" - + # 获取根对象 root = parts[0] if root not in context: logger.warning(f"变量 {var_path} 的根 '{root}' 不存在") return "" - + value = context[root] - + # 遍历路径 for part in parts[1:]: if isinstance(value, str): @@ -349,7 +352,7 @@ class ToolChainExecutor: value = "" else: value = "" - + # 确保返回字符串 if isinstance(value, (dict, list)): return json.dumps(value, ensure_ascii=False) @@ -448,39 +451,39 @@ class ToolChainExecutor: class ToolChainManager: """工具链管理器""" - + _instance: Optional["ToolChainManager"] = None - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - + def __init__(self): if self._initialized: return self._initialized = True self._chains: Dict[str, ToolChainDefinition] = {} self._executor: Optional[ToolChainExecutor] = None - + def set_executor(self, mcp_manager) -> None: """设置执行器""" self._executor = ToolChainExecutor(mcp_manager) - + def add_chain(self, chain: ToolChainDefinition) -> bool: """添加工具链""" if not chain.name: logger.error("工具链名称不能为空") return False - + if chain.name in self._chains: logger.warning(f"工具链 {chain.name} 已存在,将被覆盖") - + self._chains[chain.name] = chain logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)") return True - + def remove_chain(self, name: str) -> bool: """移除工具链""" if name in self._chains: @@ -488,19 +491,19 @@ class ToolChainManager: logger.info(f"已移除工具链: {name}") return True return False - + def get_chain(self, name: str) -> Optional[ToolChainDefinition]: """获取工具链""" return self._chains.get(name) - + def get_all_chains(self) -> Dict[str, ToolChainDefinition]: """获取所有工具链""" return self._chains.copy() - + def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]: """获取所有启用的工具链""" return {name: chain for name, chain in self._chains.items() if chain.enabled} - + async def execute_chain( self, chain_name: str, @@ -514,64 +517,64 @@ class ToolChainManager: final_output="", error=f"工具链 {chain_name} 不存在", ) - + if not chain.enabled: return ChainExecutionResult( success=False, final_output="", error=f"工具链 {chain_name} 已禁用", ) - + if not self._executor: return ChainExecutionResult( success=False, final_output="", error="工具链执行器未初始化", ) - + return await self._executor.execute(chain, input_args) - + def load_from_json(self, json_str: str) -> Tuple[int, List[str]]: """从 JSON 字符串加载工具链配置 - + Returns: (成功加载数量, 错误列表) """ errors = [] loaded = 0 - + try: data = json.loads(json_str) if json_str.strip() else [] except json.JSONDecodeError as e: return 0, [f"JSON 解析失败: {e}"] - + if not isinstance(data, list): data = [data] - + for i, item in enumerate(data): try: chain = ToolChainDefinition.from_dict(item) if not chain.name: - errors.append(f"第 {i+1} 个工具链缺少名称") + errors.append(f"第 {i + 1} 个工具链缺少名称") continue if not chain.steps: errors.append(f"工具链 {chain.name} 没有步骤") continue - + self.add_chain(chain) loaded += 1 except Exception as e: - errors.append(f"第 {i+1} 个工具链解析失败: {e}") - + errors.append(f"第 {i + 1} 个工具链解析失败: {e}") + return loaded, errors - + def export_to_json(self, pretty: bool = True) -> str: """导出所有工具链为 JSON""" chains_data = [chain.to_dict() for chain in self._chains.values()] if pretty: return json.dumps(chains_data, ensure_ascii=False, indent=2) return json.dumps(chains_data, ensure_ascii=False) - + def clear(self) -> None: """清空所有工具链""" self._chains.clear() diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index ecf67cd9..76567817 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -238,7 +238,7 @@ class TestCommand(BaseCommand): chat_stream=self.message.chat_stream, reply_reason=reply_reason, enable_chinese_typo=False, - extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"", + extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"', ) if result_status: # 发送生成的回复 diff --git a/pytests/config_test/test_config_base.py b/pytests/config_test/test_config_base.py index fea2fe8e..f67c8c56 100644 --- a/pytests/config_test/test_config_base.py +++ b/pytests/config_test/test_config_base.py @@ -46,6 +46,7 @@ def patch_attrdoc_post_init(): config_base_module.logger = logging.getLogger("config_base_test_logger") + class SimpleClass(ConfigBase): a: int = 1 b: str = "test" @@ -282,7 +283,7 @@ class TestConfigBase: True, "ConfigBase is not Hashable", id="listset-validation-set-configbase-element_reject", - ) + ), ], ) def test_validate_list_set_type(self, annotation, expect_error, error_fragment): @@ -340,7 +341,7 @@ class TestConfigBase: False, None, id="dict-validation-happy-configbase-value", - ) + ), ], ) def test_validate_dict_type(self, annotation, expect_error, error_fragment): @@ -353,13 +354,11 @@ class TestConfigBase: field_name = "mapping" if expect_error: - # Act / Assert with pytest.raises(TypeError) as exc_info: dummy._validate_dict_type(annotation, field_name) assert error_fragment in str(exc_info.value) else: - # Act dummy._validate_dict_type(annotation, field_name) @@ -392,7 +391,7 @@ class TestConfigBase: # Assert assert "字段'field_y'中使用了 Any 类型注解" in caplog.text - + def test_discourage_any_usage_suppressed_warning(self, caplog): class Sample(ConfigBase): _validate_any: bool = False diff --git a/pytests/image_sys_test/image_manager_test.py b/pytests/image_sys_test/image_manager_test.py index 360ba50c..ccc852df 100644 --- a/pytests/image_sys_test/image_manager_test.py +++ b/pytests/image_sys_test/image_manager_test.py @@ -4,7 +4,6 @@ import importlib import pytest from pathlib import Path import importlib.util -import asyncio class DummyLogger: @@ -71,6 +70,7 @@ class DummyLLMRequest: async def generate_response_for_image(self, prompt, image_base64, image_format, temp): return ("dummy description", {}) + class DummySelect: def __init__(self, *a, **k): pass @@ -81,6 +81,7 @@ class DummySelect: def limit(self, n): return self + @pytest.fixture(autouse=True) def patch_external_dependencies(monkeypatch): # Provide dummy implementations as modules so that importing image_manager is safe @@ -103,11 +104,11 @@ def patch_external_dependencies(monkeypatch): # Patch MaiImage data model data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage) monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod) - + # Patch SQLModel select function sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect()) monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod) - + # Patch config values used at import-time cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style")) model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm")) @@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None): if tmp_path is not None: tmpdir = Path(tmp_path) tmpdir.mkdir(parents=True, exist_ok=True) - setattr(mod, "IMAGE_DIR", tmpdir) + mod.IMAGE_DIR = tmpdir except Exception: pass return mod @@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path): # cleanup should run without error mgr.cleanup_invalid_descriptions_in_db() - diff --git a/pytests/webui/test_config_schema.py b/pytests/webui/test_config_schema.py index 41c3f78f..8bbedd80 100644 --- a/pytests/webui/test_config_schema.py +++ b/pytests/webui/test_config_schema.py @@ -1,5 +1,3 @@ -import pytest - from src.config.official_configs import ChatConfig from src.config.config import Config from src.webui.config_schema import ConfigSchemaGenerator diff --git a/pytests/webui/test_emoji_routes.py b/pytests/webui/test_emoji_routes.py index 8bfb5f46..92900726 100644 --- a/pytests/webui/test_emoji_routes.py +++ b/pytests/webui/test_emoji_routes.py @@ -387,7 +387,7 @@ def test_auth_required_list(client): """测试未认证访问列表端点(401)""" # Without mock_token_verify fixture with patch("src.webui.routers.emoji.verify_auth_token", return_value=False): - response = client.get("/emoji/list") + client.get("/emoji/list") # verify_auth_token 返回 False 会触发 HTTPException # 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现 # 这里假设它抛出 401 @@ -397,7 +397,7 @@ def test_auth_required_update(client, sample_emojis): """测试未认证访问更新端点(401)""" with patch("src.webui.routers.emoji.verify_auth_token", return_value=False): emoji_id = sample_emojis[0].id - response = client.patch(f"/emoji/{emoji_id}", json={"description": "test"}) + client.patch(f"/emoji/{emoji_id}", json={"description": "test"}) # Should be unauthorized diff --git a/pytests/webui/test_expression_routes.py b/pytests/webui/test_expression_routes.py index 0be7a4d7..3dcd9fba 100644 --- a/pytests/webui/test_expression_routes.py +++ b/pytests/webui/test_expression_routes.py @@ -1,6 +1,5 @@ """Expression routes pytest tests""" -from datetime import datetime from typing import Generator from unittest.mock import MagicMock @@ -12,7 +11,6 @@ from sqlalchemy import text from sqlmodel import Session, SQLModel, create_engine, select from src.common.database.database_model import Expression -from src.common.database.database import get_db_session def create_test_app() -> FastAPI: diff --git a/scripts/analyze_evaluation_stats.py b/scripts/analyze_evaluation_stats.py index 9eec6243..e18243d1 100644 --- a/scripts/analyze_evaluation_stats.py +++ b/scripts/analyze_evaluation_stats.py @@ -19,7 +19,7 @@ from typing import Dict, List, Set, Tuple project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, project_root) -from src.common.logger import get_logger +from src.common.logger import get_logger # noqa: E402 logger = get_logger("evaluation_stats_analyzer") @@ -38,10 +38,10 @@ def parse_datetime(dt_str: str) -> datetime | None: def analyze_single_file(file_path: str) -> Dict: """ 分析单个JSON文件的统计信息 - + Args: file_path: JSON文件路径 - + Returns: 统计信息字典 """ @@ -65,40 +65,40 @@ def analyze_single_file(file_path: str) -> Dict: "has_reason": False, "reason_count": 0, } - + try: with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) - + # 基本信息 stats["last_updated"] = data.get("last_updated") stats["total_count"] = data.get("total_count", 0) - + results = data.get("manual_results", []) stats["actual_count"] = len(results) - + if not results: return stats - + # 统计通过/不通过 suitable_count = sum(1 for r in results if r.get("suitable") is True) unsuitable_count = sum(1 for r in results if r.get("suitable") is False) stats["suitable_count"] = suitable_count stats["unsuitable_count"] = unsuitable_count stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0 - + # 统计唯一的(situation, style)对 pairs: Set[Tuple[str, str]] = set() for r in results: if "situation" in r and "style" in r: pairs.add((r["situation"], r["style"])) stats["unique_pairs"] = len(pairs) - + # 统计评估者 for r in results: evaluator = r.get("evaluator", "unknown") stats["evaluators"][evaluator] += 1 - + # 统计评估时间 evaluation_dates = [] for r in results: @@ -107,7 +107,7 @@ def analyze_single_file(file_path: str) -> Dict: dt = parse_datetime(evaluated_at) if dt: evaluation_dates.append(dt) - + stats["evaluation_dates"] = evaluation_dates if evaluation_dates: min_date = min(evaluation_dates) @@ -115,18 +115,18 @@ def analyze_single_file(file_path: str) -> Dict: stats["date_range"] = { "start": min_date.isoformat(), "end": max_date.isoformat(), - "duration_days": (max_date - min_date).days + 1 + "duration_days": (max_date - min_date).days + 1, } - + # 检查字段存在性 stats["has_expression_id"] = any("expression_id" in r for r in results) stats["has_reason"] = any(r.get("reason") for r in results) stats["reason_count"] = sum(1 for r in results if r.get("reason")) - + except Exception as e: stats["error"] = str(e) logger.error(f"分析文件 {file_name} 时出错: {e}") - + return stats @@ -136,57 +136,57 @@ def print_file_stats(stats: Dict, index: int = None): print(f"\n{'=' * 80}") print(f"{prefix}文件: {stats['file_name']}") print(f"{'=' * 80}") - + if stats["error"]: print(f"✗ 错误: {stats['error']}") return - + print(f"文件路径: {stats['file_path']}") print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)") - + if stats["last_updated"]: print(f"最后更新: {stats['last_updated']}") - + print("\n【记录统计】") print(f" 文件中的 total_count: {stats['total_count']}") print(f" 实际记录数: {stats['actual_count']}") - - if stats['total_count'] != stats['actual_count']: - diff = stats['total_count'] - stats['actual_count'] + + if stats["total_count"] != stats["actual_count"]: + diff = stats["total_count"] - stats["actual_count"] print(f" ⚠️ 数量不一致,差值: {diff:+d}") - + print("\n【评估结果统计】") print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)") print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)") - + print("\n【唯一性统计】") print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条") - if stats['actual_count'] > 0: - duplicate_count = stats['actual_count'] - stats['unique_pairs'] - duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 + if stats["actual_count"] > 0: + duplicate_count = stats["actual_count"] - stats["unique_pairs"] + duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0 print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)") - + print("\n【评估者统计】") - if stats['evaluators']: - for evaluator, count in stats['evaluators'].most_common(): - rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 + if stats["evaluators"]: + for evaluator, count in stats["evaluators"].most_common(): + rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0 print(f" {evaluator}: {count} 条 ({rate:.2f}%)") else: print(" 无评估者信息") - + print("\n【时间统计】") - if stats['date_range']: + if stats["date_range"]: print(f" 最早评估时间: {stats['date_range']['start']}") print(f" 最晚评估时间: {stats['date_range']['end']}") print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天") else: print(" 无时间信息") - + print("\n【字段统计】") print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}") print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}") - if stats['has_reason']: - rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 + if stats["has_reason"]: + rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0 print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)") @@ -195,35 +195,35 @@ def print_summary(all_stats: List[Dict]): print(f"\n{'=' * 80}") print("汇总统计") print(f"{'=' * 80}") - + total_files = len(all_stats) valid_files = [s for s in all_stats if not s.get("error")] error_files = [s for s in all_stats if s.get("error")] - + print("\n【文件统计】") print(f" 总文件数: {total_files}") print(f" 成功解析: {len(valid_files)}") print(f" 解析失败: {len(error_files)}") - + if error_files: print("\n 失败文件列表:") for stats in error_files: print(f" - {stats['file_name']}: {stats['error']}") - + if not valid_files: print("\n没有成功解析的文件") return - + # 汇总记录统计 - total_records = sum(s['actual_count'] for s in valid_files) - total_suitable = sum(s['suitable_count'] for s in valid_files) - total_unsuitable = sum(s['unsuitable_count'] for s in valid_files) + total_records = sum(s["actual_count"] for s in valid_files) + total_suitable = sum(s["suitable_count"] for s in valid_files) + total_unsuitable = sum(s["unsuitable_count"] for s in valid_files) total_unique_pairs = set() - + # 收集所有唯一的(situation, style)对 for stats in valid_files: try: - with open(stats['file_path'], "r", encoding="utf-8") as f: + with open(stats["file_path"], "r", encoding="utf-8") as f: data = json.load(f) results = data.get("manual_results", []) for r in results: @@ -231,23 +231,31 @@ def print_summary(all_stats: List[Dict]): total_unique_pairs.add((r["situation"], r["style"])) except Exception: pass - + print("\n【记录汇总】") print(f" 总记录数: {total_records:,} 条") - print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条") - print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条") + print( + f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" + if total_records > 0 + else " 通过: 0 条" + ) + print( + f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" + if total_records > 0 + else " 不通过: 0 条" + ) print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条") - + if total_records > 0: duplicate_count = total_records - len(total_unique_pairs) duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0 print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)") - + # 汇总评估者统计 all_evaluators = Counter() for stats in valid_files: - all_evaluators.update(stats['evaluators']) - + all_evaluators.update(stats["evaluators"]) + print("\n【评估者汇总】") if all_evaluators: for evaluator, count in all_evaluators.most_common(): @@ -255,12 +263,12 @@ def print_summary(all_stats: List[Dict]): print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)") else: print(" 无评估者信息") - + # 汇总时间范围 all_dates = [] for stats in valid_files: - all_dates.extend(stats['evaluation_dates']) - + all_dates.extend(stats["evaluation_dates"]) + if all_dates: min_date = min(all_dates) max_date = max(all_dates) @@ -268,9 +276,9 @@ def print_summary(all_stats: List[Dict]): print(f" 最早评估时间: {min_date.isoformat()}") print(f" 最晚评估时间: {max_date.isoformat()}") print(f" 总时间跨度: {(max_date - min_date).days + 1} 天") - + # 文件大小汇总 - total_size = sum(s['file_size'] for s in valid_files) + total_size = sum(s["file_size"] for s in valid_files) avg_size = total_size / len(valid_files) if valid_files else 0 print("\n【文件大小汇总】") print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)") @@ -282,35 +290,35 @@ def main(): logger.info("=" * 80) logger.info("开始分析评估结果统计信息") logger.info("=" * 80) - + if not os.path.exists(TEMP_DIR): print(f"\n✗ 错误:未找到temp目录: {TEMP_DIR}") logger.error(f"未找到temp目录: {TEMP_DIR}") return - + # 查找所有JSON文件 json_files = glob.glob(os.path.join(TEMP_DIR, "*.json")) - + if not json_files: print(f"\n✗ 错误:temp目录下未找到JSON文件: {TEMP_DIR}") logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}") return - + json_files.sort() # 按文件名排序 - + print(f"\n找到 {len(json_files)} 个JSON文件") print("=" * 80) - + # 分析每个文件 all_stats = [] for i, json_file in enumerate(json_files, 1): stats = analyze_single_file(json_file) all_stats.append(stats) print_file_stats(stats, index=i) - + # 打印汇总统计 print_summary(all_stats) - + print(f"\n{'=' * 80}") print("分析完成") print(f"{'=' * 80}") @@ -318,5 +326,3 @@ def main(): if __name__ == "__main__": main() - - diff --git a/scripts/delete_lpmm_items.py b/scripts/delete_lpmm_items.py index 2eb37ded..e6e40fea 100644 --- a/scripts/delete_lpmm_items.py +++ b/scripts/delete_lpmm_items.py @@ -171,7 +171,9 @@ def main(): sys.exit(1) if not args.raw_index: - logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3") + logger.info( + f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3" + ) sys.exit(1) # 解析索引列表(1-based) diff --git a/scripts/evaluate_expressions_count_analysis.py b/scripts/evaluate_expressions_count_analysis.py index db1f4e71..1bff5b69 100644 --- a/scripts/evaluate_expressions_count_analysis.py +++ b/scripts/evaluate_expressions_count_analysis.py @@ -22,11 +22,11 @@ from collections import defaultdict project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, project_root) -from src.common.database.database_model import Expression -from src.common.database.database import db -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config +from src.common.database.database_model import Expression # noqa: E402 +from src.common.database.database import db # noqa: E402 +from src.common.logger import get_logger # noqa: E402 +from src.llm_models.utils_model import LLMRequest # noqa: E402 +from src.config.config import model_config # noqa: E402 logger = get_logger("expression_evaluator_count_analysis_llm") @@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results. def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: """ 加载已有的评估结果 - + Returns: (已有结果列表, 已评估的项目(situation, style)元组集合) """ if not os.path.exists(COUNT_ANALYSIS_FILE): return [], set() - + try: with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f: data = json.load(f) @@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: def save_results(evaluation_results: List[Dict]): """ 保存评估结果到文件 - + Args: evaluation_results: 评估结果列表 """ try: os.makedirs(TEMP_DIR, exist_ok=True) - + data = { "last_updated": datetime.now().isoformat(), "total_count": len(evaluation_results), - "evaluation_results": evaluation_results + "evaluation_results": evaluation_results, } - + with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - + logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}") print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)") except Exception as e: @@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]): print(f"\n✗ 保存评估结果失败: {e}") -def select_expressions_for_evaluation( - evaluated_pairs: Set[Tuple[str, str]] = None -) -> List[Expression]: +def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]: """ 选择用于评估的表达方式 选择所有count>1的项目,然后选择两倍数量的count=1的项目 - + Args: evaluated_pairs: 已评估的项目集合,用于避免重复 - + Returns: 选中的表达方式列表 """ if evaluated_pairs is None: evaluated_pairs = set() - + try: # 查询所有表达方式 all_expressions = list(Expression.select()) - + if not all_expressions: logger.warning("数据库中没有表达方式记录") return [] - + # 过滤出未评估的项目 - unevaluated = [ - expr for expr in all_expressions - if (expr.situation, expr.style) not in evaluated_pairs - ] - + unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs] + if not unevaluated: logger.warning("所有项目都已评估完成") return [] - + # 按count分组 count_eq1 = [expr for expr in unevaluated if expr.count == 1] count_gt1 = [expr for expr in unevaluated if expr.count > 1] - + logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条") - + # 选择所有count>1的项目 selected_count_gt1 = count_gt1.copy() - + # 选择count=1的项目,数量为count>1数量的2倍 count_gt1_count = len(selected_count_gt1) count_eq1_needed = count_gt1_count * 2 - + if len(count_eq1) < count_eq1_needed: - logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条") + logger.warning( + f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条" + ) count_eq1_needed = len(count_eq1) - + # 随机选择count=1的项目 selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else [] - + selected = selected_count_gt1 + selected_count_eq1 random.shuffle(selected) # 打乱顺序 - - logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)") - + + logger.info( + f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)" + ) + return selected - + except Exception as e: logger.error(f"选择表达方式失败: {e}") import traceback + logger.error(traceback.format_exc()) return [] @@ -155,11 +155,11 @@ def select_expressions_for_evaluation( def create_evaluation_prompt(situation: str, style: str) -> str: """ 创建评估提示词 - + Args: situation: 情境 style: 风格 - + Returns: 评估提示词 """ @@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str: }} 如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。 请严格按照JSON格式输出,不要包含其他内容。""" - + return prompt async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]: """ 执行单次LLM评估 - + Args: situation: 情境 style: 风格 llm: LLM请求实例 - + Returns: (suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息 """ try: prompt = create_evaluation_prompt(situation, style) logger.debug(f"正在评估表达方式: situation={situation}, style={style}") - + response, (reasoning, model_name, _) = await llm.generate_response_async( - prompt=prompt, - temperature=0.6, - max_tokens=1024 + prompt=prompt, temperature=0.6, max_tokens=1024 ) - + logger.debug(f"LLM响应: {response}") - + # 解析JSON响应 try: evaluation = json.loads(response) @@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> evaluation = json.loads(json_match.group()) else: raise ValueError("无法从响应中提取JSON格式的评估结果") from e - + suitable = evaluation.get("suitable", False) reason = evaluation.get("reason", "未提供理由") - + logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") return suitable, reason, None - + except Exception as e: logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") return False, f"评估过程出错: {str(e)}", str(e) @@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict: """ 使用LLM评估单个表达方式 - + Args: expression: 表达方式对象 llm: LLM请求实例 - + Returns: 评估结果字典 """ - logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}") - + logger.info( + f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}" + ) + suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm) - + if error: suitable = False - + logger.info(f"评估完成: {'通过' if suitable else '不通过'}") - + return { "situation": expression.situation, "style": expression.style, @@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di "reason": reason, "error": error, "evaluator": "llm", - "evaluated_at": datetime.now().isoformat() + "evaluated_at": datetime.now().isoformat(), } def perform_statistical_analysis(evaluation_results: List[Dict]): """ 对评估结果进行统计分析 - + Args: evaluation_results: 评估结果列表 """ if not evaluation_results: print("\n没有评估结果可供分析") return - + print("\n" + "=" * 60) print("统计分析结果") print("=" * 60) - + # 按count分组统计 count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0}) - + for result in evaluation_results: count = result.get("count", 1) suitable = result.get("suitable", False) @@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): count_groups[count]["suitable"] += 1 else: count_groups[count]["unsuitable"] += 1 - + # 显示每个count的统计 print("\n【按count分组统计】") print("-" * 60) @@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): suitable = group["suitable"] unsuitable = group["unsuitable"] pass_rate = (suitable / total * 100) if total > 0 else 0 - + print(f"Count = {count}:") print(f" 总数: {total}") print(f" 通过: {suitable} ({pass_rate:.2f}%)") - print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)") + print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)") print() - + # 比较count=1和count>1 count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0} count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0} - + for result in evaluation_results: count = result.get("count", 1) suitable = result.get("suitable", False) - + if count == 1: count_eq1_group["total"] += 1 if suitable: @@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): count_gt1_group["suitable"] += 1 else: count_gt1_group["unsuitable"] += 1 - + print("\n【Count=1 vs Count>1 对比】") print("-" * 60) - + eq1_total = count_eq1_group["total"] eq1_suitable = count_eq1_group["suitable"] eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0 - + gt1_total = count_gt1_group["total"] gt1_suitable = count_gt1_group["suitable"] gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0 - + print("Count = 1:") print(f" 总数: {eq1_total}") print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)") - print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)") + print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)") print() print("Count > 1:") print(f" 总数: {gt1_total}") print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)") - print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)") + print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)") print() - + # 进行卡方检验(简化版,使用2x2列联表) if eq1_total > 0 and gt1_total > 0: print("【统计显著性检验】") print("-" * 60) - + # 构建2x2列联表 # 通过 不通过 # count=1 a b @@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): b = eq1_total - eq1_suitable c = gt1_suitable d = gt1_total - gt1_suitable - + # 计算卡方统计量(简化版,使用Pearson卡方检验) n = eq1_total + gt1_total if n > 0: @@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): e_b = (eq1_total * (b + d)) / n e_c = (gt1_total * (a + c)) / n e_d = (gt1_total * (b + d)) / n - + # 检查期望频数是否足够大(卡方检验要求每个期望频数>=5) min_expected = min(e_a, e_b, e_c, e_d) if min_expected < 5: print("警告:期望频数小于5,卡方检验可能不准确") print("建议使用Fisher精确检验") - + # 计算卡方值 chi_square = 0 if e_a > 0: @@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): chi_square += ((c - e_c) ** 2) / e_c if e_d > 0: chi_square += ((d - e_d) ** 2) / e_d - + # 自由度 = (行数-1) * (列数-1) = 1 df = 1 - + # 临界值(α=0.05) chi_square_critical_005 = 3.841 chi_square_critical_001 = 6.635 - + print(f"卡方统计量: {chi_square:.4f}") print(f"自由度: {df}") print(f"临界值 (α=0.05): {chi_square_critical_005}") print(f"临界值 (α=0.01): {chi_square_critical_001}") - + if chi_square >= chi_square_critical_001: print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)") elif chi_square >= chi_square_critical_005: print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)") else: print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)") - + # 计算差异大小 diff = abs(eq1_pass_rate - gt1_pass_rate) print(f"\n合格率差异: {diff:.2f}%") @@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]): print("数据不足,无法进行统计检验") else: print("数据不足,无法进行count=1和count>1的对比分析") - + # 保存统计分析结果 analysis_result = { "analysis_time": datetime.now().isoformat(), "count_groups": {str(k): v for k, v in count_groups.items()}, "count_eq1": count_eq1_group, "count_gt1": count_gt1_group, - "total_evaluated": len(evaluation_results) + "total_evaluated": len(evaluation_results), } - + try: analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json") with open(analysis_file, "w", encoding="utf-8") as f: @@ -444,7 +444,7 @@ async def main(): logger.info("=" * 60) logger.info("开始表达方式按count分组的LLM评估和统计分析") logger.info("=" * 60) - + # 初始化数据库连接 try: db.connect(reuse_if_open=True) @@ -452,97 +452,95 @@ async def main(): except Exception as e: logger.error(f"数据库连接失败: {e}") return - + # 加载已有评估结果 existing_results, evaluated_pairs = load_existing_results() evaluation_results = existing_results.copy() - + if evaluated_pairs: print(f"\n已加载 {len(existing_results)} 条已有评估结果") print(f"已评估项目数: {len(evaluated_pairs)}") - + # 检查是否需要继续评估(检查是否还有未评估的count>1项目) # 先查询未评估的count>1项目数量 try: all_expressions = list(Expression.select()) unevaluated_count_gt1 = [ - expr for expr in all_expressions - if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs + expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs ] has_unevaluated = len(unevaluated_count_gt1) > 0 except Exception as e: logger.error(f"查询未评估项目失败: {e}") has_unevaluated = False - + if has_unevaluated: print("\n" + "=" * 60) print("开始LLM评估") print("=" * 60) print("评估结果会自动保存到文件\n") - + # 创建LLM实例 print("创建LLM实例...") try: llm = LLMRequest( model_set=model_config.model_task_config.tool_use, - request_type="expression_evaluator_count_analysis_llm" + request_type="expression_evaluator_count_analysis_llm", ) print("✓ LLM实例创建成功\n") except Exception as e: logger.error(f"创建LLM实例失败: {e}") import traceback + logger.error(traceback.format_exc()) print(f"\n✗ 创建LLM实例失败: {e}") db.close() return - + # 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目) - expressions = select_expressions_for_evaluation( - evaluated_pairs=evaluated_pairs - ) - + expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs) + if not expressions: print("\n没有可评估的项目") else: print(f"\n已选择 {len(expressions)} 条表达方式进行评估") print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条") print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n") - + batch_results = [] for i, expression in enumerate(expressions, 1): print(f"LLM评估进度: {i}/{len(expressions)}") print(f" Situation: {expression.situation}") print(f" Style: {expression.style}") print(f" Count: {expression.count}") - + llm_result = await llm_evaluate_expression(expression, llm) - + print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}") - if llm_result.get('error'): + if llm_result.get("error"): print(f" 错误: {llm_result['error']}") print() - + batch_results.append(llm_result) # 使用 (situation, style) 作为唯一标识 evaluated_pairs.add((llm_result["situation"], llm_result["style"])) - + # 添加延迟以避免API限流 await asyncio.sleep(0.3) - + # 将当前批次结果添加到总结果中 evaluation_results.extend(batch_results) - + # 保存结果 save_results(evaluation_results) else: print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果") - + # 进行统计分析 if len(evaluation_results) > 0: perform_statistical_analysis(evaluation_results) else: print("\n没有评估结果可供分析") - + # 关闭数据库连接 try: db.close() @@ -553,4 +551,3 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) - diff --git a/scripts/evaluate_expressions_llm_v6.py b/scripts/evaluate_expressions_llm_v6.py index cb9a86ff..f18acda4 100644 --- a/scripts/evaluate_expressions_llm_v6.py +++ b/scripts/evaluate_expressions_llm_v6.py @@ -20,9 +20,9 @@ from typing import List, Dict, Set, Tuple project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, project_root) -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest # noqa: E402 +from src.config.config import model_config # noqa: E402 +from src.common.logger import get_logger # noqa: E402 logger = get_logger("expression_evaluator_llm") @@ -33,7 +33,7 @@ TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp") def load_manual_results() -> List[Dict]: """ 加载人工评估结果(自动读取temp目录下所有JSON文件并合并) - + Returns: 人工评估结果列表(已去重) """ @@ -42,62 +42,62 @@ def load_manual_results() -> List[Dict]: print("\n✗ 错误:未找到temp目录") print(" 请先运行 evaluate_expressions_manual.py 进行人工评估") return [] - + # 查找所有JSON文件 json_files = glob.glob(os.path.join(TEMP_DIR, "*.json")) - + if not json_files: logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}") print("\n✗ 错误:temp目录下未找到JSON文件") print(" 请先运行 evaluate_expressions_manual.py 进行人工评估") return [] - + logger.info(f"找到 {len(json_files)} 个JSON文件") print(f"\n找到 {len(json_files)} 个JSON文件:") for json_file in json_files: print(f" - {os.path.basename(json_file)}") - + # 读取并合并所有JSON文件 all_results = [] seen_pairs: Set[Tuple[str, str]] = set() # 用于去重 - + for json_file in json_files: try: with open(json_file, "r", encoding="utf-8") as f: data = json.load(f) results = data.get("manual_results", []) - + # 去重:使用(situation, style)作为唯一标识 for result in results: if "situation" not in result or "style" not in result: logger.warning(f"跳过无效数据(缺少必要字段): {result}") continue - + pair = (result["situation"], result["style"]) if pair not in seen_pairs: seen_pairs.add(pair) all_results.append(result) - + logger.info(f"从 {os.path.basename(json_file)} 加载了 {len(results)} 条结果") except Exception as e: logger.error(f"加载文件 {json_file} 失败: {e}") print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}") continue - + logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)") print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)") - + return all_results def create_evaluation_prompt(situation: str, style: str) -> str: """ 创建评估提示词 - + Args: situation: 情境 style: 风格 - + Returns: 评估提示词 """ @@ -119,51 +119,50 @@ def create_evaluation_prompt(situation: str, style: str) -> str: }} 如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。 请严格按照JSON格式输出,不要包含其他内容。""" - + return prompt async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]: """ 执行单次LLM评估 - + Args: situation: 情境 style: 风格 llm: LLM请求实例 - + Returns: (suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息 """ try: prompt = create_evaluation_prompt(situation, style) logger.debug(f"正在评估表达方式: situation={situation}, style={style}") - + response, (reasoning, model_name, _) = await llm.generate_response_async( - prompt=prompt, - temperature=0.6, - max_tokens=1024 + prompt=prompt, temperature=0.6, max_tokens=1024 ) - + logger.debug(f"LLM响应: {response}") - + # 解析JSON响应 try: evaluation = json.loads(response) except json.JSONDecodeError as e: import re + json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL) if json_match: evaluation = json.loads(json_match.group()) else: raise ValueError("无法从响应中提取JSON格式的评估结果") from e - + suitable = evaluation.get("suitable", False) reason = evaluation.get("reason", "未提供理由") - + logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") return suitable, reason, None - + except Exception as e: logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") return False, f"评估过程出错: {str(e)}", str(e) @@ -172,68 +171,68 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict: """ 使用LLM评估单个表达方式 - + Args: situation: 情境 style: 风格 llm: LLM请求实例 - + Returns: 评估结果字典 """ logger.info(f"开始评估表达方式: situation={situation}, style={style}") - + suitable, reason, error = await _single_llm_evaluation(situation, style, llm) - + if error: suitable = False - + logger.info(f"评估完成: {'通过' if suitable else '不通过'}") - + return { "situation": situation, "style": style, "suitable": suitable, "reason": reason, "error": error, - "evaluator": "llm" + "evaluator": "llm", } def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict: """ 对比人工评估和LLM评估的结果 - + Args: manual_results: 人工评估结果列表 llm_results: LLM评估结果列表 method_name: 评估方法名称(用于标识) - + Returns: 对比分析结果字典 """ # 按(situation, style)建立映射 llm_dict = {(r["situation"], r["style"]): r for r in llm_results} - + total = len(manual_results) matched = 0 true_positives = 0 true_negatives = 0 false_positives = 0 false_negatives = 0 - + for manual_result in manual_results: pair = (manual_result["situation"], manual_result["style"]) llm_result = llm_dict.get(pair) if llm_result is None: continue - + manual_suitable = manual_result["suitable"] llm_suitable = llm_result["suitable"] - + if manual_suitable == llm_suitable: matched += 1 - + if manual_suitable and llm_suitable: true_positives += 1 elif not manual_suitable and not llm_suitable: @@ -242,30 +241,36 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met false_positives += 1 elif manual_suitable and not llm_suitable: false_negatives += 1 - + accuracy = (matched / total * 100) if total > 0 else 0 - precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0 - recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0 + precision = ( + (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0 + ) + recall = ( + (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0 + ) f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0 - specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0 - + specificity = ( + (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0 + ) + # 计算人工效标的不合适率 manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数 manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0 - + # 计算经过LLM删除后剩余项目中的不合适率 # 在所有项目中,移除LLM判定为不合适的项目后,剩下的项目 = TP + FP(LLM判定为合适的项目) # 在这些剩下的项目中,按人工评定的不合适项目 = FP(人工认为不合适,但LLM认为合适) llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数(保留的项目) llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0 - + # 两者百分比相减(评估LLM评定修正后的不合适率是否有降低) rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate - + random_baseline = 50.0 accuracy_above_random = accuracy - random_baseline accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0 - + return { "method": method_name, "total": total, @@ -283,29 +288,29 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met "specificity": specificity, "manual_unsuitable_rate": manual_unsuitable_rate, "llm_kept_unsuitable_rate": llm_kept_unsuitable_rate, - "rate_difference": rate_difference + "rate_difference": rate_difference, } async def main(count: int | None = None): """ 主函数 - + Args: count: 随机选取的数据条数,如果为None则使用全部数据 """ logger.info("=" * 60) logger.info("开始表达方式LLM评估") logger.info("=" * 60) - + # 1. 加载人工评估结果 print("\n步骤1: 加载人工评估结果") manual_results = load_manual_results() if not manual_results: return - + print(f"成功加载 {len(manual_results)} 条人工评估结果") - + # 如果指定了数量,随机选择指定数量的数据 if count is not None: if count <= 0: @@ -317,7 +322,7 @@ async def main(count: int | None = None): random.seed() # 使用系统时间作为随机种子 manual_results = random.sample(manual_results, count) print(f"随机选取 {len(manual_results)} 条数据进行评估") - + # 验证数据完整性 valid_manual_results = [] for r in manual_results: @@ -325,62 +330,58 @@ async def main(count: int | None = None): valid_manual_results.append(r) else: logger.warning(f"跳过无效数据: {r}") - + if len(valid_manual_results) != len(manual_results): print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过") - + print(f"有效数据: {len(valid_manual_results)} 条") - + # 2. 创建LLM实例并评估 print("\n步骤2: 创建LLM实例") try: - llm = LLMRequest( - model_set=model_config.model_task_config.tool_use, - request_type="expression_evaluator_llm" - ) + llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm") except Exception as e: logger.error(f"创建LLM实例失败: {e}") import traceback + logger.error(traceback.format_exc()) return - + print("\n步骤3: 开始LLM评估") llm_results = [] for i, manual_result in enumerate(valid_manual_results, 1): print(f"LLM评估进度: {i}/{len(valid_manual_results)}") - llm_results.append(await evaluate_expression_llm( - manual_result["situation"], - manual_result["style"], - llm - )) + llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm)) await asyncio.sleep(0.3) - + # 5. 输出FP和FN项目(在评估结果之前) llm_dict = {(r["situation"], r["style"]): r for r in llm_results} - + # 5.1 输出FP项目(人工评估不通过但LLM误判为通过) print("\n" + "=" * 60) print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)") print("=" * 60) - + fp_items = [] for manual_result in valid_manual_results: pair = (manual_result["situation"], manual_result["style"]) llm_result = llm_dict.get(pair) if llm_result is None: continue - + # 人工评估不通过,但LLM评估通过(FP情况) if not manual_result["suitable"] and llm_result["suitable"]: - fp_items.append({ - "situation": manual_result["situation"], - "style": manual_result["style"], - "manual_suitable": manual_result["suitable"], - "llm_suitable": llm_result["suitable"], - "llm_reason": llm_result.get("reason", "未提供理由"), - "llm_error": llm_result.get("error") - }) - + fp_items.append( + { + "situation": manual_result["situation"], + "style": manual_result["style"], + "manual_suitable": manual_result["suitable"], + "llm_suitable": llm_result["suitable"], + "llm_reason": llm_result.get("reason", "未提供理由"), + "llm_error": llm_result.get("error"), + } + ) + if fp_items: print(f"\n共找到 {len(fp_items)} 条误判项目:\n") for idx, item in enumerate(fp_items, 1): @@ -389,36 +390,38 @@ async def main(count: int | None = None): print(f"Style: {item['style']}") print("人工评估: 不通过 ❌") print("LLM评估: 通过 ✅ (误判)") - if item.get('llm_error'): + if item.get("llm_error"): print(f"LLM错误: {item['llm_error']}") print(f"LLM理由: {item['llm_reason']}") print() else: print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)") - + # 5.2 输出FN项目(人工评估通过但LLM误判为不通过) print("\n" + "=" * 60) print("人工评估通过但LLM误判为不通过的项目(FN - False Negative)") print("=" * 60) - + fn_items = [] for manual_result in valid_manual_results: pair = (manual_result["situation"], manual_result["style"]) llm_result = llm_dict.get(pair) if llm_result is None: continue - + # 人工评估通过,但LLM评估不通过(FN情况) if manual_result["suitable"] and not llm_result["suitable"]: - fn_items.append({ - "situation": manual_result["situation"], - "style": manual_result["style"], - "manual_suitable": manual_result["suitable"], - "llm_suitable": llm_result["suitable"], - "llm_reason": llm_result.get("reason", "未提供理由"), - "llm_error": llm_result.get("error") - }) - + fn_items.append( + { + "situation": manual_result["situation"], + "style": manual_result["style"], + "manual_suitable": manual_result["suitable"], + "llm_suitable": llm_result["suitable"], + "llm_reason": llm_result.get("reason", "未提供理由"), + "llm_error": llm_result.get("error"), + } + ) + if fn_items: print(f"\n共找到 {len(fn_items)} 条误删项目:\n") for idx, item in enumerate(fn_items, 1): @@ -427,33 +430,41 @@ async def main(count: int | None = None): print(f"Style: {item['style']}") print("人工评估: 通过 ✅") print("LLM评估: 不通过 ❌ (误删)") - if item.get('llm_error'): + if item.get("llm_error"): print(f"LLM错误: {item['llm_error']}") print(f"LLM理由: {item['llm_reason']}") print() else: print("\n✓ 没有误删项目(所有人工评估通过的项目都被LLM正确识别为通过)") - + # 6. 对比分析并输出结果 comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估") - + print("\n" + "=" * 60) print("评估结果(以人工评估为标准)") print("=" * 60) - + # 详细评估结果(核心指标优先) print(f"\n--- {comparison['method']} ---") print(f" 总数: {comparison['total']} 条") print() # print(" 【核心能力指标】") print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)") - print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})") - print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个") + print( + f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})" + ) + print( + f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个" + ) # print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)") print() print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)") - print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})") - print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个") + print( + f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})" + ) + print( + f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个" + ) # print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)") print() print(" 【其他指标】") @@ -464,12 +475,18 @@ async def main(count: int | None = None): print() print(" 【不合适率分析】") print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%") - print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}") + print( + f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}" + ) print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适") print() print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%") - print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})") - print(f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%") + print( + f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})" + ) + print( + f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%" + ) print() # print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%") # print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%") @@ -480,21 +497,22 @@ async def main(count: int | None = None): print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐") print(f" FP (误判为合适): {comparison['false_positives']} ⚠️") print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️") - + # 7. 保存结果到JSON文件 output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json") try: os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: - json.dump({ - "manual_results": valid_manual_results, - "llm_results": llm_results, - "comparison": comparison - }, f, ensure_ascii=False, indent=2) + json.dump( + {"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison}, + f, + ensure_ascii=False, + indent=2, + ) logger.info(f"\n评估结果已保存到: {output_file}") except Exception as e: logger.warning(f"保存结果到文件失败: {e}") - + print("\n" + "=" * 60) print("评估完成") print("=" * 60) @@ -509,15 +527,9 @@ if __name__ == "__main__": python evaluate_expressions_llm_v6.py # 使用全部数据 python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据 python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据 - """ + """, ) - parser.add_argument( - "-n", "--count", - type=int, - default=None, - help="随机选取的数据条数(默认:使用全部数据)" - ) - + parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)") + args = parser.parse_args() asyncio.run(main(count=args.count)) - diff --git a/scripts/evaluate_expressions_manual.py b/scripts/evaluate_expressions_manual.py index 8221112b..ec139c0b 100644 --- a/scripts/evaluate_expressions_manual.py +++ b/scripts/evaluate_expressions_manual.py @@ -18,9 +18,9 @@ from datetime import datetime project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, project_root) -from src.common.database.database_model import Expression -from src.common.database.database import db -from src.common.logger import get_logger +from src.common.database.database_model import Expression # noqa: E402 +from src.common.database.database import db # noqa: E402 +from src.common.logger import get_logger # noqa: E402 logger = get_logger("expression_evaluator_manual") @@ -32,13 +32,13 @@ MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json") def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: """ 加载已有的评估结果 - + Returns: (已有结果列表, 已评估的项目(situation, style)元组集合) """ if not os.path.exists(MANUAL_EVAL_FILE): return [], set() - + try: with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f: data = json.load(f) @@ -55,22 +55,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: def save_results(manual_results: List[Dict]): """ 保存评估结果到文件 - + Args: manual_results: 评估结果列表 """ try: os.makedirs(TEMP_DIR, exist_ok=True) - + data = { "last_updated": datetime.now().isoformat(), "total_count": len(manual_results), - "manual_results": manual_results + "manual_results": manual_results, } - + with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - + logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}") print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)") except Exception as e: @@ -81,45 +81,43 @@ def save_results(manual_results: List[Dict]): def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]: """ 获取未评估的表达方式 - + Args: evaluated_pairs: 已评估的项目(situation, style)元组集合 batch_size: 每次获取的数量 - + Returns: 未评估的表达方式列表 """ try: # 查询所有表达方式 all_expressions = list(Expression.select()) - + if not all_expressions: logger.warning("数据库中没有表达方式记录") return [] - + # 过滤出未评估的项目:匹配 situation 和 style 均一致 - unevaluated = [ - expr for expr in all_expressions - if (expr.situation, expr.style) not in evaluated_pairs - ] - + unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs] + if not unevaluated: logger.info("所有项目都已评估完成") return [] - + # 如果未评估数量少于请求数量,返回所有 if len(unevaluated) <= batch_size: logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回") return unevaluated - + # 随机选择指定数量 selected = random.sample(unevaluated, batch_size) logger.info(f"从 {len(unevaluated)} 条未评估项目中随机选择了 {len(selected)} 条") return selected - + except Exception as e: logger.error(f"获取未评估表达方式失败: {e}") import traceback + logger.error(traceback.format_exc()) return [] @@ -127,12 +125,12 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict: """ 人工评估单个表达方式 - + Args: expression: 表达方式对象 index: 当前索引(从1开始) total: 总数 - + Returns: 评估结果字典,如果用户退出则返回 None """ @@ -146,38 +144,38 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) - print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)") print(" 输入 'q' 或 'quit' 退出评估") print(" 输入 's' 或 'skip' 跳过当前项目") - + while True: user_input = input("\n您的评估 (y/n/q/s): ").strip().lower() - - if user_input in ['q', 'quit']: + + if user_input in ["q", "quit"]: print("退出评估") return None - - if user_input in ['s', 'skip']: + + if user_input in ["s", "skip"]: print("跳过当前项目") return "skip" - - if user_input in ['y', 'yes', '1', '是', '通过']: + + if user_input in ["y", "yes", "1", "是", "通过"]: suitable = True break - elif user_input in ['n', 'no', '0', '否', '不通过']: + elif user_input in ["n", "no", "0", "否", "不通过"]: suitable = False break else: print("输入无效,请重新输入 (y/n/q/s)") - + result = { "situation": expression.situation, "style": expression.style, "suitable": suitable, "reason": None, "evaluator": "manual", - "evaluated_at": datetime.now().isoformat() + "evaluated_at": datetime.now().isoformat(), } - + print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}") - + return result @@ -186,7 +184,7 @@ def main(): logger.info("=" * 60) logger.info("开始表达方式人工评估") logger.info("=" * 60) - + # 初始化数据库连接 try: db.connect(reuse_if_open=True) @@ -194,41 +192,41 @@ def main(): except Exception as e: logger.error(f"数据库连接失败: {e}") return - + # 加载已有评估结果 existing_results, evaluated_pairs = load_existing_results() manual_results = existing_results.copy() - + if evaluated_pairs: print(f"\n已加载 {len(existing_results)} 条已有评估结果") print(f"已评估项目数: {len(evaluated_pairs)}") - + print("\n" + "=" * 60) print("开始人工评估") print("=" * 60) print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目") print("评估结果会自动保存到文件\n") - + batch_size = 10 batch_count = 0 - + while True: # 获取未评估的项目 expressions = get_unevaluated_expressions(evaluated_pairs, batch_size) - + if not expressions: print("\n" + "=" * 60) print("所有项目都已评估完成!") print("=" * 60) break - + batch_count += 1 print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---") - + batch_results = [] for i, expression in enumerate(expressions, 1): manual_result = manual_evaluate_expression(expression, i, len(expressions)) - + if manual_result is None: # 用户退出 print("\n评估已中断") @@ -237,34 +235,34 @@ def main(): manual_results.extend(batch_results) save_results(manual_results) return - + if manual_result == "skip": # 跳过当前项目 continue - + batch_results.append(manual_result) # 使用 (situation, style) 作为唯一标识 evaluated_pairs.add((manual_result["situation"], manual_result["style"])) - + # 将当前批次结果添加到总结果中 manual_results.extend(batch_results) - + # 保存结果 save_results(manual_results) - + print(f"\n当前批次完成,已评估总数: {len(manual_results)} 条") - + # 询问是否继续 while True: continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower() - if continue_input in ['y', 'yes', '1', '是', '继续']: + if continue_input in ["y", "yes", "1", "是", "继续"]: break - elif continue_input in ['n', 'no', '0', '否', '退出']: + elif continue_input in ["n", "no", "0", "否", "退出"]: print("\n评估结束") return else: print("输入无效,请重新输入 (y/n)") - + # 关闭数据库连接 try: db.close() @@ -275,4 +273,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 4057a52c..b7ec7442 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -134,9 +134,7 @@ def handle_import_openie( # 在非交互模式下,不再询问用户,而是直接报错终止 logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") if non_interactive: - logger.error( - "检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。" - ) + logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。") sys.exit(1) logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") user_choice = input().strip().lower() @@ -189,9 +187,7 @@ def handle_import_openie( async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension # 新增确认提示 if non_interactive: - logger.warning( - "当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。" - ) + logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。") else: print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") @@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d def main(argv: Optional[list[str]] = None) -> None: """主函数 - 解析参数并运行异步主流程。""" parser = argparse.ArgumentParser( - description=( - "OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次," - "将其导入到 LPMM 的向量库与知识图中。" - ) + description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。") ) parser.add_argument( "--non-interactive", diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index b5068b0f..bd52536e 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension ensure_dirs() # 确保目录存在 # 新增用户确认提示 if non_interactive: - logger.warning( - "当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。" - ) + logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。") else: print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") diff --git a/scripts/inspect_lpmm_global.py b/scripts/inspect_lpmm_global.py index 13b80e14..eb53259f 100644 --- a/scripts/inspect_lpmm_global.py +++ b/scripts/inspect_lpmm_global.py @@ -1,6 +1,5 @@ import os import sys -from typing import Set # 保证可以导入 src.* sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -32,7 +31,6 @@ def main() -> None: # KG 统计 nodes = kg.graph.get_node_list() edges = kg.graph.get_edge_list() - node_set: Set[str] = set(nodes) para_nodes = [n for n in nodes if n.startswith("paragraph-")] ent_nodes = [n for n in nodes if n.startswith("entity-")] @@ -68,4 +66,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/scripts/lpmm_interactive_manager.py b/scripts/lpmm_interactive_manager.py index 2d932d00..1dac5a3e 100644 --- a/scripts/lpmm_interactive_manager.py +++ b/scripts/lpmm_interactive_manager.py @@ -29,6 +29,7 @@ except ImportError as e: logger = get_logger("lpmm_interactive_manager") + async def interactive_add(): """交互式导入知识""" print("\n" + "=" * 40) @@ -38,7 +39,7 @@ async def interactive_add(): print(" - 支持多段落,段落间请保留空行。") print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。") print("-" * 40) - + lines = [] while True: try: @@ -48,7 +49,7 @@ async def interactive_add(): lines.append(line) except EOFError: break - + text = "\n".join(lines).strip() if not text: print("\n[!] 内容为空,操作已取消。") @@ -58,7 +59,7 @@ async def interactive_add(): try: # 使用 lpmm_ops.py 中的接口 result = await lpmm_ops.add_content(text) - + if result["status"] == "success": print(f"\n[√] 成功:{result['message']}") print(f" 实际新增段落数: {result.get('count', 0)}") @@ -68,6 +69,7 @@ async def interactive_add(): print(f"\n[×] 发生异常: {e}") logger.error(f"add_content 异常: {e}", exc_info=True) + async def interactive_delete(): """交互式删除知识""" print("\n" + "=" * 40) @@ -77,10 +79,10 @@ async def interactive_delete(): print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)") print(" 2. 完整文段匹配(删除完全匹配的段落)") print("-" * 40) - + mode = input("请选择删除模式 (1/2): ").strip() exact_match = False - + if mode == "2": exact_match = True print("\n[完整文段匹配模式]") @@ -102,14 +104,18 @@ async def interactive_delete(): print("\n[!] 无效选择,默认使用关键词模糊匹配模式。") print("\n[关键词模糊匹配模式]") keyword = input("请输入匹配关键词: ").strip() - + if not keyword: print("\n[!] 输入为空,操作已取消。") return - + print("-" * 40) - confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower() - if confirm != 'y': + confirm = ( + input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ") + .strip() + .lower() + ) + if confirm != "y": print("\n[!] 已取消删除操作。") return @@ -117,7 +123,7 @@ async def interactive_delete(): try: # 使用 lpmm_ops.py 中的接口 result = await lpmm_ops.delete(keyword, exact_match=exact_match) - + if result["status"] == "success": print(f"\n[√] 成功:{result['message']}") print(f" 删除条数: {result.get('deleted_count', 0)}") @@ -129,6 +135,7 @@ async def interactive_delete(): print(f"\n[×] 发生异常: {e}") logger.error(f"delete 异常: {e}", exc_info=True) + async def interactive_clear(): """交互式清空知识库""" print("\n" + "=" * 40) @@ -141,40 +148,45 @@ async def interactive_clear(): print(" - 整个知识图谱") print(" - 此操作不可恢复!") print("-" * 40) - + # 双重确认 confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip() if confirm1 != "YES": print("\n[!] 已取消清空操作。") return - + print("\n" + "=" * 40) confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip() if confirm2 != "CLEAR": print("\n[!] 已取消清空操作。") return - + print("\n[进度] 正在清空知识库...") try: # 使用 lpmm_ops.py 中的接口 result = await lpmm_ops.clear_all() - + if result["status"] == "success": print(f"\n[√] 成功:{result['message']}") stats = result.get("stats", {}) before = stats.get("before", {}) after = stats.get("after", {}) print("\n[统计信息]") - print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, " - f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}") - print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, " - f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}") + print( + f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, " + f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}" + ) + print( + f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, " + f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}" + ) else: print(f"\n[×] 失败:{result['message']}") except Exception as e: print(f"\n[×] 发生异常: {e}") logger.error(f"clear_all 异常: {e}", exc_info=True) + async def interactive_search(): """交互式查询知识""" print("\n" + "=" * 40) @@ -182,25 +194,25 @@ async def interactive_search(): print("=" * 40) print("说明:输入查询问题或关键词,系统会返回相关的知识段落。") print("-" * 40) - + # 确保 LPMM 已初始化 if not global_config.lpmm_knowledge.enable: print("\n[!] 警告:LPMM 知识库在配置中未启用。") return - + try: lpmm_start_up() except Exception as e: print(f"\n[!] LPMM 初始化失败: {e}") logger.error(f"LPMM 初始化失败: {e}", exc_info=True) return - + query = input("请输入查询问题或关键词: ").strip() - + if not query: print("\n[!] 查询内容为空,操作已取消。") return - + # 询问返回条数 print("-" * 40) limit_str = input("希望返回的相关知识条数(默认3,直接回车使用默认值): ").strip() @@ -210,11 +222,11 @@ async def interactive_search(): except ValueError: limit = 3 print("[!] 输入无效,使用默认值 3。") - + print("\n[进度] 正在查询知识库...") try: result = await query_lpmm_knowledge(query, limit=limit) - + print("\n" + "=" * 60) print("[查询结果]") print("=" * 60) @@ -224,6 +236,7 @@ async def interactive_search(): print(f"\n[×] 查询失败: {e}") logger.error(f"查询异常: {e}", exc_info=True) + async def main(): """主循环""" while True: @@ -236,9 +249,9 @@ async def main(): print("║ 4. 清空知识库 (Clear All) ⚠️ ║") print("║ 0. 退出 (Exit) ║") print("╚" + "═" * 38 + "╝") - + choice = input("请选择操作编号: ").strip() - + if choice == "1": await interactive_add() elif choice == "2": @@ -253,6 +266,7 @@ async def main(): else: print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。") + if __name__ == "__main__": try: # 运行主循环 @@ -262,4 +276,3 @@ if __name__ == "__main__": except Exception as e: print(f"\n[!] 程序运行出错: {e}") logger.error(f"Main loop 异常: {e}", exc_info=True) - diff --git a/scripts/lpmm_manager.py b/scripts/lpmm_manager.py index 60f32364..2f935c51 100644 --- a/scripts/lpmm_manager.py +++ b/scripts/lpmm_manager.py @@ -21,18 +21,18 @@ PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "..")) if PROJECT_ROOT not in sys.path: sys.path.append(PROJECT_ROOT) -from src.common.logger import get_logger # type: ignore -from src.config.config import global_config, model_config # type: ignore +from src.common.logger import get_logger # type: ignore # noqa: E402 +from src.config.config import global_config, model_config # type: ignore # noqa: E402 # 引入各功能脚本的入口函数 -from import_openie import main as import_openie_main # type: ignore -from info_extraction import main as info_extraction_main # type: ignore -from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore -from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore -from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore -from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore -from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore -from raw_data_preprocessor import load_raw_data # type: ignore +from import_openie import main as import_openie_main # type: ignore # noqa: E402 +from info_extraction import main as info_extraction_main # type: ignore # noqa: E402 +from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore # noqa: E402 +from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore # noqa: E402 +from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore # noqa: E402 +from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore # noqa: E402 +from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore # noqa: E402 +from raw_data_preprocessor import load_raw_data # type: ignore # noqa: E402 logger = get_logger("lpmm_manager") @@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool: raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data" txt_files = list(raw_dir.glob("*.txt")) if not txt_files: - msg = ( - f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件," - "info_extraction 可能立即退出或无数据可处理。" - ) + msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,info_extraction 可能立即退出或无数据可处理。" print(msg) if non_interactive: - logger.error( - "非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。" - ) + logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。") return False cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower() return cont == "y" @@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool: openie_dir = Path(PROJECT_ROOT) / "data" / "openie" json_files = list(openie_dir.glob("*.json")) if not json_files: - msg = ( - f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件," - "import_openie 可能会因为找不到批次而失败。" - ) + msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,import_openie 可能会因为找不到批次而失败。" print(msg) if non_interactive: - logger.error( - "非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。" - ) + logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。") return False cont = input("仍然继续执行导入吗?(y/n): ").strip().lower() return cont == "y" @@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None: """在部分操作前提醒 lpmm_knowledge.enable 状态。""" try: if not getattr(global_config.lpmm_knowledge, "enable", False): - print( - "[WARN] 当前配置 lpmm_knowledge.enable = false," - "刷新或检索测试可能无法在聊天侧真正启用 LPMM。" - ) + print("[WARN] 当前配置 lpmm_knowledge.enable = false,刷新或检索测试可能无法在聊天侧真正启用 LPMM。") except Exception: # 配置异常时不阻断主流程,仅忽略提示 pass @@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None: if action == "prepare_raw": logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...") sha_list, raw_data = load_raw_data() - print( - f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落," - f"去重后哈希数 {len(sha_list)}。" - ) + print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。") elif action == "info_extract": if not _check_before_info_extract("--non-interactive" in extra_args): print("已根据用户选择,取消执行信息提取。") @@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None: # 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新 logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新") sha_list, raw_data = load_raw_data() - print( - f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落," - f"去重后哈希数 {len(sha_list)}。" - ) + print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。") non_interactive = "--non-interactive" in extra_args if not _check_before_info_extract(non_interactive): print("已根据用户选择,取消 full_import(信息提取阶段被取消)。") @@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]: ) # 快速选项:按推荐方式清理所有相关实体/关系 - quick_all = input( - "是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): " - ).strip().lower() + quick_all = ( + input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower() + ) if quick_all in ("", "y", "yes"): args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"]) else: @@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]: def _interactive_build_batch_inspect_args() -> List[str]: """为 inspect_lpmm_batch 构造 --openie-file 参数。""" - path = _interactive_choose_openie_file( - "请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):" - ) + path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):") if not path: return [] return ["--openie-file", path] @@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]: def _interactive_build_test_args() -> List[str]: """为 test_lpmm_retrieval 构造自定义测试用例参数。""" - print( - "\n[TEST] 你可以:\n" - "- 直接回车使用内置的默认测试用例;\n" - "- 或者输入一条自定义问题,并指定期望命中的关键字。" - ) + print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。") query = input("请输入自定义测试问题(回车则使用默认用例):").strip() if not query: return [] @@ -422,9 +397,7 @@ def _run_embedding_helper() -> None: print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}") print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}") - new_dim = input( - "\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):" - ).strip() + new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip() if new_dim and not new_dim.isdigit(): print("输入的维度不是纯数字,已取消操作。") return @@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None: if __name__ == "__main__": main() - - diff --git a/scripts/test_memory_retrieval.py b/scripts/test_memory_retrieval.py index 5348bdc4..7519a306 100644 --- a/scripts/test_memory_retrieval.py +++ b/scripts/test_memory_retrieval.py @@ -28,53 +28,55 @@ from maim_message import UserInfo, GroupInfo logger = get_logger("test_memory_retrieval") + # 使用 importlib 动态导入,避免循环导入问题 def _import_memory_retrieval(): """使用 importlib 动态导入 memory_retrieval 模块,避免循环导入""" try: # 先导入 prompt_builder,检查 prompt 是否已经初始化 from src.chat.utils.prompt_builder import global_prompt_manager - + # 检查 memory_retrieval 相关的 prompt 是否已经注册 # 如果已经注册,说明模块可能已经通过其他路径初始化过了 prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts - + module_name = "src.memory_system.memory_retrieval" - + # 如果 prompt 已经初始化,尝试直接使用已加载的模块 if prompt_already_init and module_name in sys.modules: existing_module = sys.modules[module_name] - if hasattr(existing_module, 'init_memory_retrieval_prompt'): + if hasattr(existing_module, "init_memory_retrieval_prompt"): return ( existing_module.init_memory_retrieval_prompt, existing_module._react_agent_solve_question, existing_module._process_single_question, ) - + # 如果模块已经在 sys.modules 中但部分初始化,先移除它 if module_name in sys.modules: existing_module = sys.modules[module_name] - if not hasattr(existing_module, 'init_memory_retrieval_prompt'): + if not hasattr(existing_module, "init_memory_retrieval_prompt"): # 模块部分初始化,移除它 logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入") del sys.modules[module_name] # 清理可能相关的部分初始化模块 keys_to_remove = [] for key in sys.modules.keys(): - if key.startswith('src.memory_system.') and key != 'src.memory_system': + if key.startswith("src.memory_system.") and key != "src.memory_system": keys_to_remove.append(key) for key in keys_to_remove: try: del sys.modules[key] except KeyError: pass - + # 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载 # 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们 try: # 先导入可能触发循环导入的模块,让它们完成初始化 import src.config.config import src.chat.utils.prompt_builder + # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) # 如果它们已经导入,就确保它们完全初始化 # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) @@ -89,11 +91,11 @@ def _import_memory_retrieval(): pass # 如果导入失败,继续 except Exception as e: logger.warning(f"预加载依赖模块时出现警告: {e}") - + # 现在尝试导入 memory_retrieval # 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval memory_retrieval_module = importlib.import_module(module_name) - + return ( memory_retrieval_module.init_memory_retrieval_prompt, memory_retrieval_module._react_agent_solve_question, @@ -126,16 +128,16 @@ def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStrea def get_token_usage_since(start_time: float) -> Dict[str, Any]: """获取从指定时间开始的token使用情况 - + Args: start_time: 开始时间戳 - + Returns: 包含token使用统计的字典 """ try: start_datetime = datetime.fromtimestamp(start_time) - + # 查询从开始时间到现在的所有memory相关的token使用记录 records = ( LLMUsage.select() @@ -150,21 +152,21 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]: ) .order_by(LLMUsage.timestamp.asc()) ) - + total_prompt_tokens = 0 total_completion_tokens = 0 total_tokens = 0 total_cost = 0.0 request_count = 0 model_usage = {} # 按模型统计 - + for record in records: total_prompt_tokens += record.prompt_tokens or 0 total_completion_tokens += record.completion_tokens or 0 total_tokens += record.total_tokens or 0 total_cost += record.cost or 0.0 request_count += 1 - + # 按模型统计 model_name = record.model_name or "unknown" if model_name not in model_usage: @@ -180,7 +182,7 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]: model_usage[model_name]["total_tokens"] += record.total_tokens or 0 model_usage[model_name]["cost"] += record.cost or 0.0 model_usage[model_name]["request_count"] += 1 - + return { "total_prompt_tokens": total_prompt_tokens, "total_completion_tokens": total_completion_tokens, @@ -205,25 +207,25 @@ def format_thinking_steps(thinking_steps: list) -> str: """格式化思考步骤为可读字符串""" if not thinking_steps: return "无思考步骤" - + lines = [] for step in thinking_steps: iteration = step.get("iteration", "?") thought = step.get("thought", "") actions = step.get("actions", []) observations = step.get("observations", []) - + lines.append(f"\n--- 迭代 {iteration} ---") if thought: lines.append(f"思考: {thought[:200]}...") - + if actions: lines.append("行动:") for action in actions: action_type = action.get("action_type", "unknown") action_params = action.get("action_params", {}) lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}") - + if observations: lines.append("观察:") for obs in observations: @@ -231,7 +233,7 @@ def format_thinking_steps(thinking_steps: list) -> str: if len(str(obs)) > 200: obs_str += "..." lines.append(f" - {obs_str}") - + return "\n".join(lines) @@ -242,31 +244,32 @@ async def test_memory_retrieval( max_iterations: Optional[int] = None, ) -> Dict[str, Any]: """测试记忆检索功能 - + Args: question: 要查询的问题 chat_id: 聊天ID context: 上下文信息 max_iterations: 最大迭代次数 - + Returns: 包含测试结果的字典 """ print("\n" + "=" * 80) - print(f"[测试] 记忆检索测试") + print("[测试] 记忆检索测试") print(f"[问题] {question}") print("=" * 80) - + # 记录开始时间 start_time = time.time() - + # 延迟导入并初始化记忆检索prompt(这会自动加载 global_config) # 注意:必须在函数内部调用,避免在模块级别触发循环导入 try: init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval() - + # 检查 prompt 是否已经初始化,避免重复初始化 from src.chat.utils.prompt_builder import global_prompt_manager + if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts: init_memory_retrieval_prompt() else: @@ -274,24 +277,24 @@ async def test_memory_retrieval( except Exception as e: logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True) raise - + # 获取 global_config(此时应该已经加载) from src.config.config import global_config - + # 直接调用 _react_agent_solve_question 来获取详细的迭代信息 if max_iterations is None: max_iterations = global_config.memory.max_agent_iterations - + timeout = global_config.memory.agent_timeout_seconds - - print(f"\n[配置]") + + print("\n[配置]") print(f" 最大迭代次数: {max_iterations}") print(f" 超时时间: {timeout}秒") print(f" 聊天ID: {chat_id}") - + # 执行检索 print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") - + found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( question=question, chat_id=chat_id, @@ -299,14 +302,14 @@ async def test_memory_retrieval( timeout=timeout, initial_info="", ) - + # 记录结束时间 end_time = time.time() elapsed_time = end_time - start_time - + # 获取token使用情况 token_usage = get_token_usage_since(start_time) - + # 构建结果 result = { "question": question, @@ -318,41 +321,41 @@ async def test_memory_retrieval( "iteration_count": len(thinking_steps), "token_usage": token_usage, } - + # 输出结果 print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") - print(f"\n[结果]") + print("\n[结果]") print(f" 是否找到答案: {'是' if found_answer else '否'}") if found_answer and answer: print(f" 答案: {answer}") else: - print(f" 答案: (未找到答案)") + print(" 答案: (未找到答案)") print(f" 是否超时: {'是' if is_timeout else '否'}") print(f" 迭代次数: {len(thinking_steps)}") print(f" 总耗时: {elapsed_time:.2f}秒") - - print(f"\n[Token使用情况]") + + print("\n[Token使用情况]") print(f" 总请求数: {token_usage['request_count']}") print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}") print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}") print(f" 总Tokens: {token_usage['total_tokens']:,}") print(f" 总成本: ${token_usage['total_cost']:.6f}") - - if token_usage['model_usage']: - print(f"\n[按模型统计]") - for model_name, usage in token_usage['model_usage'].items(): + + if token_usage["model_usage"]: + print("\n[按模型统计]") + for model_name, usage in token_usage["model_usage"].items(): print(f" {model_name}:") print(f" 请求数: {usage['request_count']}") print(f" Prompt Tokens: {usage['prompt_tokens']:,}") print(f" Completion Tokens: {usage['completion_tokens']:,}") print(f" 总Tokens: {usage['total_tokens']:,}") print(f" 成本: ${usage['cost']:.6f}") - - print(f"\n[迭代详情]") + + print("\n[迭代详情]") print(format_thinking_steps(thinking_steps)) - + print("\n" + "=" * 80) - + return result @@ -375,12 +378,12 @@ def main() -> None: "-o", help="将结果保存到JSON文件(可选)", ) - + args = parser.parse_args() - + # 初始化日志(使用较低的详细程度,避免输出过多日志) initialize_logging(verbose=False) - + # 交互式输入问题 print("\n" + "=" * 80) print("记忆检索测试工具") @@ -389,7 +392,7 @@ def main() -> None: if not question: print("错误: 问题不能为空") return - + # 交互式输入最大迭代次数 max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip() max_iterations = None @@ -402,7 +405,7 @@ def main() -> None: except ValueError: print("警告: 无效的迭代次数,将使用配置默认值") max_iterations = None - + # 连接数据库 try: db.connect(reuse_if_open=True) @@ -410,7 +413,7 @@ def main() -> None: logger.error(f"数据库连接失败: {e}") print(f"错误: 数据库连接失败: {e}") return - + # 运行测试 try: result = asyncio.run( @@ -421,7 +424,7 @@ def main() -> None: max_iterations=max_iterations, ) ) - + # 如果指定了输出文件,保存结果 if args.output: # 将thinking_steps转换为可序列化的格式 @@ -429,7 +432,7 @@ def main() -> None: with open(args.output, "w", encoding="utf-8") as f: json.dump(output_result, f, ensure_ascii=False, indent=2) print(f"\n[结果已保存] {args.output}") - + except KeyboardInterrupt: print("\n\n[中断] 用户中断测试") except Exception as e: @@ -444,4 +447,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 50a7e627..4f77ad74 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -455,6 +455,7 @@ class ExpressionSelector: expr_obj.save() logger.debug("表达方式激活: 更新last_active_time in db") + try: expression_selector = ExpressionSelector() except Exception as e: diff --git a/src/bw_learner/jargon_explainer.py b/src/bw_learner/jargon_explainer.py index df68de15..a99ecd84 100644 --- a/src/bw_learner/jargon_explainer.py +++ b/src/bw_learner/jargon_explainer.py @@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import ( logger = get_logger("jargon") + class JargonExplainer: """黑话解释器,用于在回复前识别和解释上下文中的黑话""" diff --git a/src/bw_learner/learner_utils.py b/src/bw_learner/learner_utils.py index fabf555e..ce3ea379 100644 --- a/src/bw_learner/learner_utils.py +++ b/src/bw_learner/learner_utils.py @@ -60,31 +60,31 @@ def calculate_style_similarity(style1: str, style2: str) -> float: """ 计算两个 style 的相似度,返回0-1之间的值 在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py) - + Args: style1: 第一个 style style2: 第二个 style - + Returns: float: 相似度值,范围0-1 """ if not style1 or not style2: return 0.0 - + # 移除"使用"和"句式"这两个词 def remove_ignored_words(text: str) -> str: """移除需要忽略的词""" text = text.replace("使用", "") text = text.replace("句式", "") return text.strip() - + cleaned_style1 = remove_ignored_words(style1) cleaned_style2 = remove_ignored_words(style2) - + # 如果清理后文本为空,返回0 if not cleaned_style1 or not cleaned_style2: return 0.0 - + return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio() @@ -495,4 +495,4 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]] if content and source_id: jargon_entries.append((content, source_id)) - return expressions, jargon_entries \ No newline at end of file + return expressions, jargon_entries diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder.py index 8e15ab43..39be834f 100644 --- a/src/bw_learner/message_recorder.py +++ b/src/bw_learner/message_recorder.py @@ -2,7 +2,6 @@ import time import asyncio from typing import List, Any from src.common.logger import get_logger -from src.config.config import global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.common_utils import TempMethodsExpression @@ -119,9 +118,7 @@ class MessageRecorder: # 触发 expression_learner 和 jargon_miner 的处理 if self.enable_expression_learning: - asyncio.create_task( - self._trigger_expression_learning(messages) - ) + asyncio.create_task(self._trigger_expression_learning(messages)) except Exception as e: logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") @@ -130,9 +127,7 @@ class MessageRecorder: traceback.print_exc() # 即使失败也保持时间戳更新,避免频繁重试 - async def _trigger_expression_learning( - self, messages: List[Any] - ) -> None: + async def _trigger_expression_learning(self, messages: List[Any]) -> None: """ 触发 expression 学习,使用指定的消息列表 diff --git a/src/chat/brain_chat/PFC/action_planner.py b/src/chat/brain_chat/PFC/action_planner.py index 13b76f8b..7c8bbd79 100644 --- a/src/chat/brain_chat/PFC/action_planner.py +++ b/src/chat/brain_chat/PFC/action_planner.py @@ -1,5 +1,5 @@ import time -from typing import Tuple, Optional, Dict, Any # 增加了 Optional +from typing import Tuple, Optional # 增加了 Optional from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config @@ -120,7 +120,7 @@ class ActionPlanner: def _get_personality_prompt(self) -> str: """获取个性提示信息""" prompt_personality = global_config.personality.personality - + # 检查是否需要随机替换为状态 if ( global_config.personality.states @@ -128,7 +128,7 @@ class ActionPlanner: and random.random() < global_config.personality.state_probability ): prompt_personality = random.choice(global_config.personality.states) - + bot_name = global_config.bot.nickname return f"你的名字是{bot_name},你{prompt_personality};" @@ -170,13 +170,10 @@ class ActionPlanner: ) break else: - logger.debug( - f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。" - ) + logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。") except Exception as e: logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}") - # --- 获取超时提示信息 --- # (这部分逻辑不变) timeout_context = "" diff --git a/src/chat/brain_chat/PFC/conversation.py b/src/chat/brain_chat/PFC/conversation.py index c0ddd285..4fe2f168 100644 --- a/src/chat/brain_chat/PFC/conversation.py +++ b/src/chat/brain_chat/PFC/conversation.py @@ -112,10 +112,10 @@ class Conversation: "user_nickname": msg.user_info.user_nickname if msg.user_info else "", "user_cardname": msg.user_info.user_cardname if msg.user_info else None, "platform": msg.user_info.platform if msg.user_info else "", - } + }, } initial_messages_dict.append(msg_dict) - + # 将加载的消息填充到 ObservationInfo 的 chat_history self.observation_info.chat_history = initial_messages_dict self.observation_info.chat_history_str = chat_talking_prompt + "\n" diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py index 8febe763..a35576cc 100644 --- a/src/chat/brain_chat/PFC/message_sender.py +++ b/src/chat/brain_chat/PFC/message_sender.py @@ -66,9 +66,9 @@ class DirectMessageSender: # 发送消息(直接调用底层 API) from src.chat.message_receive.uni_message_sender import _send_message - + sent = await _send_message(message, show_log=True) - + if sent: # 存储消息 await self.storage.store_message(message, chat_stream) diff --git a/src/chat/brain_chat/PFC/observation_info.py b/src/chat/brain_chat/PFC/observation_info.py index 296505a5..d04d1c0a 100644 --- a/src/chat/brain_chat/PFC/observation_info.py +++ b/src/chat/brain_chat/PFC/observation_info.py @@ -5,7 +5,7 @@ from src.common.logger import get_logger from .chat_observer import ChatObserver from .chat_states import NotificationHandler, NotificationType, Notification from src.chat.utils.chat_message_builder import build_readable_messages -from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo +from src.common.data_models.database_data_model import DatabaseMessages import traceback # 导入 traceback 用于调试 logger = get_logger("observation_info") @@ -13,15 +13,15 @@ logger = get_logger("observation_info") def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages: """Convert PFC dict format to DatabaseMessages object - + Args: msg_dict: Message in PFC dict format with nested user_info - + Returns: DatabaseMessages object compatible with build_readable_messages() """ user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {}) - + return DatabaseMessages( message_id=msg_dict.get("message_id", ""), time=msg_dict.get("time", 0.0), diff --git a/src/chat/brain_chat/PFC/pfc.py b/src/chat/brain_chat/PFC/pfc.py index c618df93..7de6589b 100644 --- a/src/chat/brain_chat/PFC/pfc.py +++ b/src/chat/brain_chat/PFC/pfc.py @@ -42,9 +42,7 @@ class GoalAnalyzer: """对话目标分析器""" def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.planner, request_type="conversation_goal" - ) + self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal") self.personality_info = self._get_personality_prompt() self.name = global_config.bot.nickname @@ -60,7 +58,7 @@ class GoalAnalyzer: def _get_personality_prompt(self) -> str: """获取个性提示信息""" prompt_personality = global_config.personality.personality - + # 检查是否需要随机替换为状态 if ( global_config.personality.states @@ -68,7 +66,7 @@ class GoalAnalyzer: and random.random() < global_config.personality.state_probability ): prompt_personality = random.choice(global_config.personality.states) - + bot_name = global_config.bot.nickname return f"你的名字是{bot_name},你{prompt_personality};" diff --git a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py index afdc6943..67509bd5 100644 --- a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py +++ b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py @@ -1,13 +1,11 @@ from typing import List, Tuple, Dict, Any from src.common.logger import get_logger + # NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned # from src.plugins.memory_system.Hippocampus import HippocampusManager from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.chat.message_receive.message import Message +from src.config.config import model_config from src.chat.knowledge import qa_manager -from src.chat.utils.chat_message_builder import build_readable_messages -from src.chat.brain_chat.PFC.observation_info import dict_to_database_message logger = get_logger("knowledge_fetcher") @@ -16,9 +14,7 @@ class KnowledgeFetcher: """知识调取器""" def __init__(self, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.utils - ) + self.llm = LLMRequest(model_set=model_config.model_task_config.utils) self.private_name = private_name def _lpmm_get_knowledge(self, query: str) -> str: @@ -50,13 +46,7 @@ class KnowledgeFetcher: Returns: Tuple[str, str]: (获取的知识, 知识来源) """ - db_messages = [dict_to_database_message(m) for m in chat_history] - chat_history_text = build_readable_messages( - db_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - ) + _ = chat_history # NOTE: Hippocampus memory system was redesigned in v0.12.2 # The old get_memory_from_text API no longer exists @@ -64,7 +54,7 @@ class KnowledgeFetcher: # TODO: Integrate with new memory system if needed knowledge_text = "" sources_text = "无记忆匹配" # 默认值 - + # # 从记忆中获取相关知识 (DISABLED - old Hippocampus API) # related_memory = await HippocampusManager.get_instance().get_memory_from_text( # text=f"{query}\n{chat_history_text}", diff --git a/src/chat/brain_chat/PFC/reply_checker.py b/src/chat/brain_chat/PFC/reply_checker.py index 0ab88495..c6304b30 100644 --- a/src/chat/brain_chat/PFC/reply_checker.py +++ b/src/chat/brain_chat/PFC/reply_checker.py @@ -14,10 +14,7 @@ class ReplyChecker: """回复检查器""" def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="reply_check" - ) + self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check") self.personality_info = self._get_personality_prompt() self.name = global_config.bot.nickname self.private_name = private_name @@ -27,7 +24,7 @@ class ReplyChecker: def _get_personality_prompt(self) -> str: """获取个性提示信息""" prompt_personality = global_config.personality.personality - + # 检查是否需要随机替换为状态 if ( global_config.personality.states @@ -35,7 +32,7 @@ class ReplyChecker: and random.random() < global_config.personality.state_probability ): prompt_personality = random.choice(global_config.personality.states) - + bot_name = global_config.bot.nickname return f"你的名字是{bot_name},你{prompt_personality};" diff --git a/src/chat/brain_chat/PFC/reply_generator.py b/src/chat/brain_chat/PFC/reply_generator.py index ad801de1..3636fe72 100644 --- a/src/chat/brain_chat/PFC/reply_generator.py +++ b/src/chat/brain_chat/PFC/reply_generator.py @@ -99,7 +99,7 @@ class ReplyGenerator: def _get_personality_prompt(self) -> str: """获取个性提示信息""" prompt_personality = global_config.personality.personality - + # 检查是否需要随机替换为状态 if ( global_config.personality.states @@ -107,7 +107,7 @@ class ReplyGenerator: and random.random() < global_config.personality.state_probability ): prompt_personality = random.choice(global_config.personality.states) - + bot_name = global_config.bot.nickname return f"你的名字是{bot_name},你{prompt_personality};" diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 5be8be63..d1b3c535 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -704,10 +704,7 @@ class BrainChatting: # 等待指定时间,但可被新消息打断 try: - await asyncio.wait_for( - self._new_message_event.wait(), - timeout=wait_seconds - ) + await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds) # 如果事件被触发,说明有新消息到达 logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待") except asyncio.TimeoutError: @@ -731,7 +728,9 @@ class BrainChatting: # 使用默认等待时间 wait_seconds = 3 - logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)") + logger.info( + f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)" + ) # 清除事件状态,准备等待新消息 self._new_message_event.clear() @@ -749,10 +748,7 @@ class BrainChatting: # 等待指定时间,但可被新消息打断 try: - await asyncio.wait_for( - self._new_message_event.wait(), - timeout=wait_seconds - ) + await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds) # 如果事件被触发,说明有新消息到达 logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待") except asyncio.TimeoutError: diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 520b55d1..59fd2394 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -431,15 +431,21 @@ class BrainPlanner: except Exception as req_e: logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}" - return extracted_reasoning, [ - ActionPlannerInfo( - action_type="complete_talk", - reasoning=extracted_reasoning, - action_data={}, - action_message=None, - available_actions=available_actions, - ) - ], llm_content, llm_reasoning, llm_duration_ms + return ( + extracted_reasoning, + [ + ActionPlannerInfo( + action_type="complete_talk", + reasoning=extracted_reasoning, + action_data={}, + action_message=None, + available_actions=available_actions, + ) + ], + llm_content, + llm_reasoning, + llm_duration_ms, + ) # 解析LLM响应 if llm_content: diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 9c460a0d..6d041af6 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -105,7 +105,7 @@ class EmbeddingStore: self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json" - + self.dirty = False # 标记是否有新增数据需要重建索引 # 多线程配置参数验证和设置 diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index 6d345ec9..d7413bdc 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,7 +1,7 @@ import asyncio import json import time -from typing import List, Union, Dict, Any +from typing import List, Union from .global_logger import logger from . import prompt_template @@ -192,17 +192,15 @@ class IEProcess: results = [] total = len(paragraphs) - + for i, pg in enumerate(paragraphs, start=1): # 打印进度日志,让用户知道没有卡死 logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...") - + # 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁 # 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行 try: - entities, triples = await asyncio.to_thread( - info_extract_from_str, self.llm_ner, self.llm_rdf, pg - ) + entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg) if entities is not None: results.append( diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 8108c296..7bd48a26 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -395,8 +395,7 @@ class KGManager: appear_cnt = self.ent_appear_cnt.get(ent_hash) if not appear_cnt or appear_cnt <= 0: logger.debug( - f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0," - f"将使用 1.0 作为默认出现次数参与权重计算" + f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,将使用 1.0 作为默认出现次数参与权重计算" ) appear_cnt = 1.0 ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt) diff --git a/src/chat/knowledge/lpmm_ops.py b/src/chat/knowledge/lpmm_ops.py index a794c5ea..acaac4ca 100644 --- a/src/chat/knowledge/lpmm_ops.py +++ b/src/chat/knowledge/lpmm_ops.py @@ -11,31 +11,30 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up logger = get_logger("LPMM-Plugin-API") + class LPMMOperations: """ LPMM 内部操作接口。 封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。 """ - + def __init__(self): self._initialized = False - async def _run_cancellable_executor( - self, func: Callable, *args, **kwargs - ) -> Any: + async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any: """ 在线程池中执行可取消的同步操作。 当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。 注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。 - + Args: func: 要执行的同步函数 *args: 函数的位置参数 **kwargs: 函数的关键字参数 - + Returns: 函数的返回值 - + Raises: asyncio.CancelledError: 当任务被取消时 """ @@ -51,42 +50,42 @@ class LPMMOperations: # 如果全局没初始化,尝试初始化 if not global_config.lpmm_knowledge.enable: logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。") - + lpmm_start_up() qa_mgr = get_qa_manager() - + if qa_mgr is None: raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。") - + return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr async def add_content(self, text: str, auto_split: bool = True) -> dict: """ 向知识库添加新内容。 - + Args: text: 原始文本。 auto_split: 是否自动按双换行符分割段落。 - True: 自动分割(默认),支持多段文本(用双换行分隔) - False: 不分割,将整个文本作为完整一段处理 - + Returns: dict: {"status": "success/error", "count": 导入段落数, "message": "描述"} """ try: embed_mgr, kg_mgr, _ = await self._get_managers() - + # 1. 分段处理 if auto_split: # 自动按双换行符分割 - paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] + paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] else: # 不分割,作为完整一段 text_stripped = text.strip() if not text_stripped: return {"status": "error", "message": "文本内容为空"} paragraphs = [text_stripped] - + if not paragraphs: return {"status": "error", "message": "文本内容为空"} @@ -94,14 +93,16 @@ class LPMMOperations: from src.chat.knowledge.ie_process import IEProcess from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - - llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract") + + llm_ner = LLMRequest( + model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" + ) llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") ie_process = IEProcess(llm_ner, llm_rdf) - + logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...") extracted_docs = await ie_process.process_paragraphs(paragraphs) - + # 3. 构造并导入数据 # 这里我们手动实现导入逻辑,不依赖外部脚本 # a. 准备段落 @@ -115,7 +116,7 @@ class LPMMOperations: # store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本 new_raw_paragraphs = {} new_triple_list_data = {} - + for pg_hash, passage in raw_paragraphs.items(): key = f"paragraph-{pg_hash}" if key not in embed_mgr.stored_pg_hashes: @@ -128,26 +129,22 @@ class LPMMOperations: # 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入 # store_new_data_set 会自动处理嵌入生成和存储 # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - await self._run_cancellable_executor( - embed_mgr.store_new_data_set, - new_raw_paragraphs, - new_triple_list_data - ) + await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data) # 3. 构建知识图谱(只需要三元组数据和embedding_manager) - await self._run_cancellable_executor( - kg_mgr.build_kg, - new_triple_list_data, - embed_mgr - ) + await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr) # 4. 持久化 await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file) - - return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"} - + + return { + "status": "success", + "count": len(new_raw_paragraphs), + "message": f"成功导入 {len(new_raw_paragraphs)} 条知识", + } + except asyncio.CancelledError: logger.warning("[Plugin API] 导入操作被用户中断") return {"status": "cancelled", "message": "导入操作已被用户中断"} @@ -158,11 +155,11 @@ class LPMMOperations: async def search(self, query: str, top_k: int = 3) -> List[str]: """ 检索知识库。 - + Args: query: 查询问题。 top_k: 返回最相关的条目数。 - + Returns: List[str]: 相关文段列表。 """ @@ -179,21 +176,21 @@ class LPMMOperations: async def delete(self, keyword: str, exact_match: bool = False) -> dict: """ 根据关键词或完整文段删除知识库内容。 - + Args: keyword: 匹配关键词或完整文段。 exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。 - + Returns: dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"} """ try: embed_mgr, kg_mgr, _ = await self._get_managers() - + # 1. 查找匹配的段落 to_delete_keys = [] to_delete_hashes = [] - + for key, item in embed_mgr.paragraphs_embedding_store.store.items(): if exact_match: # 完整文段匹配 @@ -205,29 +202,25 @@ class LPMMOperations: if keyword in item.str: to_delete_keys.append(key) to_delete_hashes.append(key.replace("paragraph-", "", 1)) - + if not to_delete_keys: match_type = "完整文段" if exact_match else "关键词" return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"} # 2. 执行删除 # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - + # a. 从向量库删除 deleted_count, _ = await self._run_cancellable_executor( - embed_mgr.paragraphs_embedding_store.delete_items, - to_delete_keys + embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys ) embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys()) - + # b. 从知识图谱删除 # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs delete_func = partial( - kg_mgr.delete_paragraphs, - to_delete_hashes, - ent_hashes=None, - remove_orphan_entities=True + kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True ) await self._run_cancellable_executor(delete_func) @@ -235,9 +228,13 @@ class LPMMOperations: await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file) - + match_type = "完整文段" if exact_match else "关键词" - return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"} + return { + "status": "success", + "deleted_count": deleted_count, + "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)", + } except asyncio.CancelledError: logger.warning("[Plugin API] 删除操作被用户中断") @@ -249,13 +246,13 @@ class LPMMOperations: async def clear_all(self) -> dict: """ 清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。 - + Returns: dict: {"status": "success/error", "message": "描述", "stats": {...}} """ try: embed_mgr, kg_mgr, _ = await self._get_managers() - + # 记录清空前的统计信息 before_stats = { "paragraphs": len(embed_mgr.paragraphs_embedding_store.store), @@ -264,40 +261,37 @@ class LPMMOperations: "kg_nodes": len(kg_mgr.graph.get_node_list()), "kg_edges": len(kg_mgr.graph.get_edge_list()), } - + # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - + # 1. 清空所有向量库 # 获取所有keys para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys()) ent_keys = list(embed_mgr.entities_embedding_store.store.keys()) rel_keys = list(embed_mgr.relation_embedding_store.store.keys()) - + # 删除所有段落向量 para_deleted, _ = await self._run_cancellable_executor( - embed_mgr.paragraphs_embedding_store.delete_items, - para_keys + embed_mgr.paragraphs_embedding_store.delete_items, para_keys ) embed_mgr.stored_pg_hashes.clear() - + # 删除所有实体向量 if ent_keys: ent_deleted, _ = await self._run_cancellable_executor( - embed_mgr.entities_embedding_store.delete_items, - ent_keys + embed_mgr.entities_embedding_store.delete_items, ent_keys ) else: ent_deleted = 0 - + # 删除所有关系向量 if rel_keys: rel_deleted, _ = await self._run_cancellable_executor( - embed_mgr.relation_embedding_store.delete_items, - rel_keys + embed_mgr.relation_embedding_store.delete_items, rel_keys ) else: rel_deleted = 0 - + # 2. 清空所有 embedding store 的索引和映射 # 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件 def _clear_embedding_indices(): @@ -310,7 +304,7 @@ class LPMMOperations: os.remove(embed_mgr.paragraphs_embedding_store.index_file_path) if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path): os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path) - + # 清空实体索引 embed_mgr.entities_embedding_store.faiss_index = None embed_mgr.entities_embedding_store.idx2hash = None @@ -320,7 +314,7 @@ class LPMMOperations: os.remove(embed_mgr.entities_embedding_store.index_file_path) if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path): os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path) - + # 清空关系索引 embed_mgr.relation_embedding_store.faiss_index = None embed_mgr.relation_embedding_store.idx2hash = None @@ -330,9 +324,9 @@ class LPMMOperations: os.remove(embed_mgr.relation_embedding_store.index_file_path) if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path): os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path) - + await self._run_cancellable_executor(_clear_embedding_indices) - + # 3. 清空知识图谱 # 获取所有段落hash all_pg_hashes = list(kg_mgr.stored_paragraph_hashes) @@ -341,24 +335,22 @@ class LPMMOperations: # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs delete_func = partial( - kg_mgr.delete_paragraphs, - all_pg_hashes, - ent_hashes=None, - remove_orphan_entities=True + kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True ) await self._run_cancellable_executor(delete_func) - + # 完全清空KG:创建新的空图(无论是否有段落hash都要执行) from quick_algo import di_graph + kg_mgr.graph = di_graph.DiGraph() kg_mgr.stored_paragraph_hashes.clear() kg_mgr.ent_appear_cnt.clear() - + # 4. 保存所有数据(此时所有store都是空的,索引也是None) # 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的 await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file) - + after_stats = { "paragraphs": len(embed_mgr.paragraphs_embedding_store.store), "entities": len(embed_mgr.entities_embedding_store.store), @@ -366,14 +358,14 @@ class LPMMOperations: "kg_nodes": len(kg_mgr.graph.get_node_list()), "kg_edges": len(kg_mgr.graph.get_edge_list()), } - + return { "status": "success", "message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)", "stats": { "before": before_stats, "after": after_stats, - } + }, } except asyncio.CancelledError: @@ -383,6 +375,6 @@ class LPMMOperations: logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True) return {"status": "error", "message": str(e)} + # 内部使用的单例 lpmm_ops = LPMMOperations() - diff --git a/src/chat/logger/plan_reply_logger.py b/src/chat/logger/plan_reply_logger.py index d81101f0..c0891196 100644 --- a/src/chat/logger/plan_reply_logger.py +++ b/src/chat/logger/plan_reply_logger.py @@ -136,4 +136,3 @@ class PlanReplyLogger: return str(value) # Fallback to string for other complex types return str(value) - diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 13a538bb..7dbf3688 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -85,17 +85,17 @@ class ChatBot: async def _create_pfc_chat(self, message: MessageRecv): """创建或获取PFC对话实例 - + Args: message: 消息对象 """ try: chat_id = str(message.chat_stream.stream_id) private_name = str(message.message_info.user_info.user_nickname) - + logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}") await self.pfc_manager.get_or_create_conversation(chat_id, private_name) - + except Exception as e: logger.error(f"创建PFC聊天失败: {e}") logger.error(traceback.format_exc()) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 7769b022..86122b9a 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -96,7 +96,7 @@ class Message(MessageBase): if processed_text: return f"{global_config.bot.nickname}: {processed_text}" return None - + tasks = [process_forward_node(node_dict) for node_dict in segment.data] results = await asyncio.gather(*tasks, return_exceptions=True) segments_text = [] diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 9d939395..dfc23e4b 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: # 如果未开启 API Server,直接跳过 Fallback if not global_config.maim_message.enable_api_server: - logger.debug(f"[API Server Fallback] API Server未开启,跳过fallback") + logger.debug("[API Server Fallback] API Server未开启,跳过fallback") if legacy_exception: raise legacy_exception return False @@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: extra_server = getattr(global_api, "extra_server", None) if not extra_server: - logger.warning(f"[API Server Fallback] extra_server不存在") + logger.warning("[API Server Fallback] extra_server不存在") if legacy_exception: raise legacy_exception return False if not extra_server.is_running(): - logger.warning(f"[API Server Fallback] extra_server未运行") + logger.warning("[API Server Fallback] extra_server未运行") if legacy_exception: raise legacy_exception return False @@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: ) # 直接调用 Server 的 send_message 接口,它会自动处理路由 - logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...") + logger.debug("[API Server Fallback] 正在通过extra_server发送消息...") results = await extra_server.send_message(api_message) logger.debug(f"[API Server Fallback] 发送结果: {results}") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index cff0bd55..87c42c84 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -35,6 +35,7 @@ logger = get_logger("planner") install(extra_lines=3) + class ActionPlanner: def __init__(self, chat_id: str, action_manager: ActionManager): self.chat_id = chat_id @@ -48,7 +49,7 @@ class ActionPlanner: self.last_obs_time_mark = 0.0 self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = [] - + # 黑话缓存:使用 OrderedDict 实现 LRU,最多缓存10个 self.unknown_words_cache: OrderedDict[str, None] = OrderedDict() self.unknown_words_cache_limit = 10 @@ -111,20 +112,29 @@ class ActionPlanner: # 替换 [picid:xxx] 为 [图片:描述] pic_pattern = r"\[picid:([^\]]+)\]" + def replace_pic_id(pic_match: re.Match) -> str: pic_id = pic_match.group(1) description = translate_pid_to_description(pic_id) return f"[图片:{description}]" + msg_text = re.sub(pic_pattern, replace_pic_id, msg_text) # 替换用户引用格式:回复 和 @ - platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq" + platform = ( + getattr(message, "user_info", None) + and message.user_info.platform + or getattr(message, "chat_info", None) + and message.chat_info.platform + or "qq" + ) msg_text = replace_user_references(msg_text, platform, replace_bot_name=True) # 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式) # 匹配所有 格式,由于 replace_user_references 已经替换了回复<和@<格式, # 这里匹配到的应该都是单独的格式 user_ref_pattern = r"<([^:<>]+):([^:<>]+)>" + def replace_user_ref(user_match: re.Match) -> str: user_name = user_match.group(1) user_id = user_match.group(2) @@ -137,6 +147,7 @@ class ActionPlanner: except Exception: # 如果解析失败,使用原始昵称 return user_name + msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text) preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..." @@ -165,7 +176,7 @@ class ActionPlanner: else: reasoning = "未提供原因" action_data = {key: value for key, value in action_json.items() if key not in ["action"]} - + # 非no_reply动作需要target_message_id target_message = None @@ -244,7 +255,7 @@ class ActionPlanner: def _update_unknown_words_cache(self, new_words: List[str]) -> None: """ 更新黑话缓存,将新的黑话加入缓存 - + Args: new_words: 新提取的黑话列表 """ @@ -254,7 +265,7 @@ class ActionPlanner: word = word.strip() if not word: continue - + # 如果已存在,移到末尾(LRU) if word in self.unknown_words_cache: self.unknown_words_cache.move_to_end(word) @@ -269,10 +280,10 @@ class ActionPlanner: def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]: """ 合并新提取的黑话和缓存中的黑话 - + Args: new_words: 新提取的黑话列表(可能为None) - + Returns: 合并后的黑话列表(去重) """ @@ -284,31 +295,29 @@ class ActionPlanner: word = word.strip() if word: cleaned_new_words.append(word) - + # 获取缓存中的黑话列表 cached_words = list(self.unknown_words_cache.keys()) - + # 合并并去重(保留顺序:新提取的在前,缓存的在后) merged_words: List[str] = [] seen = set() - + # 先添加新提取的 for word in cleaned_new_words: if word not in seen: merged_words.append(word) seen.add(word) - + # 再添加缓存的(如果不在新提取的列表中) for word in cached_words: if word not in seen: merged_words.append(word) seen.add(word) - + return merged_words - def _process_unknown_words_cache( - self, actions: List[ActionPlannerInfo] - ) -> None: + def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None: """ 处理黑话缓存逻辑: 1. 检查是否有 reply action 提取了 unknown_words @@ -316,7 +325,7 @@ class ActionPlanner: 3. 如果缓存数量大于5,移除最老的2个 4. 对于每个 reply action,合并缓存和新提取的黑话 5. 更新缓存 - + Args: actions: 解析后的动作列表 """ @@ -330,7 +339,7 @@ class ActionPlanner: removed_count += 1 if removed_count > 0: logger.debug(f"{self.log_prefix}缓存数量大于5,移除最老的{removed_count}个缓存") - + # 检查是否有 reply action 提取了 unknown_words has_extracted_unknown_words = False for action in actions: @@ -340,22 +349,22 @@ class ActionPlanner: if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0: has_extracted_unknown_words = True break - + # 如果当前 plan 的 reply 没有提取,移除最老的1个 if not has_extracted_unknown_words: if len(self.unknown_words_cache) > 0: self.unknown_words_cache.popitem(last=False) logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存") - + # 对于每个 reply action,合并缓存和新提取的黑话 for action in actions: if action.action_type == "reply": action_data = action.action_data or {} new_words = action_data.get("unknown_words") - + # 合并新提取的和缓存的黑话列表 merged_words = self._merge_unknown_words_with_cache(new_words) - + # 更新 action_data if merged_words: action_data["unknown_words"] = merged_words @@ -366,7 +375,7 @@ class ActionPlanner: else: # 如果没有合并后的黑话,移除 unknown_words 字段 action_data.pop("unknown_words", None) - + # 更新缓存(将新提取的黑话加入缓存) if new_words: self._update_unknown_words_cache(new_words) @@ -442,15 +451,19 @@ class ActionPlanner: # 检查是否已经有回复该消息的 action has_reply_to_force_message = False for action in actions: - if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id: + if ( + action.action_type == "reply" + and action.action_message + and action.action_message.message_id == force_reply_message.message_id + ): has_reply_to_force_message = True break - + # 如果没有回复该消息,强制添加回复 action if not has_reply_to_force_message: # 移除所有 no_reply action(如果有) actions = [a for a in actions if a.action_type != "no_reply"] - + # 创建强制回复 action available_actions_dict = dict(current_available_actions) force_reply_action = ActionPlannerInfo( @@ -577,10 +590,11 @@ class ActionPlanner: if global_config.chat.think_mode == "classic": reply_action_example = "" if global_config.chat.llm_quote: - reply_action_example += "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n" + reply_action_example += ( + "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n" + ) reply_action_example += ( - '{{"action":"reply", "target_message_id":"消息id(m+数字)", ' - '"unknown_words":["词语1","词语2"]' + '{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]' ) if global_config.chat.llm_quote: reply_action_example += ', "quote":"如果需要引用该message,设置为true"' @@ -590,7 +604,9 @@ class ActionPlanner: "5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n" ) if global_config.chat.llm_quote: - reply_action_example += "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n" + reply_action_example += ( + "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n" + ) reply_action_example += ( '{{"action":"reply", "think_level":数值等级(0或1), ' '"target_message_id":"消息id(m+数字)", ' @@ -741,15 +757,21 @@ class ActionPlanner: except Exception as req_e: logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") - return f"LLM 请求失败,模型出现问题: {req_e}", [ - ActionPlannerInfo( - action_type="no_reply", - reasoning=f"LLM 请求失败,模型出现问题: {req_e}", - action_data={}, - action_message=None, - available_actions=available_actions, - ) - ], llm_content, llm_reasoning, llm_duration_ms + return ( + f"LLM 请求失败,模型出现问题: {req_e}", + [ + ActionPlannerInfo( + action_type="no_reply", + reasoning=f"LLM 请求失败,模型出现问题: {req_e}", + action_data={}, + action_message=None, + available_actions=available_actions, + ) + ], + llm_content, + llm_reasoning, + llm_duration_ms, + ) # 解析LLM响应 extracted_reasoning = "" diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 2fff9773..98ecfcbb 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1071,7 +1071,6 @@ class DefaultReplyer: chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2") chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt) - # 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换 reply_style = global_config.personality.reply_style multi_styles = global_config.personality.multiple_reply_style diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 9c3d78d5..30ed14a9 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import ( ) from src.bw_learner.expression_selector import expression_selector from src.plugin_system.apis.message_api import translate_pid_to_description + # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person, is_person_known from src.plugin_system.base.component_types import ActionInfo, EventType @@ -807,7 +808,7 @@ class PrivateReplyer: reply_style = global_config.personality.reply_style # 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI) - + if is_bot_self(platform, user_id): prompt_template = prompt_manager.get_prompt("private_replyer_self") prompt_template.add_context("target", target) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 11113f97..86ac7da4 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -519,7 +519,7 @@ def _build_readable_messages_internal( output_lines: List[str] = [] prev_timestamp: Optional[float] = None - for timestamp, name, content, is_action in detailed_message: + for timestamp, name, content, _is_action in detailed_message: # 检查是否需要插入长时间间隔提示 if long_time_notice and prev_timestamp is not None: time_diff = timestamp - prev_timestamp diff --git a/src/chat/utils/common_utils.py b/src/chat/utils/common_utils.py index e751381e..bffd5557 100644 --- a/src/chat/utils/common_utils.py +++ b/src/chat/utils/common_utils.py @@ -5,6 +5,7 @@ from src.common.logger import get_logger logger = get_logger("common_utils") + class TempMethodsExpression: """用于临时存放一些方法的类""" diff --git a/src/common/data_models/chat_session_data_model.py b/src/common/data_models/chat_session_data_model.py index fe9a04c7..33fb6cde 100644 --- a/src/common/data_models/chat_session_data_model.py +++ b/src/common/data_models/chat_session_data_model.py @@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession from . import BaseDatabaseDataModel + class MaiChatSession(BaseDatabaseDataModel[ChatSession]): def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None): self.session_id = session_id @@ -33,4 +34,4 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]): platform=self.platform, user_id=self.user_id, group_id=self.group_id, - ) \ No newline at end of file + ) diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py index 938e0013..63a30dc0 100644 --- a/src/common/data_models/message_data_model.py +++ b/src/common/data_models/message_data_model.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum from typing import Any, Iterable, List, Optional, Tuple, Union diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index f6df0d4c..ee70b4b7 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -221,5 +221,7 @@ if not supports_truecolor(): CONVERTED_MODULE_COLORS[name] = escape_str else: for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items(): - escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold) - CONVERTED_MODULE_COLORS[name] = escape_str \ No newline at end of file + escape_str = rgb_pair_to_ansi_truecolor( + hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold + ) + CONVERTED_MODULE_COLORS[name] = escape_str diff --git a/src/common/message_server/api.py b/src/common/message_server/api.py index 67816d81..b280b66e 100644 --- a/src/common/message_server/api.py +++ b/src/common/message_server/api.py @@ -9,6 +9,7 @@ from .server import get_global_server global_api = None + def get_global_api() -> MessageServer: # sourcery skip: extract-method """获取全局MessageServer实例""" global global_api @@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method api_logger.warning(f"Rejected connection with invalid API Key: {api_key}") return False - server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了 + server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了 # 3. Setup Message Bridge # Initialize refined route map if not exists if not hasattr(global_api, "platform_map"): - global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法 + global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法 async def bridge_message_handler(message: APIMessageBase, metadata: dict): # 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase @@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'") if platform: - global_api.platform_map[platform] = api_key # type: ignore + global_api.platform_map[platform] = api_key # type: ignore api_logger.info(f"Updated platform_map: {platform} -> {api_key}") except Exception as e: api_logger.warning(f"Failed to update platform map: {e}") @@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method if "raw_message" not in msg_dict: msg_dict["raw_message"] = None - await global_api.process_message(msg_dict) # type: ignore + await global_api.process_message(msg_dict) # type: ignore - server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了 + server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了 # 3.5. Register custom message handlers (bridge to Legacy handlers) # message_id_echo: handles message ID echo from adapters # 兼容新旧两个版本的 maim_message: # - 旧版: handler(payload) # - 新版: handler(payload, metadata) - async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore + async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore # Bridge to the Legacy custom handler registered in main.py try: # The Legacy handler expects the payload format directly if hasattr(global_api, "_custom_message_handlers"): - handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了 + handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了 if handler: await handler(payload) api_logger.debug(f"Processed message_id_echo: {payload}") @@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method except Exception as e: api_logger.warning(f"Failed to process message_id_echo: {e}") - server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了 + server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了 # 4. Initialize Server extra_server = WebSocketServer(config=server_config) @@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method global_api.stop = patched_stop # Attach for reference - global_api.extra_server = extra_server # type: ignore # 这是什么 + global_api.extra_server = extra_server # type: ignore # 这是什么 except ImportError: get_logger("maim_message").error( diff --git a/src/common/utils/utils_file.py b/src/common/utils/utils_file.py index 1967e1f1..58fab698 100644 --- a/src/common/utils/utils_file.py +++ b/src/common/utils/utils_file.py @@ -9,6 +9,7 @@ from src.common.database.database import get_db_session logger = get_logger("file_utils") + class FileUtils: @staticmethod def save_binary_to_file(file_path: Path, data: bytes): @@ -35,7 +36,7 @@ class FileUtils: except Exception as e: logger.error(f"保存文件 {file_path} 失败: {e}") raise e - + @staticmethod def get_file_path_by_hash(data_hash: str) -> Path: """ @@ -52,4 +53,4 @@ class FileUtils: if binary_data := session.exec(statement).first(): return Path(binary_data.full_path) else: - raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录") \ No newline at end of file + raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录") diff --git a/src/config/legacy_migration.py b/src/config/legacy_migration.py index ccbfb26a..a90c9e2e 100644 --- a/src/config/legacy_migration.py +++ b/src/config/legacy_migration.py @@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult: reason = ",".join(reasons) return MigrationResult(data=data, migrated=migrated_any, reason=reason) - diff --git a/src/dream/dream_agent.py b/src/dream/dream_agent.py index ece9e65a..79acceeb 100644 --- a/src/dream/dream_agent.py +++ b/src/dream/dream_agent.py @@ -86,8 +86,8 @@ def init_dream_tools(chat_id: str) -> None: finish_maintenance = make_finish_maintenance(chat_id) search_jargon = make_search_jargon(chat_id) - delete_jargon = make_delete_jargon(chat_id) - update_jargon = make_update_jargon(chat_id) + _delete_jargon = make_delete_jargon(chat_id) + _update_jargon = make_update_jargon(chat_id) _dream_tool_registry.register_tool( DreamTool( diff --git a/src/dream/dream_generator.py b/src/dream/dream_generator.py index dd73a85e..316cac99 100644 --- a/src/dream/dream_generator.py +++ b/src/dream/dream_generator.py @@ -54,8 +54,6 @@ async def generate_dream_summary( ) -> None: """生成梦境总结,输出到日志,并根据配置可选地推送给指定用户""" try: - - # 第一步:建立工具调用结果映射 (call_id -> result) tool_results_map: dict[str, str] = {} for msg in conversation_messages: diff --git a/src/dream/tools/__init__.py b/src/dream/tools/__init__.py index da5c237e..cd784b02 100644 --- a/src/dream/tools/__init__.py +++ b/src/dream/tools/__init__.py @@ -4,4 +4,3 @@ dream agent 工具实现模块。 每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数 生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。 """ - diff --git a/src/dream/tools/create_chat_history_tool.py b/src/dream/tools/create_chat_history_tool.py index 9d423b45..30e1b0e4 100644 --- a/src/dream/tools/create_chat_history_tool.py +++ b/src/dream/tools/create_chat_history_tool.py @@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str): return f"create_chat_history 执行失败: {e}" return create_chat_history - diff --git a/src/dream/tools/delete_chat_history_tool.py b/src/dream/tools/delete_chat_history_tool.py index 6a14d57f..18c32f27 100644 --- a/src/dream/tools/delete_chat_history_tool.py +++ b/src/dream/tools/delete_chat_history_tool.py @@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用, return f"delete_chat_history 执行失败: {e}" return delete_chat_history - diff --git a/src/dream/tools/delete_jargon_tool.py b/src/dream/tools/delete_jargon_tool.py index cfab51aa..8edd3245 100644 --- a/src/dream/tools/delete_jargon_tool.py +++ b/src/dream/tools/delete_jargon_tool.py @@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留 return f"delete_jargon 执行失败: {e}" return delete_jargon - diff --git a/src/dream/tools/finish_maintenance_tool.py b/src/dream/tools/finish_maintenance_tool.py index c2dda545..403b6c6e 100644 --- a/src/dream/tools/finish_maintenance_tool.py +++ b/src/dream/tools/finish_maintenance_tool.py @@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用, return msg return finish_maintenance - diff --git a/src/dream/tools/get_chat_history_detail_tool.py b/src/dream/tools/get_chat_history_detail_tool.py index 5cf01955..81d63284 100644 --- a/src/dream/tools/get_chat_history_detail_tool.py +++ b/src/dream/tools/get_chat_history_detail_tool.py @@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用 return f"get_chat_history_detail 执行失败: {e}" return get_chat_history_detail - diff --git a/src/dream/tools/search_chat_history_tool.py b/src/dream/tools/search_chat_history_tool.py index 1de0f253..5d216f00 100644 --- a/src/dream/tools/search_chat_history_tool.py +++ b/src/dream/tools/search_chat_history_tool.py @@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str): return f"search_chat_history 执行失败: {e}" return search_chat_history - diff --git a/src/dream/tools/update_chat_history_tool.py b/src/dream/tools/update_chat_history_tool.py index 1797c714..acb98d45 100644 --- a/src/dream/tools/update_chat_history_tool.py +++ b/src/dream/tools/update_chat_history_tool.py @@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用, return f"update_chat_history 执行失败: {e}" return update_chat_history - diff --git a/src/dream/tools/update_jargon_tool.py b/src/dream/tools/update_jargon_tool.py index 59ea1230..1d559cf6 100644 --- a/src/dream/tools/update_jargon_tool.py +++ b/src/dream/tools/update_jargon_tool.py @@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留 return f"update_jargon 执行失败: {e}" return update_jargon - diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 65af170e..07dd66a2 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -458,8 +458,8 @@ def _default_normal_response_parser( if not isinstance(arguments, dict): # 此时为了调试方便,建议打印出 arguments 的类型 raise RespParseException( - resp, - f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}" + resp, + f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}", ) api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) except json.JSONDecodeError as e: diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index b356663a..8d85964a 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -2,7 +2,7 @@ import time import json import asyncio from datetime import datetime -from typing import List, Dict, Any, Optional, Tuple, Callable, cast +from typing import List, Dict, Any, Optional, Tuple, Callable from src.common.logger import get_logger from src.config.config import global_config, model_config from src.prompt.prompt_manager import prompt_manager @@ -34,7 +34,7 @@ def _cleanup_stale_not_found_thinking_back() -> None: try: with get_db_session() as session: statement = select(ThinkingQuestion).where( - (ThinkingQuestion.found_answer == False) + col(ThinkingQuestion.found_answer).is_(False) & (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time)) ) records = session.exec(statement).all() @@ -786,8 +786,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) str: 格式化的查询历史字符串 """ try: - current_time = time.time() - start_time = current_time - time_window_seconds + _current_time = time.time() with get_db_session() as session: statement = ( @@ -838,15 +837,14 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx" """ try: - current_time = time.time() - start_time = current_time - time_window_seconds + _current_time = time.time() # 查询最近时间窗口内已找到答案的记录,按更新时间倒序 with get_db_session() as session: statement = ( select(ThinkingQuestion) .where(col(ThinkingQuestion.context) == chat_id) - .where(col(ThinkingQuestion.found_answer) == True) + .where(col(ThinkingQuestion.found_answer)) .where(col(ThinkingQuestion.answer).is_not(None)) .where(col(ThinkingQuestion.answer) != "") .order_by(col(ThinkingQuestion.updated_timestamp).desc()) diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index 4a7adfa1..0a9f502f 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -105,25 +105,27 @@ async def search_chat_history( # 检查参数 if not keyword and not participant and not start_time and not end_time: return "未指定查询参数(需要提供keyword、participant、start_time或end_time之一)" - + # 解析时间参数 start_timestamp = None end_timestamp = None - + if start_time: try: from src.memory_system.memory_utils import parse_datetime_to_timestamp + start_timestamp = parse_datetime_to_timestamp(start_time) except ValueError as e: return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'" - + if end_time: try: from src.memory_system.memory_utils import parse_datetime_to_timestamp + end_timestamp = parse_datetime_to_timestamp(end_time) except ValueError as e: return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'" - + # 验证时间范围 if start_timestamp and end_timestamp and start_timestamp > end_timestamp: return "开始时间不能晚于结束时间" @@ -158,23 +160,20 @@ async def search_chat_history( f"search_chat_history 当前聊天流在黑名单中,强制使用本地查询,chat_id={chat_id}, keyword={keyword}, participant={participant}" ) query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) - + # 添加时间过滤条件 if start_timestamp is not None and end_timestamp is not None: # 查询指定时间段内的记录(记录的时间范围与查询时间段有交集) # 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段 query = query.where( ( - (ChatHistory.start_time >= start_timestamp) - & (ChatHistory.start_time <= end_timestamp) + (ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp) ) # 记录开始时间在查询时间段内 | ( - (ChatHistory.end_time >= start_timestamp) - & (ChatHistory.end_time <= end_timestamp) + (ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp) ) # 记录结束时间在查询时间段内 | ( - (ChatHistory.start_time <= start_timestamp) - & (ChatHistory.end_time >= end_timestamp) + (ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp) ) # 记录完全包含查询时间段 ) logger.debug( @@ -302,7 +301,7 @@ async def search_chat_history( time_desc = f"时间<='{end_str}'" if time_desc: conditions.append(time_desc) - + if conditions: conditions_str = "且".join(conditions) return f"未找到满足条件({conditions_str})的聊天记录" diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py index 2d7a0c60..9bdf8ba2 100644 --- a/src/memory_system/retrieval_tools/query_words.py +++ b/src/memory_system/retrieval_tools/query_words.py @@ -30,7 +30,7 @@ async def query_words(chat_id: str, words: str) -> str: if separator in words: words_list = [w.strip() for w in words.split(separator) if w.strip()] break - + # 如果没有找到分隔符,整个字符串作为一个词语 if not words_list: words_list = [words.strip()] @@ -76,4 +76,3 @@ def register_tool(): ], execute_func=query_words, ) - diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index deaccc4b..af985b96 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -123,7 +123,7 @@ async def generate_reply( # 如果 reply_time_point 未传入,设置为当前时间戳 if reply_time_point is None: reply_time_point = time.time() - + # 获取回复器 logger.debug("[GeneratorAPI] 开始生成回复") replyer = get_replyer(chat_stream, chat_id, request_type=request_type) diff --git a/src/plugin_system/apis/plugin_service_api.py b/src/plugin_system/apis/plugin_service_api.py index 3e64722b..4c783f04 100644 --- a/src/plugin_system/apis/plugin_service_api.py +++ b/src/plugin_system/apis/plugin_service_api.py @@ -39,6 +39,18 @@ def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> return plugin_service_registry.unregister_service(service_name, plugin_name) -async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: +async def call_service( + service_name: str, + *args: Any, + plugin_name: Optional[str] = None, + caller_plugin: Optional[str] = None, + **kwargs: Any, +) -> Any: """调用插件服务。""" - return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs) + return await plugin_service_registry.call_service( + service_name, + *args, + plugin_name=plugin_name, + caller_plugin=caller_plugin, + **kwargs, + ) diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 521ca137..7076af4d 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -558,7 +558,9 @@ class PluginBase(ABC): if version_spec: is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec) if not is_ok: - logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})") + logger.error( + f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})" + ) return False if min_version or max_version: diff --git a/src/plugin_system/base/service_types.py b/src/plugin_system/base/service_types.py index 7d23f7ca..0a789311 100644 --- a/src/plugin_system/base/service_types.py +++ b/src/plugin_system/base/service_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any, Dict, List @dataclass @@ -11,6 +11,8 @@ class PluginServiceInfo: version: str = "1.0.0" description: str = "" enabled: bool = True + public: bool = False + allowed_callers: List[str] = field(default_factory=list) params_schema: Dict[str, Any] = field(default_factory=dict) return_schema: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index af2f852c..7320f09a 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -274,6 +274,23 @@ class ComponentRegistry: logger.error(f"移除组件 {component_name} 时发生错误: {e}") return False + async def remove_components_by_plugin(self, plugin_name: str) -> int: + """移除某插件注册的所有组件。""" + targets = [ + (component_info.name, component_info.component_type) + for component_info in self._components.values() + if component_info.plugin_name == plugin_name + ] + + removed_count = 0 + for component_name, component_type in targets: + if await self.remove_component(component_name, component_type, plugin_name): + removed_count += 1 + + if removed_count: + logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}") + return removed_count + def remove_plugin_registry(self, plugin_name: str) -> bool: """移除插件注册信息 @@ -734,9 +751,7 @@ class ComponentRegistry: "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), "workflow_steps": workflow_step_count, "enabled_workflow_steps": enabled_workflow_step_count, - "workflow_steps_by_stage": { - stage.value: len(steps) for stage, steps in self._workflow_steps.items() - }, + "workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()}, } diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 9c8664ba..62803ee6 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -200,13 +200,43 @@ class PluginManager: """ 重载插件模块 """ + old_instance = self.loaded_plugins.get(plugin_name) + if not old_instance: + logger.warning(f"插件 {plugin_name} 未加载,无法重载") + return False + if not await self.remove_registered_plugin(plugin_name): return False + if not self.load_registered_plugin_classes(plugin_name)[0]: + logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例") + rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance) + if rollback_ok: + logger.info(f"插件 {plugin_name} 已回滚到旧版本实例") + else: + logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用") return False + logger.debug(f"插件 {plugin_name} 重载成功") return True + async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool: + """重载失败后回滚旧实例。""" + try: + await component_registry.remove_components_by_plugin(plugin_name) + component_registry.remove_plugin_registry(plugin_name) + plugin_service_registry.remove_services_by_plugin(plugin_name) + + if not old_instance.register_plugin(): + logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败") + return False + + self.loaded_plugins[plugin_name] = old_instance + return True + except Exception as e: + logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True) + return False + def rescan_plugin_directory(self) -> Tuple[int, int]: """ 重新扫描插件根目录 @@ -399,7 +429,9 @@ class PluginManager: def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]: """根据依赖图计算加载顺序,并检测循环依赖。""" - indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()} + indegree: Dict[str, int] = { + plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items() + } reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph} for plugin_name, dependencies in dependency_graph.items(): diff --git a/src/plugin_system/core/plugin_service_registry.py b/src/plugin_system/core/plugin_service_registry.py index 1dbf1233..d94e5b3a 100644 --- a/src/plugin_system/core/plugin_service_registry.py +++ b/src/plugin_system/core/plugin_service_registry.py @@ -26,6 +26,9 @@ class PluginServiceRegistry: if "." in service_info.plugin_name: logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") return False + if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]: + logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}") + return False full_name = service_info.full_name if full_name in self._services: @@ -52,7 +55,9 @@ class PluginServiceRegistry: full_name = self._resolve_full_name(service_name, plugin_name) return self._service_handlers.get(full_name) if full_name else None - def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]: + def list_services( + self, plugin_name: Optional[str] = None, enabled_only: bool = False + ) -> Dict[str, PluginServiceInfo]: """列出插件服务。""" services = self._services.copy() if plugin_name: @@ -103,12 +108,33 @@ class PluginServiceRegistry: logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}") return removed_count - async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: + async def call_service( + self, + service_name: str, + *args: Any, + plugin_name: Optional[str] = None, + caller_plugin: Optional[str] = None, + **kwargs: Any, + ) -> Any: """调用插件服务(支持同步/异步handler)。""" service_info = self.get_service(service_name, plugin_name) if not service_info: target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name raise ValueError(f"插件服务未注册: {target_name}") + + if ( + "." not in service_name + and plugin_name is None + and caller_plugin + and service_info.plugin_name != caller_plugin + ): + raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name") + + if not self._is_call_authorized(service_info, caller_plugin): + raise PermissionError( + f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}" + ) + if not service_info.enabled: raise RuntimeError(f"插件服务已禁用: {service_info.full_name}") @@ -116,8 +142,93 @@ class PluginServiceRegistry: if not handler: raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}") + self._validate_input_contract(service_info, args, kwargs) + result = handler(*args, **kwargs) - return await result if inspect.isawaitable(result) else result + resolved_result = await result if inspect.isawaitable(result) else result + self._validate_output_contract(service_info, resolved_result) + return resolved_result + + def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool: + """检查服务调用权限。""" + if caller_plugin is None: + return service_info.public + if caller_plugin == service_info.plugin_name: + return True + if service_info.public: + return True + allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()} + return "*" in allowed_callers or caller_plugin in allowed_callers + + def _validate_input_contract( + self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + """校验服务入参契约。""" + schema = service_info.params_schema + if not schema: + return + + properties = schema.get("properties", {}) if isinstance(schema, dict) else {} + is_invocation_schema = "args" in properties or "kwargs" in properties + + if is_invocation_schema: + payload = {"args": list(args), "kwargs": kwargs} + self._validate_by_schema(payload, schema, path="params") + return + + if args: + raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数") + self._validate_by_schema(kwargs, schema, path="params") + + def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None: + """校验服务返回值契约。""" + if not service_info.return_schema: + return + self._validate_by_schema(value, service_info.return_schema, path="return") + + def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None: + """基于简化JSON-Schema校验数据。""" + expected_type = schema.get("type") + if expected_type: + self._validate_type(value, expected_type, path) + + enum_values = schema.get("enum") + if enum_values is not None and value not in enum_values: + raise ValueError(f"{path} 不在枚举范围内: {value}") + + if expected_type == "object": + properties = schema.get("properties", {}) + required = schema.get("required", []) + + for field in required: + if field not in value: + raise ValueError(f"{path}.{field} 为必填字段") + + for field, field_value in value.items(): + if field in properties: + self._validate_by_schema(field_value, properties[field], f"{path}.{field}") + elif schema.get("additionalProperties", True) is False: + raise ValueError(f"{path}.{field} 不允许额外字段") + + if expected_type == "array": + if item_schema := schema.get("items"): + for index, item in enumerate(value): + self._validate_by_schema(item, item_schema, f"{path}[{index}]") + + def _validate_type(self, value: Any, expected_type: str, path: str) -> None: + """校验基础类型。""" + type_checkers: Dict[str, Callable[[Any], bool]] = { + "string": lambda item: isinstance(item, str), + "number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool), + "integer": lambda item: isinstance(item, int) and not isinstance(item, bool), + "boolean": lambda item: isinstance(item, bool), + "object": lambda item: isinstance(item, dict), + "array": lambda item: isinstance(item, list), + "null": lambda item: item is None, + } + checker = type_checkers.get(expected_type) + if checker and not checker(value): + raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}") def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]: """解析服务全名。""" diff --git a/src/plugin_system/core/workflow_engine.py b/src/plugin_system/core/workflow_engine.py index 089d0498..1a3fbcf7 100644 --- a/src/plugin_system/core/workflow_engine.py +++ b/src/plugin_system/core/workflow_engine.py @@ -1,7 +1,8 @@ from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union +import asyncio +import inspect import time import uuid -import inspect from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType, MaiMessages @@ -95,7 +96,9 @@ class WorkflowEngine: except Exception as e: workflow_context.timings[stage_key] = time.perf_counter() - stage_start workflow_context.errors.append(f"{stage_key}: {e}") - logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True) + logger.error( + f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True + ) self._execution_history[workflow_context.trace_id]["status"] = "failed" self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() return ( @@ -144,11 +147,19 @@ class WorkflowEngine: step_timing_key = f"{stage.value}:{step_info.full_name}" step_start = time.perf_counter() + timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None try: - result = handler(context, message) - if inspect.isawaitable(result): - result = await result + if inspect.iscoroutinefunction(handler): + coroutine = handler(context, message) + result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine + else: + if timeout_seconds: + result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds) + else: + result = handler(context, message) + if inspect.isawaitable(result): + result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result context.timings[step_timing_key] = time.perf_counter() - step_start normalized_result = self._normalize_step_result(result) @@ -165,10 +176,30 @@ class WorkflowEngine: normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value) return normalized_result + except asyncio.TimeoutError: + context.timings[step_timing_key] = time.perf_counter() - step_start + timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms" + context.errors.append(f"{step_info.full_name}: {timeout_message}") + logger.error( + f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}" + ) + return WorkflowStepResult( + status="failed", + return_message=timeout_message, + diagnostics={ + "stage": stage.value, + "step": step_info.full_name, + "trace_id": context.trace_id, + "error_code": WorkflowErrorCode.STEP_TIMEOUT.value, + }, + ) + except Exception as e: context.timings[step_timing_key] = time.perf_counter() - step_start context.errors.append(f"{step_info.full_name}: {e}") - logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True) + logger.error( + f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True + ) return WorkflowStepResult( status="failed", return_message=str(e), diff --git a/src/prompt/prompt_manager.py b/src/prompt/prompt_manager.py index 8137c16b..45ef7188 100644 --- a/src/prompt/prompt_manager.py +++ b/src/prompt/prompt_manager.py @@ -117,7 +117,7 @@ class PromptManager: def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None: """ 添加一个上下文构造函数 - + Args: name (str): 上下文名称 func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数,接受 Prompt 名称作为参数,返回字符串或返回字符串的协程 @@ -144,7 +144,7 @@ class PromptManager: def get_prompt(self, prompt_name: str) -> Prompt: """ 获取指定名称的 Prompt 实例的克隆 - + Args: prompt_name (str): 要获取的 Prompt 名称 Returns: @@ -161,7 +161,7 @@ class PromptManager: async def render_prompt(self, prompt: Prompt) -> str: """ 渲染一个 Prompt 实例 - + Args: prompt (Prompt): 要渲染的 Prompt 实例 Returns: diff --git a/src/webui/api/planner.py b/src/webui/api/planner.py index b28e7a0d..981cc9d4 100644 --- a/src/webui/api/planner.py +++ b/src/webui/api/planner.py @@ -7,6 +7,7 @@ 2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容 3. 详情按需加载 """ + import json from pathlib import Path from typing import List, Dict, Optional @@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan") class ChatSummary(BaseModel): """聊天摘要 - 轻量级,不读取文件内容""" + chat_id: str plan_count: int latest_timestamp: float @@ -29,6 +31,7 @@ class ChatSummary(BaseModel): class PlanLogSummary(BaseModel): """规划日志摘要""" + chat_id: str timestamp: float filename: str @@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel): class PlanLogDetail(BaseModel): """规划日志详情""" + type: str chat_id: str timestamp: float @@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel): class PlannerOverview(BaseModel): """规划器总览 - 轻量级统计""" + total_chats: int total_plans: int chats: List[ChatSummary] @@ -61,6 +66,7 @@ class PlannerOverview(BaseModel): class PaginatedChatLogs(BaseModel): """分页的聊天日志列表""" + data: List[PlanLogSummary] total: int page: int @@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel): def parse_timestamp_from_filename(filename: str) -> float: """从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220""" try: - timestamp_str = filename.split('_')[0] + timestamp_str = filename.split("_")[0] # 时间戳是毫秒级,需要转换为秒 return float(timestamp_str) / 1000 except (ValueError, IndexError): @@ -86,41 +92,39 @@ async def get_planner_overview(): """ if not PLAN_LOG_DIR.exists(): return PlannerOverview(total_chats=0, total_plans=0, chats=[]) - + chats = [] total_plans = 0 - + for chat_dir in PLAN_LOG_DIR.iterdir(): if not chat_dir.is_dir(): continue - + # 只统计json文件数量 json_files = list(chat_dir.glob("*.json")) plan_count = len(json_files) total_plans += plan_count - + if plan_count == 0: continue - + # 从文件名获取最新时间戳 latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name)) latest_timestamp = parse_timestamp_from_filename(latest_file.name) - - chats.append(ChatSummary( - chat_id=chat_dir.name, - plan_count=plan_count, - latest_timestamp=latest_timestamp, - latest_filename=latest_file.name - )) - + + chats.append( + ChatSummary( + chat_id=chat_dir.name, + plan_count=plan_count, + latest_timestamp=latest_timestamp, + latest_filename=latest_file.name, + ) + ) + # 按最新时间戳排序 chats.sort(key=lambda x: x.latest_timestamp, reverse=True) - - return PlannerOverview( - total_chats=len(chats), - total_plans=total_plans, - chats=chats - ) + + return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats) @router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs) @@ -128,7 +132,7 @@ async def get_chat_plan_logs( chat_id: str, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), - search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容") + search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"), ): """ 获取指定聊天的规划日志列表(分页) @@ -137,73 +141,69 @@ async def get_chat_plan_logs( """ chat_dir = PLAN_LOG_DIR / chat_id if not chat_dir.exists(): - return PaginatedChatLogs( - data=[], total=0, page=page, page_size=page_size, chat_id=chat_id - ) - + return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id) + # 先获取所有文件并按时间戳排序 json_files = list(chat_dir.glob("*.json")) json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True) - + # 如果有搜索关键词,需要过滤文件 if search: search_lower = search.lower() filtered_files = [] for log_file in json_files: try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) - prompt = data.get('prompt', '') + prompt = data.get("prompt", "") if search_lower in prompt.lower(): filtered_files.append(log_file) except Exception: continue json_files = filtered_files - + total = len(json_files) - + # 分页 - 只读取当前页的文件 offset = (page - 1) * page_size - page_files = json_files[offset:offset + page_size] - + page_files = json_files[offset : offset + page_size] + logs = [] for log_file in page_files: try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) - reasoning = data.get('reasoning', '') - actions = data.get('actions', []) - action_types = [a.get('action_type', '') for a in actions if a.get('action_type')] - logs.append(PlanLogSummary( - chat_id=data.get('chat_id', chat_id), - timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)), - filename=log_file.name, - action_count=len(actions), - action_types=action_types, - total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0), - llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0), - reasoning_preview=reasoning[:100] if reasoning else '' - )) + reasoning = data.get("reasoning", "") + actions = data.get("actions", []) + action_types = [a.get("action_type", "") for a in actions if a.get("action_type")] + logs.append( + PlanLogSummary( + chat_id=data.get("chat_id", chat_id), + timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)), + filename=log_file.name, + action_count=len(actions), + action_types=action_types, + total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0), + llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0), + reasoning_preview=reasoning[:100] if reasoning else "", + ) + ) except Exception: # 文件读取失败时使用文件名信息 - logs.append(PlanLogSummary( - chat_id=chat_id, - timestamp=parse_timestamp_from_filename(log_file.name), - filename=log_file.name, - action_count=0, - action_types=[], - total_plan_ms=0, - llm_duration_ms=0, - reasoning_preview='[读取失败]' - )) - - return PaginatedChatLogs( - data=logs, - total=total, - page=page, - page_size=page_size, - chat_id=chat_id - ) + logs.append( + PlanLogSummary( + chat_id=chat_id, + timestamp=parse_timestamp_from_filename(log_file.name), + filename=log_file.name, + action_count=0, + action_types=[], + total_plan_ms=0, + llm_duration_ms=0, + reasoning_preview="[读取失败]", + ) + ) + + return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id) @router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail) @@ -212,22 +212,23 @@ async def get_log_detail(chat_id: str, filename: str): log_file = PLAN_LOG_DIR / chat_id / filename if not log_file.exists(): raise HTTPException(status_code=404, detail="日志文件不存在") - + try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) return PlanLogDetail(**data) except Exception as e: - raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e # ========== 兼容旧接口 ========== + @router.get("/stats") async def get_planner_stats(): """获取规划器统计信息 - 兼容旧接口""" overview = await get_planner_overview() - + # 获取最近10条计划的摘要 recent_plans = [] for chat in overview.chats[:5]: # 从最近5个聊天中获取 @@ -236,17 +237,17 @@ async def get_planner_stats(): recent_plans.extend(chat_logs.data) except Exception: continue - + # 按时间排序取前10 recent_plans.sort(key=lambda x: x.timestamp, reverse=True) recent_plans = recent_plans[:10] - + return { "total_chats": overview.total_chats, "total_plans": overview.total_plans, "avg_plan_time_ms": 0, "avg_llm_time_ms": 0, - "recent_plans": recent_plans + "recent_plans": recent_plans, } @@ -258,44 +259,43 @@ async def get_chat_list(): @router.get("/all-logs") -async def get_all_logs( - page: int = Query(1, ge=1), - page_size: int = Query(20, ge=1, le=100) -): +async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)): """获取所有规划日志 - 兼容旧接口""" if not PLAN_LOG_DIR.exists(): return {"data": [], "total": 0, "page": page, "page_size": page_size} - + # 收集所有文件 all_files = [] for chat_dir in PLAN_LOG_DIR.iterdir(): if chat_dir.is_dir(): for log_file in chat_dir.glob("*.json"): all_files.append((chat_dir.name, log_file)) - + # 按时间戳排序 all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True) - + total = len(all_files) offset = (page - 1) * page_size - page_files = all_files[offset:offset + page_size] - + page_files = all_files[offset : offset + page_size] + logs = [] for chat_id, log_file in page_files: try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) - reasoning = data.get('reasoning', '') - logs.append({ - "chat_id": data.get('chat_id', chat_id), - "timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)), - "filename": log_file.name, - "action_count": len(data.get('actions', [])), - "total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0), - "llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0), - "reasoning_preview": reasoning[:100] if reasoning else '' - }) + reasoning = data.get("reasoning", "") + logs.append( + { + "chat_id": data.get("chat_id", chat_id), + "timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)), + "filename": log_file.name, + "action_count": len(data.get("actions", [])), + "total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0), + "llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0), + "reasoning_preview": reasoning[:100] if reasoning else "", + } + ) except Exception: continue - - return {"data": logs, "total": total, "page": page, "page_size": page_size} \ No newline at end of file + + return {"data": logs, "total": total, "page": page, "page_size": page_size} diff --git a/src/webui/api/replier.py b/src/webui/api/replier.py index 3ea71286..0643ceb4 100644 --- a/src/webui/api/replier.py +++ b/src/webui/api/replier.py @@ -7,6 +7,7 @@ 2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容 3. 详情按需加载 """ + import json from pathlib import Path from typing import List, Dict, Optional @@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply") class ReplierChatSummary(BaseModel): """聊天摘要 - 轻量级,不读取文件内容""" + chat_id: str reply_count: int latest_timestamp: float @@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel): class ReplyLogSummary(BaseModel): """回复日志摘要""" + chat_id: str timestamp: float filename: str @@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel): class ReplyLogDetail(BaseModel): """回复日志详情""" + type: str chat_id: str timestamp: float @@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel): class ReplierOverview(BaseModel): """回复器总览 - 轻量级统计""" + total_chats: int total_replies: int chats: List[ReplierChatSummary] @@ -64,6 +69,7 @@ class ReplierOverview(BaseModel): class PaginatedReplyLogs(BaseModel): """分页的回复日志列表""" + data: List[ReplyLogSummary] total: int page: int @@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel): def parse_timestamp_from_filename(filename: str) -> float: """从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220""" try: - timestamp_str = filename.split('_')[0] + timestamp_str = filename.split("_")[0] # 时间戳是毫秒级,需要转换为秒 return float(timestamp_str) / 1000 except (ValueError, IndexError): @@ -89,41 +95,39 @@ async def get_replier_overview(): """ if not REPLY_LOG_DIR.exists(): return ReplierOverview(total_chats=0, total_replies=0, chats=[]) - + chats = [] total_replies = 0 - + for chat_dir in REPLY_LOG_DIR.iterdir(): if not chat_dir.is_dir(): continue - + # 只统计json文件数量 json_files = list(chat_dir.glob("*.json")) reply_count = len(json_files) total_replies += reply_count - + if reply_count == 0: continue - + # 从文件名获取最新时间戳 latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name)) latest_timestamp = parse_timestamp_from_filename(latest_file.name) - - chats.append(ReplierChatSummary( - chat_id=chat_dir.name, - reply_count=reply_count, - latest_timestamp=latest_timestamp, - latest_filename=latest_file.name - )) - + + chats.append( + ReplierChatSummary( + chat_id=chat_dir.name, + reply_count=reply_count, + latest_timestamp=latest_timestamp, + latest_filename=latest_file.name, + ) + ) + # 按最新时间戳排序 chats.sort(key=lambda x: x.latest_timestamp, reverse=True) - - return ReplierOverview( - total_chats=len(chats), - total_replies=total_replies, - chats=chats - ) + + return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats) @router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs) @@ -131,7 +135,7 @@ async def get_chat_reply_logs( chat_id: str, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), - search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容") + search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"), ): """ 获取指定聊天的回复日志列表(分页) @@ -140,71 +144,67 @@ async def get_chat_reply_logs( """ chat_dir = REPLY_LOG_DIR / chat_id if not chat_dir.exists(): - return PaginatedReplyLogs( - data=[], total=0, page=page, page_size=page_size, chat_id=chat_id - ) - + return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id) + # 先获取所有文件并按时间戳排序 json_files = list(chat_dir.glob("*.json")) json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True) - + # 如果有搜索关键词,需要过滤文件 if search: search_lower = search.lower() filtered_files = [] for log_file in json_files: try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) - prompt = data.get('prompt', '') + prompt = data.get("prompt", "") if search_lower in prompt.lower(): filtered_files.append(log_file) except Exception: continue json_files = filtered_files - + total = len(json_files) - + # 分页 - 只读取当前页的文件 offset = (page - 1) * page_size - page_files = json_files[offset:offset + page_size] - + page_files = json_files[offset : offset + page_size] + logs = [] for log_file in page_files: try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) - output = data.get('output', '') - logs.append(ReplyLogSummary( - chat_id=data.get('chat_id', chat_id), - timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)), - filename=log_file.name, - model=data.get('model', ''), - success=data.get('success', True), - llm_ms=data.get('timing', {}).get('llm_ms', 0), - overall_ms=data.get('timing', {}).get('overall_ms', 0), - output_preview=output[:100] if output else '' - )) + output = data.get("output", "") + logs.append( + ReplyLogSummary( + chat_id=data.get("chat_id", chat_id), + timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)), + filename=log_file.name, + model=data.get("model", ""), + success=data.get("success", True), + llm_ms=data.get("timing", {}).get("llm_ms", 0), + overall_ms=data.get("timing", {}).get("overall_ms", 0), + output_preview=output[:100] if output else "", + ) + ) except Exception: # 文件读取失败时使用文件名信息 - logs.append(ReplyLogSummary( - chat_id=chat_id, - timestamp=parse_timestamp_from_filename(log_file.name), - filename=log_file.name, - model='', - success=False, - llm_ms=0, - overall_ms=0, - output_preview='[读取失败]' - )) - - return PaginatedReplyLogs( - data=logs, - total=total, - page=page, - page_size=page_size, - chat_id=chat_id - ) + logs.append( + ReplyLogSummary( + chat_id=chat_id, + timestamp=parse_timestamp_from_filename(log_file.name), + filename=log_file.name, + model="", + success=False, + llm_ms=0, + overall_ms=0, + output_preview="[读取失败]", + ) + ) + + return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id) @router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail) @@ -213,35 +213,36 @@ async def get_reply_log_detail(chat_id: str, filename: str): log_file = REPLY_LOG_DIR / chat_id / filename if not log_file.exists(): raise HTTPException(status_code=404, detail="日志文件不存在") - + try: - with open(log_file, 'r', encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: data = json.load(f) return ReplyLogDetail( - type=data.get('type', 'reply'), - chat_id=data.get('chat_id', chat_id), - timestamp=data.get('timestamp', 0), - prompt=data.get('prompt', ''), - output=data.get('output', ''), - processed_output=data.get('processed_output', []), - model=data.get('model', ''), - reasoning=data.get('reasoning', ''), - think_level=data.get('think_level', 0), - timing=data.get('timing', {}), - error=data.get('error'), - success=data.get('success', True) + type=data.get("type", "reply"), + chat_id=data.get("chat_id", chat_id), + timestamp=data.get("timestamp", 0), + prompt=data.get("prompt", ""), + output=data.get("output", ""), + processed_output=data.get("processed_output", []), + model=data.get("model", ""), + reasoning=data.get("reasoning", ""), + think_level=data.get("think_level", 0), + timing=data.get("timing", {}), + error=data.get("error"), + success=data.get("success", True), ) except Exception as e: - raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e # ========== 兼容接口 ========== + @router.get("/stats") async def get_replier_stats(): """获取回复器统计信息""" overview = await get_replier_overview() - + # 获取最近10条回复的摘要 recent_replies = [] for chat in overview.chats[:5]: # 从最近5个聊天中获取 @@ -250,15 +251,15 @@ async def get_replier_stats(): recent_replies.extend(chat_logs.data) except Exception: continue - + # 按时间排序取前10 recent_replies.sort(key=lambda x: x.timestamp, reverse=True) recent_replies = recent_replies[:10] - + return { "total_chats": overview.total_chats, "total_replies": overview.total_replies, - "recent_replies": recent_replies + "recent_replies": recent_replies, } @@ -266,4 +267,4 @@ async def get_replier_stats(): async def get_replier_chat_list(): """获取所有聊天ID列表""" overview = await get_replier_overview() - return [chat.chat_id for chat in overview.chats] \ No newline at end of file + return [chat.chat_id for chat in overview.chats] diff --git a/src/webui/dependencies.py b/src/webui/dependencies.py index a7395522..b7f14348 100644 --- a/src/webui/dependencies.py +++ b/src/webui/dependencies.py @@ -1,6 +1,6 @@ from typing import Optional -from fastapi import Depends, Cookie, Header, Request, HTTPException -from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit +from fastapi import Depends, Cookie, Header, Request +from .core import get_current_token, get_token_manager, check_auth_rate_limit async def require_auth( diff --git a/src/webui/middleware/anti_crawler.py b/src/webui/middleware/anti_crawler.py index b489179b..8cb0335f 100644 --- a/src/webui/middleware/anti_crawler.py +++ b/src/webui/middleware/anti_crawler.py @@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = { # loose: 宽松模式(较宽松的检测,较高的频率限制) # basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP) + # IP白名单配置(从配置文件读取,逗号分隔) # 支持格式: # - 精确IP:127.0.0.1, 192.168.1.100 @@ -151,7 +152,7 @@ def _parse_allowed_ips(ip_string: str) -> list: ip_entry = ip_entry.strip() # 去除空格 if not ip_entry: continue - + # 跳过注释行(以#开头) if ip_entry.startswith("#"): continue @@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]: def _get_anti_crawler_config(): """获取防爬虫配置""" from src.config.config import global_config + return { - 'mode': global_config.webui.anti_crawler_mode, - 'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips), - 'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies), - 'trust_xff': global_config.webui.trust_xff + "mode": global_config.webui.anti_crawler_mode, + "allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips), + "trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies), + "trust_xff": global_config.webui.trust_xff, } + # 初始化配置(将在模块加载时执行) _config = _get_anti_crawler_config() -ANTI_CRAWLER_MODE = _config['mode'] -ALLOWED_IPS = _config['allowed_ips'] -TRUSTED_PROXIES = _config['trusted_proxies'] -TRUST_XFF = _config['trust_xff'] +ANTI_CRAWLER_MODE = _config["mode"] +ALLOWED_IPS = _config["allowed_ips"] +TRUSTED_PROXIES = _config["trusted_proxies"] +TRUST_XFF = _config["trust_xff"] def _get_mode_config(mode: str) -> dict: diff --git a/src/webui/routers/annual_report.py b/src/webui/routers/annual_report.py index a277c070..b9abc54d 100644 --- a/src/webui/routers/annual_report.py +++ b/src/webui/routers/annual_report.py @@ -333,7 +333,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: statement = select(func.count()).where( col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.is_at) == True, + col(Messages.is_at), ) data.at_count = int(session.exec(statement).first() or 0) @@ -342,7 +342,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: statement = select(func.count()).where( col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.is_mentioned) == True, + col(Messages.is_mentioned), ) data.mentioned_count = int(session.exec(statement).first() or 0) @@ -552,7 +552,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: # 1. 表情包之王 - 使用次数最多的表情包 with get_db_session() as session: statement = ( - select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5) + select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5) ) top_emojis = session.exec(statement).all() if top_emojis: @@ -636,7 +636,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: statement = select(func.count()).where( col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.is_picture) == True, + col(Messages.is_picture), ) data.image_processed_count = int(session.exec(statement).first() or 0) @@ -781,12 +781,12 @@ async def get_achievements(year: int = 2025) -> AchievementData: # 1. 新学到的黑话数量 # Jargon 表没有时间字段,统计全部已确认的黑话 with get_db_session() as session: - statement = select(func.count()).where(col(Jargon.is_jargon) == True) + statement = select(func.count()).where(col(Jargon.is_jargon)) data.new_jargon_count = int(session.exec(statement).first() or 0) # 2. 代表性黑话示例 with get_db_session() as session: - statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5) + statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5) jargon_samples = session.exec(statement).all() data.sample_jargons = [ { diff --git a/src/webui/routers/emoji.py b/src/webui/routers/emoji.py index 98e8c588..4f882d09 100644 --- a/src/webui/routers/emoji.py +++ b/src/webui/routers/emoji.py @@ -532,7 +532,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz .select_from(Images) .where( col(Images.image_type) == ImageType.EMOJI, - col(Images.is_registered) == True, + col(Images.is_registered), ) ) banned_statement = ( @@ -540,7 +540,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz .select_from(Images) .where( col(Images.image_type) == ImageType.EMOJI, - col(Images.is_banned) == True, + col(Images.is_banned), ) ) @@ -1283,7 +1283,7 @@ async def preheat_thumbnail_cache( select(Images) .where( col(Images.image_type) == ImageType.EMOJI, - col(Images.is_banned) == False, + col(Images.is_banned).is_(False), ) .order_by(col(Images.query_count).desc()) .limit(limit * 2) diff --git a/src/webui/routers/jargon.py b/src/webui/routers/jargon.py index d1f97181..bee2d276 100644 --- a/src/webui/routers/jargon.py +++ b/src/webui/routers/jargon.py @@ -315,15 +315,15 @@ async def get_jargon_stats(): total = session.exec(select(fn.count()).select_from(Jargon)).one() confirmed_jargon = session.exec( - select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True) + select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon)) ).one() confirmed_not_jargon = session.exec( - select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False) + select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False)) ).one() pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one() complete_count = session.exec( - select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True) + select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete)) ).one() chat_count = session.exec( diff --git a/src/webui/routers/knowledge.py b/src/webui/routers/knowledge.py index 5959e0ac..e43bdb02 100644 --- a/src/webui/routers/knowledge.py +++ b/src/webui/routers/knowledge.py @@ -17,36 +17,36 @@ _paragraph_store_cache = None def _get_paragraph_store(): """延迟加载段落 embedding store(只读模式,轻量级) - + Returns: EmbeddingStore | None: 如果配置启用则返回store,否则返回None """ # 检查配置是否启用 if not global_config.webui.enable_paragraph_content: return None - + global _paragraph_store_cache if _paragraph_store_cache is not None: return _paragraph_store_cache - + try: from src.chat.knowledge.embedding_store import EmbeddingStore import os - + # 获取数据路径 current_dir = os.path.dirname(os.path.abspath(__file__)) root_path = os.path.abspath(os.path.join(current_dir, "..", "..")) embedding_dir = os.path.join(root_path, "data/embedding") - + # 只加载段落 embedding store(轻量级) paragraph_store = EmbeddingStore( namespace="paragraph", dir_path=embedding_dir, max_workers=1, # 只读不需要多线程 - chunk_size=100 + chunk_size=100, ) paragraph_store.load_from_file() - + _paragraph_store_cache = paragraph_store logger.info(f"成功加载段落 embedding store,包含 {len(paragraph_store.store)} 个段落") return paragraph_store @@ -57,10 +57,10 @@ def _get_paragraph_store(): def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]: """从 embedding store 获取段落完整内容 - + Args: node_id: 段落节点ID,格式为 'paragraph-{hash}' - + Returns: tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能) """ @@ -69,12 +69,12 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]: if paragraph_store is None: # 功能未启用 return None, False - + # 从 store 中获取完整内容 paragraph_item = paragraph_store.store.get(node_id) if paragraph_item is not None: # paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本 - content: str = getattr(paragraph_item, 'str', '') + content: str = getattr(paragraph_item, "str", "") if content: return content, True return None, True @@ -156,14 +156,18 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph: node_data = graph[node_id] # 节点类型: "ent" -> "entity", "pg" -> "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" - + # 对于段落节点,尝试从 embedding store 获取完整内容 if node_type == "paragraph": full_content, _ = _get_paragraph_content(node_id) - content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) + content = ( + full_content + if full_content is not None + else (node_data["content"] if "content" in node_data else node_id) + ) else: content = node_data["content"] if "content" in node_data else node_id - + create_time = node_data["create_time"] if "create_time" in node_data else None nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time)) @@ -245,14 +249,18 @@ async def get_knowledge_graph( try: node_data = graph[node_id] node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" - + # 对于段落节点,尝试从 embedding store 获取完整内容 if node_type_val == "paragraph": full_content, _ = _get_paragraph_content(node_id) - content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) + content = ( + full_content + if full_content is not None + else (node_data["content"] if "content" in node_data else node_id) + ) else: content = node_data["content"] if "content" in node_data else node_id - + create_time = node_data["create_time"] if "create_time" in node_data else None nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time)) @@ -368,11 +376,15 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo try: node_data = graph[node_id] node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" - + # 对于段落节点,尝试从 embedding store 获取完整内容 if node_type == "paragraph": full_content, _ = _get_paragraph_content(node_id) - content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) + content = ( + full_content + if full_content is not None + else (node_data["content"] if "content" in node_data else node_id) + ) else: content = node_data["content"] if "content" in node_data else node_id diff --git a/src/webui/routers/person.py b/src/webui/routers/person.py index d1b86a02..0b896f69 100644 --- a/src/webui/routers/person.py +++ b/src/webui/routers/person.py @@ -370,7 +370,7 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori with get_db_session() as session: total = len(session.exec(select(PersonInfo.id)).all()) - known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all()) + known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all()) unknown = total - known # 按平台统计 diff --git a/src/webui/routers/plugin.py b/src/webui/routers/plugin.py index c3ddc956..ab6ca479 100644 --- a/src/webui/routers/plugin.py +++ b/src/webui/routers/plugin.py @@ -1762,7 +1762,7 @@ async def update_plugin_config_raw( try: tomlkit.loads(request.config) except Exception as e: - raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") + raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e # 备份旧配置 import shutil diff --git a/src/webui/services/git_mirror_service.py b/src/webui/services/git_mirror_service.py index 83e6be01..a6a9b1bc 100644 --- a/src/webui/services/git_mirror_service.py +++ b/src/webui/services/git_mirror_service.py @@ -659,4 +659,4 @@ def get_git_mirror_service() -> GitMirrorService: global _git_mirror_service if _git_mirror_service is None: _git_mirror_service = GitMirrorService() - return _git_mirror_service \ No newline at end of file + return _git_mirror_service