MaiBot/plugins/MaiBot_MCPBridgePlugin/mcp_client.py

1486 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
MCP 客户端封装模块
负责与 MCP 服务器建立连接、获取工具列表、执行工具调用
v1.7.0 稳定性优化:
- 断路器模式:连续失败 5 次后熔断60 秒后试探恢复
- 熔断期间快速失败,避免等待超时
- 连接成功时自动重置断路器
v1.5.2 性能优化:
- 智能心跳间隔:根据服务器稳定性动态调整心跳频率
- 稳定服务器逐渐增加间隔(最高 3x减少不必要的检测
- 断开的服务器使用较短间隔快速重连
v1.1.0 新增功能:
- 调用统计(次数、成功率、耗时)
- 心跳检测
- 自动重连
- 更好的错误处理
v1.2.0 新增功能:
- Resources 支持(资源读取)
- Prompts 支持(提示模板)
- 新增配置项: enable_resources, enable_prompts
"""
import asyncio
import time
import logging
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
# 尝试导入 MaiBot 的 logger如果失败则使用标准 logging
try:
from src.common.logger import get_logger
logger = get_logger("mcp_client")
except ImportError:
# Fallback: 使用标准 logging
logger = logging.getLogger("mcp_client")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class TransportType(Enum):
"""MCP 传输类型"""
STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐)
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
@dataclass
class MCPToolInfo:
"""MCP 工具信息"""
name: str
description: str
input_schema: Dict[str, Any]
server_name: str
@dataclass
class MCPResourceInfo:
"""MCP 资源信息"""
uri: str
name: str
description: str
mime_type: Optional[str]
server_name: str
@dataclass
class MCPPromptInfo:
"""MCP 提示模板信息"""
name: str
description: str
arguments: List[Dict[str, Any]] # [{name, description, required}]
server_name: str
@dataclass
class MCPServerConfig:
"""MCP 服务器配置"""
name: str
enabled: bool = True
transport: TransportType = TransportType.STDIO
# stdio 配置
command: str = ""
args: List[str] = field(default_factory=list)
env: Dict[str, str] = field(default_factory=dict)
# http/sse 配置
url: str = ""
headers: Dict[str, str] = field(default_factory=dict) # v1.4.2: 鉴权头支持
@dataclass
class MCPCallResult:
"""MCP 工具调用结果"""
success: bool
content: Any
error: Optional[str] = None
duration_ms: float = 0.0 # 调用耗时(毫秒)
circuit_broken: bool = False # v1.7.0: 是否被断路器拦截
class CircuitState(Enum):
"""断路器状态"""
CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
@dataclass
class CircuitBreaker:
"""v1.7.0: 断路器 - 防止对故障服务器持续请求
状态转换:
- CLOSED -> OPEN: 连续失败次数达到阈值
- OPEN -> HALF_OPEN: 熔断时间到期
- HALF_OPEN -> CLOSED: 试探请求成功
- HALF_OPEN -> OPEN: 试探请求失败
"""
# 配置
failure_threshold: int = 5 # 连续失败多少次后熔断
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
# 状态
state: CircuitState = field(default=CircuitState.CLOSED)
failure_count: int = 0
success_count: int = 0
last_failure_time: float = 0.0
last_state_change: float = field(default_factory=time.time)
half_open_calls: int = 0
def can_execute(self) -> Tuple[bool, Optional[str]]:
"""检查是否允许执行请求
Returns:
(是否允许, 拒绝原因)
"""
current_time = time.time()
if self.state == CircuitState.CLOSED:
return True, None
if self.state == CircuitState.OPEN:
# 检查是否到了恢复时间
time_since_failure = current_time - self.last_failure_time
if time_since_failure >= self.recovery_timeout:
# 转换到半开状态
self._transition_to(CircuitState.HALF_OPEN)
return True, None
else:
remaining = self.recovery_timeout - time_since_failure
return False, f"断路器熔断中,{remaining:.0f}秒后重试"
if self.state == CircuitState.HALF_OPEN:
# 半开状态,检查是否还有试探配额
if self.half_open_calls < self.half_open_max_calls:
return True, None
else:
return False, "断路器半开状态,等待试探结果"
return True, None
def record_success(self) -> None:
"""记录成功调用"""
self.success_count += 1
if self.state == CircuitState.HALF_OPEN:
# 半开状态下成功,恢复到关闭状态
self._transition_to(CircuitState.CLOSED)
logger.info("断路器恢复正常(试探成功)")
elif self.state == CircuitState.CLOSED:
# 正常状态下成功,重置失败计数
self.failure_count = 0
def record_failure(self) -> None:
"""记录失败调用"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == CircuitState.HALF_OPEN:
# 半开状态下失败,重新熔断
self._transition_to(CircuitState.OPEN)
logger.warning("断路器重新熔断(试探失败)")
elif self.state == CircuitState.CLOSED:
# 检查是否达到熔断阈值
if self.failure_count >= self.failure_threshold:
self._transition_to(CircuitState.OPEN)
logger.warning(f"断路器熔断(连续失败 {self.failure_count} 次)")
def _transition_to(self, new_state: CircuitState) -> None:
"""状态转换"""
old_state = self.state
self.state = new_state
self.last_state_change = time.time()
if new_state == CircuitState.CLOSED:
self.failure_count = 0
self.half_open_calls = 0
elif new_state == CircuitState.HALF_OPEN:
self.half_open_calls = 0
logger.debug(f"断路器状态: {old_state.value} -> {new_state.value}")
def reset(self) -> None:
"""重置断路器"""
self.state = CircuitState.CLOSED
self.failure_count = 0
self.success_count = 0
self.half_open_calls = 0
self.last_state_change = time.time()
def get_status(self) -> Dict[str, Any]:
"""获取断路器状态"""
return {
"state": self.state.value,
"failure_count": self.failure_count,
"success_count": self.success_count,
"failure_threshold": self.failure_threshold,
"recovery_timeout": self.recovery_timeout,
"time_since_last_failure": time.time() - self.last_failure_time if self.last_failure_time > 0 else None,
}
@dataclass
class ToolCallStats:
"""工具调用统计"""
tool_key: str
total_calls: int = 0
success_calls: int = 0
failed_calls: int = 0
total_duration_ms: float = 0.0
last_call_time: Optional[float] = None
last_error: Optional[str] = None
@property
def success_rate(self) -> float:
"""成功率0-100"""
if self.total_calls == 0:
return 0.0
return (self.success_calls / self.total_calls) * 100
@property
def avg_duration_ms(self) -> float:
"""平均耗时(毫秒)"""
if self.success_calls == 0:
return 0.0
return self.total_duration_ms / self.success_calls
def record_call(self, success: bool, duration_ms: float, error: Optional[str] = None) -> None:
"""记录一次调用"""
self.total_calls += 1
self.last_call_time = time.time()
if success:
self.success_calls += 1
self.total_duration_ms += duration_ms
else:
self.failed_calls += 1
self.last_error = error
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"tool_key": self.tool_key,
"total_calls": self.total_calls,
"success_calls": self.success_calls,
"failed_calls": self.failed_calls,
"success_rate": round(self.success_rate, 2),
"avg_duration_ms": round(self.avg_duration_ms, 2),
"last_call_time": self.last_call_time,
"last_error": self.last_error,
}
@dataclass
class ServerStats:
"""服务器统计"""
server_name: str
connect_count: int = 0 # 连接次数
disconnect_count: int = 0 # 断开次数
reconnect_count: int = 0 # 重连次数
last_connect_time: Optional[float] = None
last_disconnect_time: Optional[float] = None
last_heartbeat_time: Optional[float] = None
consecutive_failures: int = 0 # 连续失败次数
def record_connect(self) -> None:
self.connect_count += 1
self.last_connect_time = time.time()
self.consecutive_failures = 0
def record_disconnect(self) -> None:
self.disconnect_count += 1
self.last_disconnect_time = time.time()
def record_reconnect(self) -> None:
self.reconnect_count += 1
self.consecutive_failures = 0
def record_failure(self) -> None:
self.consecutive_failures += 1
def record_heartbeat(self) -> None:
self.last_heartbeat_time = time.time()
def to_dict(self) -> Dict[str, Any]:
return {
"server_name": self.server_name,
"connect_count": self.connect_count,
"disconnect_count": self.disconnect_count,
"reconnect_count": self.reconnect_count,
"last_connect_time": self.last_connect_time,
"last_disconnect_time": self.last_disconnect_time,
"last_heartbeat_time": self.last_heartbeat_time,
"consecutive_failures": self.consecutive_failures,
}
class MCPClientSession:
"""MCP 客户端会话,管理与单个 MCP 服务器的连接"""
def __init__(self, config: MCPServerConfig, call_timeout: float = 60.0):
self.config = config
self.call_timeout = call_timeout
self._session = None
self._read_stream = None
self._write_stream = None
self._process: Optional[asyncio.subprocess.Process] = None
self._tools: List[MCPToolInfo] = []
self._resources: List[MCPResourceInfo] = [] # v1.2.0: Resources 支持
self._prompts: List[MCPPromptInfo] = [] # v1.2.0: Prompts 支持
self._connected = False
self._lock = asyncio.Lock()
# 功能支持标记(服务器可能不支持某些功能)
self._supports_resources: bool = False
self._supports_prompts: bool = False
# 统计信息
self.stats = ServerStats(server_name=config.name)
self._tool_stats: Dict[str, ToolCallStats] = {}
# v1.7.0: 断路器
self._circuit_breaker = CircuitBreaker()
@property
def is_connected(self) -> bool:
return self._connected
@property
def tools(self) -> List[MCPToolInfo]:
return self._tools.copy()
@property
def resources(self) -> List[MCPResourceInfo]:
"""v1.2.0: 获取资源列表"""
return self._resources.copy()
@property
def prompts(self) -> List[MCPPromptInfo]:
"""v1.2.0: 获取提示模板列表"""
return self._prompts.copy()
@property
def supports_resources(self) -> bool:
"""v1.2.0: 服务器是否支持 Resources"""
return self._supports_resources
@property
def supports_prompts(self) -> bool:
"""v1.2.0: 服务器是否支持 Prompts"""
return self._supports_prompts
@property
def server_name(self) -> str:
return self.config.name
def get_tool_stats(self, tool_name: str) -> Optional[ToolCallStats]:
"""获取工具统计"""
return self._tool_stats.get(tool_name)
def get_circuit_breaker_status(self) -> Dict[str, Any]:
"""v1.7.0: 获取断路器状态"""
return self._circuit_breaker.get_status()
def reset_circuit_breaker(self) -> None:
"""v1.7.0: 重置断路器"""
self._circuit_breaker.reset()
logger.info(f"[{self.server_name}] 断路器已重置")
def get_all_tool_stats(self) -> Dict[str, ToolCallStats]:
"""获取所有工具统计"""
return self._tool_stats.copy()
async def connect(self) -> bool:
"""连接到 MCP 服务器"""
async with self._lock:
if self._connected:
return True
try:
success = False
if self.config.transport == TransportType.STDIO:
success = await self._connect_stdio()
elif self.config.transport == TransportType.SSE:
success = await self._connect_sse()
elif self.config.transport in (TransportType.HTTP, TransportType.STREAMABLE_HTTP):
success = await self._connect_http()
else:
logger.error(f"[{self.server_name}] 不支持的传输类型: {self.config.transport}")
return False
if success:
self.stats.record_connect()
# v1.7.0: 连接成功时重置断路器
self._circuit_breaker.reset()
else:
self.stats.record_failure()
return success
except Exception as e:
logger.error(f"[{self.server_name}] 连接失败: {e}")
self._connected = False
self.stats.record_failure()
return False
async def _connect_stdio(self) -> bool:
"""通过 stdio 连接 MCP 服务器"""
try:
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
except ImportError:
logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp")
return False
server_params = StdioServerParameters(
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
)
self._stdio_context = stdio_client(server_params)
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
self._session_context = ClientSession(self._read_stream, self._write_stream)
self._session = await self._session_context.__aenter__()
await self._session.initialize()
await self._fetch_tools()
self._connected = True
logger.info(f"[{self.server_name}] stdio 连接成功,发现 {len(self._tools)} 个工具")
return True
except Exception as e:
logger.error(f"[{self.server_name}] stdio 连接失败: {e}")
await self._cleanup()
return False
async def _connect_sse(self) -> bool:
"""通过 SSE 连接 MCP 服务器"""
try:
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
except ImportError:
logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp")
return False
if not self.config.url:
logger.error(f"[{self.server_name}] SSE 传输需要配置 url")
return False
logger.debug(f"[{self.server_name}] 正在连接 SSE MCP 服务器: {self.config.url}")
# v1.4.2: 支持 headers 鉴权
sse_kwargs = {
"url": self.config.url,
"timeout": 60.0,
"sse_read_timeout": 300.0,
}
if self.config.headers:
sse_kwargs["headers"] = self.config.headers
self._sse_context = sse_client(**sse_kwargs)
self._read_stream, self._write_stream = await self._sse_context.__aenter__()
self._session_context = ClientSession(self._read_stream, self._write_stream)
self._session = await self._session_context.__aenter__()
await self._session.initialize()
await self._fetch_tools()
self._connected = True
logger.info(f"[{self.server_name}] SSE 连接成功,发现 {len(self._tools)} 个工具")
return True
except Exception as e:
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup()
return False
async def _connect_http(self) -> bool:
"""通过 HTTP Streamable 连接 MCP 服务器"""
try:
try:
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
except ImportError:
logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp")
return False
if not self.config.url:
logger.error(f"[{self.server_name}] HTTP 传输需要配置 url")
return False
logger.debug(f"[{self.server_name}] 正在连接 HTTP MCP 服务器: {self.config.url}")
# v1.4.2: 支持 headers 鉴权
http_kwargs = {
"url": self.config.url,
"timeout": 60.0,
"sse_read_timeout": 300.0,
}
if self.config.headers:
http_kwargs["headers"] = self.config.headers
self._http_context = streamablehttp_client(**http_kwargs)
self._read_stream, self._write_stream, self._get_session_id = await self._http_context.__aenter__()
self._session_context = ClientSession(self._read_stream, self._write_stream)
self._session = await self._session_context.__aenter__()
await self._session.initialize()
await self._fetch_tools()
self._connected = True
logger.info(f"[{self.server_name}] HTTP 连接成功,发现 {len(self._tools)} 个工具")
return True
except Exception as e:
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup()
return False
async def _fetch_tools(self) -> None:
"""获取 MCP 服务器的工具列表"""
if not self._session:
return
try:
result = await self._session.list_tools()
self._tools = []
for tool in result.tools:
tool_info = MCPToolInfo(
name=tool.name,
description=tool.description or f"MCP tool: {tool.name}",
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
server_name=self.server_name,
)
self._tools.append(tool_info)
# 初始化工具统计
if tool.name not in self._tool_stats:
self._tool_stats[tool.name] = ToolCallStats(tool_key=tool.name)
logger.debug(f"[{self.server_name}] 发现工具: {tool.name}")
except Exception as e:
logger.error(f"[{self.server_name}] 获取工具列表失败: {e}")
self._tools = []
async def fetch_resources(self) -> bool:
"""v1.2.0: 获取 MCP 服务器的资源列表
Returns:
bool: 是否成功获取(服务器不支持时返回 False
"""
if not self._session:
return False
try:
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
self._resources = []
for resource in result.resources:
resource_info = MCPResourceInfo(
uri=str(resource.uri),
name=resource.name or str(resource.uri),
description=resource.description or "",
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
server_name=self.server_name,
)
self._resources.append(resource_info)
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
self._supports_resources = True
logger.info(f"[{self.server_name}] 获取到 {len(self._resources)} 个资源")
return True
except Exception as e:
# 服务器可能不支持 resources这不是错误
error_str = str(e).lower()
if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str:
logger.debug(f"[{self.server_name}] 服务器不支持 Resources 功能")
else:
logger.warning(f"[{self.server_name}] 获取资源列表失败: {e}")
self._supports_resources = False
self._resources = []
return False
async def fetch_prompts(self) -> bool:
"""v1.2.0: 获取 MCP 服务器的提示模板列表
Returns:
bool: 是否成功获取(服务器不支持时返回 False
"""
if not self._session:
return False
try:
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
self._prompts = []
for prompt in result.prompts:
# 解析参数
arguments = []
if hasattr(prompt, "arguments") and prompt.arguments:
for arg in prompt.arguments:
arguments.append(
{
"name": arg.name,
"description": arg.description or "",
"required": arg.required if hasattr(arg, "required") else False,
}
)
prompt_info = MCPPromptInfo(
name=prompt.name,
description=prompt.description or f"MCP prompt: {prompt.name}",
arguments=arguments,
server_name=self.server_name,
)
self._prompts.append(prompt_info)
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
self._supports_prompts = True
logger.info(f"[{self.server_name}] 获取到 {len(self._prompts)} 个提示模板")
return True
except Exception as e:
# 服务器可能不支持 prompts这不是错误
error_str = str(e).lower()
if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str:
logger.debug(f"[{self.server_name}] 服务器不支持 Prompts 功能")
else:
logger.warning(f"[{self.server_name}] 获取提示模板列表失败: {e}")
self._supports_prompts = False
self._prompts = []
return False
async def read_resource(self, uri: str) -> MCPCallResult:
"""v1.2.0: 读取指定资源的内容
Args:
uri: 资源 URI
Returns:
MCPCallResult: 包含资源内容的结果
"""
start_time = time.time()
if not self._connected or not self._session:
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
if not self._supports_resources:
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
try:
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
duration_ms = (time.time() - start_time) * 1000
# 处理返回内容
content_parts = []
for content in result.contents:
if hasattr(content, "text"):
content_parts.append(content.text)
elif hasattr(content, "blob"):
# 二进制数据,返回 base64 或提示
import base64
blob_data = content.blob
if len(blob_data) < 10000: # 小于 10KB 返回 base64
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
else:
content_parts.append(f"[二进制数据: {len(blob_data)} bytes]")
else:
content_parts.append(str(content))
return MCPCallResult(
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
return MCPCallResult(
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
"""v1.2.0: 获取提示模板的内容
Args:
name: 提示模板名称
arguments: 模板参数
Returns:
MCPCallResult: 包含提示内容的结果
"""
start_time = time.time()
if not self._connected or not self._session:
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
if not self._supports_prompts:
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
try:
result = await asyncio.wait_for(
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000
# 处理返回的消息
messages = []
for msg in result.messages:
role = msg.role if hasattr(msg, "role") else "unknown"
content_text = ""
if hasattr(msg, "content"):
if hasattr(msg.content, "text"):
content_text = msg.content.text
elif isinstance(msg.content, str):
content_text = msg.content
else:
content_text = str(msg.content)
messages.append(f"[{role}]: {content_text}")
return MCPCallResult(
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
return MCPCallResult(
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
async def check_health(self) -> bool:
"""检查连接健康状态(心跳检测)
通过调用 list_tools 来验证连接是否正常
"""
if not self._connected or not self._session:
return False
try:
# 使用 list_tools 作为心跳检测
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
self.stats.record_heartbeat()
return True
except Exception as e:
logger.warning(f"[{self.server_name}] 心跳检测失败: {e}")
# 标记为断开
self._connected = False
self.stats.record_disconnect()
return False
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPCallResult:
"""调用 MCP 工具"""
start_time = time.time()
# v1.7.0: 断路器检查
can_execute, reject_reason = self._circuit_breaker.can_execute()
if not can_execute:
return MCPCallResult(success=False, content=None, error=f"{reject_reason}", circuit_broken=True)
# 半开状态下增加试探计数
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
self._circuit_breaker.half_open_calls += 1
if not self._connected or not self._session:
error_msg = f"服务器 {self.server_name} 未连接"
# 记录失败
if tool_name in self._tool_stats:
self._tool_stats[tool_name].record_call(False, 0, error_msg)
self._circuit_breaker.record_failure()
return MCPCallResult(success=False, content=None, error=error_msg)
try:
result = await asyncio.wait_for(
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000
# 处理返回内容
content_parts = []
for content in result.content:
if hasattr(content, "text"):
content_parts.append(content.text)
elif hasattr(content, "data"):
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
else:
content_parts.append(str(content))
# 记录成功
if tool_name in self._tool_stats:
self._tool_stats[tool_name].record_call(True, duration_ms)
# v1.7.0: 断路器记录成功
self._circuit_breaker.record_success()
return MCPCallResult(
success=True,
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
duration_ms=duration_ms,
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
error_msg = f"工具调用超时({self.call_timeout}秒)"
if tool_name in self._tool_stats:
self._tool_stats[tool_name].record_call(False, duration_ms, error_msg)
# v1.7.0: 断路器记录失败
self._circuit_breaker.record_failure()
return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
error_msg = str(e)
logger.error(f"[{self.server_name}] 调用工具 {tool_name} 失败: {e}")
if tool_name in self._tool_stats:
self._tool_stats[tool_name].record_call(False, duration_ms, error_msg)
# v1.7.0: 断路器记录失败
self._circuit_breaker.record_failure()
# 检查是否是连接问题
if "connection" in error_msg.lower() or "closed" in error_msg.lower():
self._connected = False
self.stats.record_disconnect()
return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms)
async def disconnect(self) -> None:
"""断开连接"""
async with self._lock:
if self._connected:
self.stats.record_disconnect()
await self._cleanup()
async def _cleanup(self) -> None:
"""清理资源"""
self._connected = False
self._tools = []
self._resources = [] # v1.2.0
self._prompts = [] # v1.2.0
self._supports_resources = False # v1.2.0
self._supports_prompts = False # v1.2.0
try:
if hasattr(self, "_session_context") and self._session_context:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
try:
if hasattr(self, "_stdio_context") and self._stdio_context:
await self._stdio_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
try:
if hasattr(self, "_http_context") and self._http_context:
await self._http_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
try:
if hasattr(self, "_sse_context") and self._sse_context:
await self._sse_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
self._session = None
self._session_context = None
self._stdio_context = None
self._http_context = None
self._sse_context = None
self._read_stream = None
self._write_stream = None
logger.debug(f"[{self.server_name}] 连接已关闭")
class MCPClientManager:
"""MCP 客户端管理器,管理多个 MCP 服务器连接
功能:
- 管理多个 MCP 服务器连接
- 心跳检测和自动重连
- 调用统计
"""
_instance: Optional["MCPClientManager"] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._clients: Dict[str, MCPClientSession] = {}
self._all_tools: Dict[str, Tuple[MCPToolInfo, MCPClientSession]] = {}
self._all_resources: Dict[str, Tuple[MCPResourceInfo, MCPClientSession]] = {} # v1.2.0
self._all_prompts: Dict[str, Tuple[MCPPromptInfo, MCPClientSession]] = {} # v1.2.0
self._settings: Dict[str, Any] = {}
self._lock = asyncio.Lock()
# 心跳检测任务
self._heartbeat_task: Optional[asyncio.Task] = None
self._heartbeat_running = False
# 状态变化回调
self._on_status_change: Optional[callable] = None
# 全局统计
self._global_stats = {
"total_tool_calls": 0,
"successful_calls": 0,
"failed_calls": 0,
"start_time": time.time(),
}
def configure(self, settings: Dict[str, Any]) -> None:
"""配置管理器"""
self._settings = settings
def set_status_change_callback(self, callback: callable) -> None:
"""设置状态变化回调函数"""
self._on_status_change = callback
def _notify_status_change(self) -> None:
"""通知状态变化"""
if self._on_status_change:
try:
self._on_status_change()
except Exception as e:
logger.debug(f"状态变化回调出错: {e}")
@property
def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]:
"""获取所有已注册的工具"""
return self._all_tools.copy()
@property
def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]:
"""v1.2.0: 获取所有已注册的资源"""
return self._all_resources.copy()
@property
def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]:
"""v1.2.0: 获取所有已注册的提示模板"""
return self._all_prompts.copy()
@property
def connected_servers(self) -> List[str]:
"""获取已连接的服务器列表"""
return [name for name, client in self._clients.items() if client.is_connected]
@property
def disconnected_servers(self) -> List[str]:
"""获取已断开的服务器列表"""
return [name for name, client in self._clients.items() if not client.is_connected and client.config.enabled]
async def add_server(self, config: MCPServerConfig) -> bool:
"""添加并连接 MCP 服务器"""
async with self._lock:
if config.name in self._clients:
logger.warning(f"服务器 {config.name} 已存在")
return False
call_timeout = self._settings.get("call_timeout", 60.0)
client = MCPClientSession(config, call_timeout)
self._clients[config.name] = client
if not config.enabled:
logger.info(f"服务器 {config.name} 已添加但未启用")
return True
# 尝试连接
retry_attempts = self._settings.get("retry_attempts", 3)
retry_interval = self._settings.get("retry_interval", 5.0)
for attempt in range(1, retry_attempts + 1):
if await client.connect():
self._register_tools(client)
return True
if attempt < retry_attempts:
logger.warning(
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
)
await asyncio.sleep(retry_interval)
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
# 连接失败,但保留在 _clients 中以便后续重连
return False
def _register_tools(self, client: MCPClientSession) -> None:
"""注册客户端的工具"""
tool_prefix = self._settings.get("tool_prefix", "mcp")
for tool in client.tools:
if tool.name.startswith(f"{tool_prefix}_{client.server_name}_"):
tool_key = tool.name
else:
tool_key = f"{tool_prefix}_{client.server_name}_{tool.name}"
self._all_tools[tool_key] = (tool, client)
logger.debug(f"注册 MCP 工具: {tool_key}")
def _unregister_tools(self, server_name: str) -> List[str]:
"""注销服务器的工具,返回被注销的工具键列表"""
tool_prefix = self._settings.get("tool_prefix", "mcp")
prefix = f"{tool_prefix}_{server_name}_"
keys_to_remove = [k for k in self._all_tools.keys() if k.startswith(prefix)]
for key in keys_to_remove:
del self._all_tools[key]
logger.debug(f"注销 MCP 工具: {key}")
return keys_to_remove
def _register_resources(self, client: MCPClientSession) -> None:
"""v1.2.0: 注册客户端的资源"""
tool_prefix = self._settings.get("tool_prefix", "mcp")
for resource in client.resources:
# 资源键格式: mcp_{server}_{uri_safe_name}
# 将 URI 转换为安全的键名
safe_uri = resource.uri.replace("://", "_").replace("/", "_").replace(".", "_")
resource_key = f"{tool_prefix}_{client.server_name}_res_{safe_uri}"
self._all_resources[resource_key] = (resource, client)
logger.debug(f"注册 MCP 资源: {resource_key}")
def _unregister_resources(self, server_name: str) -> List[str]:
"""v1.2.0: 注销服务器的资源"""
tool_prefix = self._settings.get("tool_prefix", "mcp")
prefix = f"{tool_prefix}_{server_name}_res_"
keys_to_remove = [k for k in self._all_resources.keys() if k.startswith(prefix)]
for key in keys_to_remove:
del self._all_resources[key]
logger.debug(f"注销 MCP 资源: {key}")
return keys_to_remove
def _register_prompts(self, client: MCPClientSession) -> None:
"""v1.2.0: 注册客户端的提示模板"""
tool_prefix = self._settings.get("tool_prefix", "mcp")
for prompt in client.prompts:
prompt_key = f"{tool_prefix}_{client.server_name}_prompt_{prompt.name}"
self._all_prompts[prompt_key] = (prompt, client)
logger.debug(f"注册 MCP 提示模板: {prompt_key}")
def _unregister_prompts(self, server_name: str) -> List[str]:
"""v1.2.0: 注销服务器的提示模板"""
tool_prefix = self._settings.get("tool_prefix", "mcp")
prefix = f"{tool_prefix}_{server_name}_prompt_"
keys_to_remove = [k for k in self._all_prompts.keys() if k.startswith(prefix)]
for key in keys_to_remove:
del self._all_prompts[key]
logger.debug(f"注销 MCP 提示模板: {key}")
return keys_to_remove
async def remove_server(self, server_name: str) -> bool:
"""移除 MCP 服务器"""
async with self._lock:
if server_name not in self._clients:
return False
client = self._clients[server_name]
await client.disconnect()
self._unregister_tools(server_name)
self._unregister_resources(server_name) # v1.2.0
self._unregister_prompts(server_name) # v1.2.0
del self._clients[server_name]
logger.info(f"服务器 {server_name} 已移除")
return True
async def reconnect_server(self, server_name: str) -> bool:
"""重新连接服务器"""
if server_name not in self._clients:
return False
client = self._clients[server_name]
async with self._lock:
self._unregister_tools(server_name)
self._unregister_resources(server_name) # v1.2.0
self._unregister_prompts(server_name) # v1.2.0
await client.disconnect()
# 尝试重连
retry_attempts = self._settings.get("retry_attempts", 3)
retry_interval = self._settings.get("retry_interval", 5.0)
for attempt in range(1, retry_attempts + 1):
if await client.connect():
async with self._lock:
self._register_tools(client)
# v1.2.0: 重连后也尝试获取 resources 和 prompts
if self._settings.get("enable_resources", False):
await client.fetch_resources()
self._register_resources(client)
if self._settings.get("enable_prompts", False):
await client.fetch_prompts()
self._register_prompts(client)
client.stats.record_reconnect()
logger.info(f"服务器 {server_name} 重连成功")
return True
if attempt < retry_attempts:
logger.warning(f"服务器 {server_name} 重连失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
await asyncio.sleep(retry_interval)
logger.error(f"服务器 {server_name} 重连失败")
return False
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
"""调用 MCP 工具"""
if tool_key not in self._all_tools:
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
tool_info, client = self._all_tools[tool_key]
# 更新全局统计
self._global_stats["total_tool_calls"] += 1
result = await client.call_tool(tool_info.name, arguments)
if result.success:
self._global_stats["successful_calls"] += 1
else:
self._global_stats["failed_calls"] += 1
return result
async def fetch_resources_for_server(self, server_name: str) -> bool:
"""v1.2.0: 获取指定服务器的资源列表"""
if server_name not in self._clients:
return False
client = self._clients[server_name]
if not client.is_connected:
return False
success = await client.fetch_resources()
if success:
async with self._lock:
self._register_resources(client)
return success
async def fetch_prompts_for_server(self, server_name: str) -> bool:
"""v1.2.0: 获取指定服务器的提示模板列表"""
if server_name not in self._clients:
return False
client = self._clients[server_name]
if not client.is_connected:
return False
success = await client.fetch_prompts()
if success:
async with self._lock:
self._register_prompts(client)
return success
async def read_resource(self, uri: str, server_name: Optional[str] = None) -> MCPCallResult:
"""v1.2.0: 读取资源内容
Args:
uri: 资源 URI
server_name: 指定服务器名称(可选,不指定则自动查找)
"""
# 如果指定了服务器
if server_name:
if server_name not in self._clients:
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
client = self._clients[server_name]
return await client.read_resource(uri)
# 自动查找拥有该资源的服务器
for resource_key, (resource_info, client) in self._all_resources.items():
if resource_info.uri == uri:
return await client.read_resource(uri)
# 尝试在所有支持 resources 的服务器上查找
for client in self._clients.values():
if client.is_connected and client.supports_resources:
result = await client.read_resource(uri)
if result.success:
return result
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
async def get_prompt(
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
) -> MCPCallResult:
"""v1.2.0: 获取提示模板内容
Args:
name: 提示模板名称
arguments: 模板参数
server_name: 指定服务器名称(可选)
"""
# 如果指定了服务器
if server_name:
if server_name not in self._clients:
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
client = self._clients[server_name]
return await client.get_prompt(name, arguments)
# 自动查找拥有该提示模板的服务器
for prompt_key, (prompt_info, client) in self._all_prompts.items():
if prompt_info.name == name:
return await client.get_prompt(name, arguments)
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
# ==================== 心跳检测 ====================
async def start_heartbeat(self) -> None:
"""启动心跳检测任务"""
if self._heartbeat_running:
logger.warning("心跳检测任务已在运行")
return
heartbeat_enabled = self._settings.get("heartbeat_enabled", True)
if not heartbeat_enabled:
logger.info("心跳检测已禁用")
return
self._heartbeat_running = True
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
logger.info("心跳检测任务已启动")
async def stop_heartbeat(self) -> None:
"""停止心跳检测任务"""
self._heartbeat_running = False
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
self._heartbeat_task = None
logger.info("心跳检测任务已停止")
async def _heartbeat_loop(self) -> None:
"""心跳检测循环v1.5.2: 智能心跳间隔)"""
base_interval = self._settings.get("heartbeat_interval", 60.0)
auto_reconnect = self._settings.get("auto_reconnect", True)
max_reconnect_attempts = self._settings.get("max_reconnect_attempts", 3)
# v1.5.2: 智能心跳配置
adaptive_enabled = self._settings.get("heartbeat_adaptive", True)
max_multiplier = self._settings.get("heartbeat_max_multiplier", 3.0)
# 每个服务器独立的心跳间隔(根据稳定性动态调整)
server_intervals: Dict[str, float] = {}
min_interval = max(base_interval * 0.5, 30.0) # 最小间隔
max_interval = base_interval * max_multiplier # 最大间隔
mode_str = "智能" if adaptive_enabled else "固定"
logger.info(f"心跳检测循环启动,{mode_str}模式,基准间隔: {base_interval}")
while self._heartbeat_running:
try:
# 使用最小的服务器间隔作为循环间隔
current_interval = min(server_intervals.values()) if server_intervals else base_interval
current_interval = max(current_interval, min_interval)
await asyncio.sleep(current_interval)
if not self._heartbeat_running:
break
current_time = time.time()
# 检查所有已启用的服务器
for server_name, client in list(self._clients.items()):
if not client.config.enabled:
continue
# 初始化服务器间隔
if server_name not in server_intervals:
server_intervals[server_name] = base_interval
# 检查是否到达该服务器的心跳时间
last_heartbeat = client.stats.last_heartbeat_time or 0
if current_time - last_heartbeat < server_intervals[server_name] * 0.9:
continue # 还没到心跳时间
if client.is_connected:
# 检查健康状态
healthy = await client.check_health()
if healthy:
# v1.5.2: 智能心跳 - 稳定服务器逐渐增加间隔
if adaptive_enabled and client.stats.consecutive_failures == 0:
new_interval = min(server_intervals[server_name] * 1.2, max_interval)
if new_interval != server_intervals[server_name]:
server_intervals[server_name] = new_interval
logger.debug(f"[{server_name}] 稳定,心跳间隔调整为 {new_interval:.0f}s")
else:
logger.warning(f"[{server_name}] 心跳检测失败,连接可能已断开")
# 失败后重置为基准间隔
if adaptive_enabled:
server_intervals[server_name] = base_interval
self._notify_status_change()
if auto_reconnect:
await self._try_reconnect(server_name, max_reconnect_attempts)
else:
# 服务器未连接,尝试重连
if adaptive_enabled:
# 智能心跳:断开的服务器使用较短间隔
server_intervals[server_name] = min_interval
if auto_reconnect and client.stats.consecutive_failures < max_reconnect_attempts:
logger.info(f"[{server_name}] 检测到断开,尝试重连...")
await self._try_reconnect(server_name, max_reconnect_attempts)
elif client.stats.consecutive_failures >= max_reconnect_attempts:
if adaptive_enabled:
# 达到最大重连次数,降低检测频率
server_intervals[server_name] = max_interval
logger.debug(f"[{server_name}] 已达最大重连次数,降低检测频率")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"心跳检测循环出错: {e}")
await asyncio.sleep(5)
async def _try_reconnect(self, server_name: str, max_attempts: int) -> bool:
"""尝试重连服务器"""
client = self._clients.get(server_name)
if not client:
return False
if client.stats.consecutive_failures >= max_attempts:
logger.warning(f"[{server_name}] 连续失败次数已达上限 ({max_attempts}),暂停重连")
return False
logger.info(f"[{server_name}] 尝试重连 (失败次数: {client.stats.consecutive_failures}/{max_attempts})")
success = await self.reconnect_server(server_name)
if not success:
client.stats.record_failure()
self._notify_status_change() # 重连后更新状态
return success
# ==================== 统计和状态 ====================
def get_tool_stats(self, tool_key: str) -> Optional[Dict[str, Any]]:
"""获取指定工具的统计信息"""
if tool_key not in self._all_tools:
return None
tool_info, client = self._all_tools[tool_key]
stats = client.get_tool_stats(tool_info.name)
return stats.to_dict() if stats else None
def get_all_stats(self) -> Dict[str, Any]:
"""获取所有统计信息"""
server_stats = {}
tool_stats = {}
for server_name, client in self._clients.items():
server_stats[server_name] = client.stats.to_dict()
for tool_name, stats in client.get_all_tool_stats().items():
full_key = f"{self._settings.get('tool_prefix', 'mcp')}_{server_name}_{tool_name}"
tool_stats[full_key] = stats.to_dict()
uptime = time.time() - self._global_stats["start_time"]
return {
"global": {
**self._global_stats,
"uptime_seconds": round(uptime, 2),
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
if uptime > 0
else 0,
},
"servers": server_stats,
"tools": tool_stats,
}
async def shutdown(self) -> None:
"""关闭所有连接"""
# 停止心跳检测
await self.stop_heartbeat()
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
self._all_tools.clear()
self._all_resources.clear() # v1.2.0
self._all_prompts.clear() # v1.2.0
logger.info("MCP 客户端管理器已关闭")
def get_status(self) -> Dict[str, Any]:
"""获取状态信息"""
return {
"total_servers": len(self._clients),
"connected_servers": len(self.connected_servers),
"disconnected_servers": len(self.disconnected_servers),
"total_tools": len(self._all_tools),
"total_resources": len(self._all_resources), # v1.2.0
"total_prompts": len(self._all_prompts), # v1.2.0
"heartbeat_running": self._heartbeat_running,
"servers": {
name: {
"connected": client.is_connected,
"enabled": client.config.enabled,
"tools_count": len(client.tools),
"resources_count": len(client.resources), # v1.2.0
"prompts_count": len(client.prompts), # v1.2.0
"supports_resources": client.supports_resources, # v1.2.0
"supports_prompts": client.supports_prompts, # v1.2.0
"transport": client.config.transport.value,
"consecutive_failures": client.stats.consecutive_failures,
"circuit_breaker": client.get_circuit_breaker_status(), # v1.7.0
}
for name, client in self._clients.items()
},
"global_stats": self._global_stats,
}
# 全局单例
mcp_manager = MCPClientManager()