mirror of https://github.com/Mai-with-u/MaiBot.git
1486 lines
56 KiB
Python
1486 lines
56 KiB
Python
"""
|
||
MCP 客户端封装模块
|
||
负责与 MCP 服务器建立连接、获取工具列表、执行工具调用
|
||
|
||
v1.7.0 稳定性优化:
|
||
- 断路器模式:连续失败 5 次后熔断,60 秒后试探恢复
|
||
- 熔断期间快速失败,避免等待超时
|
||
- 连接成功时自动重置断路器
|
||
|
||
v1.5.2 性能优化:
|
||
- 智能心跳间隔:根据服务器稳定性动态调整心跳频率
|
||
- 稳定服务器逐渐增加间隔(最高 3x),减少不必要的检测
|
||
- 断开的服务器使用较短间隔快速重连
|
||
|
||
v1.1.0 新增功能:
|
||
- 调用统计(次数、成功率、耗时)
|
||
- 心跳检测
|
||
- 自动重连
|
||
- 更好的错误处理
|
||
|
||
v1.2.0 新增功能:
|
||
- Resources 支持(资源读取)
|
||
- Prompts 支持(提示模板)
|
||
- 新增配置项: enable_resources, enable_prompts
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
import logging
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
|
||
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
||
try:
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("mcp_client")
|
||
except ImportError:
|
||
# Fallback: 使用标准 logging
|
||
logger = logging.getLogger("mcp_client")
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler()
|
||
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
|
||
logger.addHandler(handler)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
|
||
class TransportType(Enum):
|
||
"""MCP 传输类型"""
|
||
|
||
STDIO = "stdio" # 本地进程通信
|
||
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
||
HTTP = "http" # HTTP Streamable (新版,推荐)
|
||
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
|
||
|
||
|
||
@dataclass
|
||
class MCPToolInfo:
|
||
"""MCP 工具信息"""
|
||
|
||
name: str
|
||
description: str
|
||
input_schema: Dict[str, Any]
|
||
server_name: str
|
||
|
||
|
||
@dataclass
|
||
class MCPResourceInfo:
|
||
"""MCP 资源信息"""
|
||
|
||
uri: str
|
||
name: str
|
||
description: str
|
||
mime_type: Optional[str]
|
||
server_name: str
|
||
|
||
|
||
@dataclass
|
||
class MCPPromptInfo:
|
||
"""MCP 提示模板信息"""
|
||
|
||
name: str
|
||
description: str
|
||
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
||
server_name: str
|
||
|
||
|
||
@dataclass
|
||
class MCPServerConfig:
|
||
"""MCP 服务器配置"""
|
||
|
||
name: str
|
||
enabled: bool = True
|
||
transport: TransportType = TransportType.STDIO
|
||
# stdio 配置
|
||
command: str = ""
|
||
args: List[str] = field(default_factory=list)
|
||
env: Dict[str, str] = field(default_factory=dict)
|
||
# http/sse 配置
|
||
url: str = ""
|
||
headers: Dict[str, str] = field(default_factory=dict) # v1.4.2: 鉴权头支持
|
||
|
||
|
||
@dataclass
|
||
class MCPCallResult:
|
||
"""MCP 工具调用结果"""
|
||
|
||
success: bool
|
||
content: Any
|
||
error: Optional[str] = None
|
||
duration_ms: float = 0.0 # 调用耗时(毫秒)
|
||
circuit_broken: bool = False # v1.7.0: 是否被断路器拦截
|
||
|
||
|
||
class CircuitState(Enum):
|
||
"""断路器状态"""
|
||
|
||
CLOSED = "closed" # 正常状态,允许请求
|
||
OPEN = "open" # 熔断状态,拒绝请求
|
||
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
||
|
||
|
||
@dataclass
|
||
class CircuitBreaker:
|
||
"""v1.7.0: 断路器 - 防止对故障服务器持续请求
|
||
|
||
状态转换:
|
||
- CLOSED -> OPEN: 连续失败次数达到阈值
|
||
- OPEN -> HALF_OPEN: 熔断时间到期
|
||
- HALF_OPEN -> CLOSED: 试探请求成功
|
||
- HALF_OPEN -> OPEN: 试探请求失败
|
||
"""
|
||
|
||
# 配置
|
||
failure_threshold: int = 5 # 连续失败多少次后熔断
|
||
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
||
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
||
|
||
# 状态
|
||
state: CircuitState = field(default=CircuitState.CLOSED)
|
||
failure_count: int = 0
|
||
success_count: int = 0
|
||
last_failure_time: float = 0.0
|
||
last_state_change: float = field(default_factory=time.time)
|
||
half_open_calls: int = 0
|
||
|
||
def can_execute(self) -> Tuple[bool, Optional[str]]:
|
||
"""检查是否允许执行请求
|
||
|
||
Returns:
|
||
(是否允许, 拒绝原因)
|
||
"""
|
||
current_time = time.time()
|
||
|
||
if self.state == CircuitState.CLOSED:
|
||
return True, None
|
||
|
||
if self.state == CircuitState.OPEN:
|
||
# 检查是否到了恢复时间
|
||
time_since_failure = current_time - self.last_failure_time
|
||
if time_since_failure >= self.recovery_timeout:
|
||
# 转换到半开状态
|
||
self._transition_to(CircuitState.HALF_OPEN)
|
||
return True, None
|
||
else:
|
||
remaining = self.recovery_timeout - time_since_failure
|
||
return False, f"断路器熔断中,{remaining:.0f}秒后重试"
|
||
|
||
if self.state == CircuitState.HALF_OPEN:
|
||
# 半开状态,检查是否还有试探配额
|
||
if self.half_open_calls < self.half_open_max_calls:
|
||
return True, None
|
||
else:
|
||
return False, "断路器半开状态,等待试探结果"
|
||
|
||
return True, None
|
||
|
||
def record_success(self) -> None:
|
||
"""记录成功调用"""
|
||
self.success_count += 1
|
||
|
||
if self.state == CircuitState.HALF_OPEN:
|
||
# 半开状态下成功,恢复到关闭状态
|
||
self._transition_to(CircuitState.CLOSED)
|
||
logger.info("断路器恢复正常(试探成功)")
|
||
elif self.state == CircuitState.CLOSED:
|
||
# 正常状态下成功,重置失败计数
|
||
self.failure_count = 0
|
||
|
||
def record_failure(self) -> None:
|
||
"""记录失败调用"""
|
||
self.failure_count += 1
|
||
self.last_failure_time = time.time()
|
||
|
||
if self.state == CircuitState.HALF_OPEN:
|
||
# 半开状态下失败,重新熔断
|
||
self._transition_to(CircuitState.OPEN)
|
||
logger.warning("断路器重新熔断(试探失败)")
|
||
elif self.state == CircuitState.CLOSED:
|
||
# 检查是否达到熔断阈值
|
||
if self.failure_count >= self.failure_threshold:
|
||
self._transition_to(CircuitState.OPEN)
|
||
logger.warning(f"断路器熔断(连续失败 {self.failure_count} 次)")
|
||
|
||
def _transition_to(self, new_state: CircuitState) -> None:
|
||
"""状态转换"""
|
||
old_state = self.state
|
||
self.state = new_state
|
||
self.last_state_change = time.time()
|
||
|
||
if new_state == CircuitState.CLOSED:
|
||
self.failure_count = 0
|
||
self.half_open_calls = 0
|
||
elif new_state == CircuitState.HALF_OPEN:
|
||
self.half_open_calls = 0
|
||
|
||
logger.debug(f"断路器状态: {old_state.value} -> {new_state.value}")
|
||
|
||
def reset(self) -> None:
|
||
"""重置断路器"""
|
||
self.state = CircuitState.CLOSED
|
||
self.failure_count = 0
|
||
self.success_count = 0
|
||
self.half_open_calls = 0
|
||
self.last_state_change = time.time()
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取断路器状态"""
|
||
return {
|
||
"state": self.state.value,
|
||
"failure_count": self.failure_count,
|
||
"success_count": self.success_count,
|
||
"failure_threshold": self.failure_threshold,
|
||
"recovery_timeout": self.recovery_timeout,
|
||
"time_since_last_failure": time.time() - self.last_failure_time if self.last_failure_time > 0 else None,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ToolCallStats:
|
||
"""工具调用统计"""
|
||
|
||
tool_key: str
|
||
total_calls: int = 0
|
||
success_calls: int = 0
|
||
failed_calls: int = 0
|
||
total_duration_ms: float = 0.0
|
||
last_call_time: Optional[float] = None
|
||
last_error: Optional[str] = None
|
||
|
||
@property
|
||
def success_rate(self) -> float:
|
||
"""成功率(0-100)"""
|
||
if self.total_calls == 0:
|
||
return 0.0
|
||
return (self.success_calls / self.total_calls) * 100
|
||
|
||
@property
|
||
def avg_duration_ms(self) -> float:
|
||
"""平均耗时(毫秒)"""
|
||
if self.success_calls == 0:
|
||
return 0.0
|
||
return self.total_duration_ms / self.success_calls
|
||
|
||
def record_call(self, success: bool, duration_ms: float, error: Optional[str] = None) -> None:
|
||
"""记录一次调用"""
|
||
self.total_calls += 1
|
||
self.last_call_time = time.time()
|
||
if success:
|
||
self.success_calls += 1
|
||
self.total_duration_ms += duration_ms
|
||
else:
|
||
self.failed_calls += 1
|
||
self.last_error = error
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"tool_key": self.tool_key,
|
||
"total_calls": self.total_calls,
|
||
"success_calls": self.success_calls,
|
||
"failed_calls": self.failed_calls,
|
||
"success_rate": round(self.success_rate, 2),
|
||
"avg_duration_ms": round(self.avg_duration_ms, 2),
|
||
"last_call_time": self.last_call_time,
|
||
"last_error": self.last_error,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ServerStats:
|
||
"""服务器统计"""
|
||
|
||
server_name: str
|
||
connect_count: int = 0 # 连接次数
|
||
disconnect_count: int = 0 # 断开次数
|
||
reconnect_count: int = 0 # 重连次数
|
||
last_connect_time: Optional[float] = None
|
||
last_disconnect_time: Optional[float] = None
|
||
last_heartbeat_time: Optional[float] = None
|
||
consecutive_failures: int = 0 # 连续失败次数
|
||
|
||
def record_connect(self) -> None:
|
||
self.connect_count += 1
|
||
self.last_connect_time = time.time()
|
||
self.consecutive_failures = 0
|
||
|
||
def record_disconnect(self) -> None:
|
||
self.disconnect_count += 1
|
||
self.last_disconnect_time = time.time()
|
||
|
||
def record_reconnect(self) -> None:
|
||
self.reconnect_count += 1
|
||
self.consecutive_failures = 0
|
||
|
||
def record_failure(self) -> None:
|
||
self.consecutive_failures += 1
|
||
|
||
def record_heartbeat(self) -> None:
|
||
self.last_heartbeat_time = time.time()
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"server_name": self.server_name,
|
||
"connect_count": self.connect_count,
|
||
"disconnect_count": self.disconnect_count,
|
||
"reconnect_count": self.reconnect_count,
|
||
"last_connect_time": self.last_connect_time,
|
||
"last_disconnect_time": self.last_disconnect_time,
|
||
"last_heartbeat_time": self.last_heartbeat_time,
|
||
"consecutive_failures": self.consecutive_failures,
|
||
}
|
||
|
||
|
||
class MCPClientSession:
|
||
"""MCP 客户端会话,管理与单个 MCP 服务器的连接"""
|
||
|
||
def __init__(self, config: MCPServerConfig, call_timeout: float = 60.0):
|
||
self.config = config
|
||
self.call_timeout = call_timeout
|
||
self._session = None
|
||
self._read_stream = None
|
||
self._write_stream = None
|
||
self._process: Optional[asyncio.subprocess.Process] = None
|
||
self._tools: List[MCPToolInfo] = []
|
||
self._resources: List[MCPResourceInfo] = [] # v1.2.0: Resources 支持
|
||
self._prompts: List[MCPPromptInfo] = [] # v1.2.0: Prompts 支持
|
||
self._connected = False
|
||
self._lock = asyncio.Lock()
|
||
|
||
# 功能支持标记(服务器可能不支持某些功能)
|
||
self._supports_resources: bool = False
|
||
self._supports_prompts: bool = False
|
||
|
||
# 统计信息
|
||
self.stats = ServerStats(server_name=config.name)
|
||
self._tool_stats: Dict[str, ToolCallStats] = {}
|
||
|
||
# v1.7.0: 断路器
|
||
self._circuit_breaker = CircuitBreaker()
|
||
|
||
@property
|
||
def is_connected(self) -> bool:
|
||
return self._connected
|
||
|
||
@property
|
||
def tools(self) -> List[MCPToolInfo]:
|
||
return self._tools.copy()
|
||
|
||
@property
|
||
def resources(self) -> List[MCPResourceInfo]:
|
||
"""v1.2.0: 获取资源列表"""
|
||
return self._resources.copy()
|
||
|
||
@property
|
||
def prompts(self) -> List[MCPPromptInfo]:
|
||
"""v1.2.0: 获取提示模板列表"""
|
||
return self._prompts.copy()
|
||
|
||
@property
|
||
def supports_resources(self) -> bool:
|
||
"""v1.2.0: 服务器是否支持 Resources"""
|
||
return self._supports_resources
|
||
|
||
@property
|
||
def supports_prompts(self) -> bool:
|
||
"""v1.2.0: 服务器是否支持 Prompts"""
|
||
return self._supports_prompts
|
||
|
||
@property
|
||
def server_name(self) -> str:
|
||
return self.config.name
|
||
|
||
def get_tool_stats(self, tool_name: str) -> Optional[ToolCallStats]:
|
||
"""获取工具统计"""
|
||
return self._tool_stats.get(tool_name)
|
||
|
||
def get_circuit_breaker_status(self) -> Dict[str, Any]:
|
||
"""v1.7.0: 获取断路器状态"""
|
||
return self._circuit_breaker.get_status()
|
||
|
||
def reset_circuit_breaker(self) -> None:
|
||
"""v1.7.0: 重置断路器"""
|
||
self._circuit_breaker.reset()
|
||
logger.info(f"[{self.server_name}] 断路器已重置")
|
||
|
||
def get_all_tool_stats(self) -> Dict[str, ToolCallStats]:
|
||
"""获取所有工具统计"""
|
||
return self._tool_stats.copy()
|
||
|
||
async def connect(self) -> bool:
|
||
"""连接到 MCP 服务器"""
|
||
async with self._lock:
|
||
if self._connected:
|
||
return True
|
||
|
||
try:
|
||
success = False
|
||
if self.config.transport == TransportType.STDIO:
|
||
success = await self._connect_stdio()
|
||
elif self.config.transport == TransportType.SSE:
|
||
success = await self._connect_sse()
|
||
elif self.config.transport in (TransportType.HTTP, TransportType.STREAMABLE_HTTP):
|
||
success = await self._connect_http()
|
||
else:
|
||
logger.error(f"[{self.server_name}] 不支持的传输类型: {self.config.transport}")
|
||
return False
|
||
|
||
if success:
|
||
self.stats.record_connect()
|
||
# v1.7.0: 连接成功时重置断路器
|
||
self._circuit_breaker.reset()
|
||
else:
|
||
self.stats.record_failure()
|
||
return success
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.server_name}] 连接失败: {e}")
|
||
self._connected = False
|
||
self.stats.record_failure()
|
||
return False
|
||
|
||
async def _connect_stdio(self) -> bool:
|
||
"""通过 stdio 连接 MCP 服务器"""
|
||
try:
|
||
try:
|
||
from mcp import ClientSession, StdioServerParameters
|
||
from mcp.client.stdio import stdio_client
|
||
except ImportError:
|
||
logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp")
|
||
return False
|
||
|
||
server_params = StdioServerParameters(
|
||
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
|
||
)
|
||
|
||
self._stdio_context = stdio_client(server_params)
|
||
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
|
||
|
||
self._session_context = ClientSession(self._read_stream, self._write_stream)
|
||
self._session = await self._session_context.__aenter__()
|
||
|
||
await self._session.initialize()
|
||
await self._fetch_tools()
|
||
|
||
self._connected = True
|
||
logger.info(f"[{self.server_name}] stdio 连接成功,发现 {len(self._tools)} 个工具")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.server_name}] stdio 连接失败: {e}")
|
||
await self._cleanup()
|
||
return False
|
||
|
||
async def _connect_sse(self) -> bool:
|
||
"""通过 SSE 连接 MCP 服务器"""
|
||
try:
|
||
try:
|
||
from mcp import ClientSession
|
||
from mcp.client.sse import sse_client
|
||
except ImportError:
|
||
logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp")
|
||
return False
|
||
|
||
if not self.config.url:
|
||
logger.error(f"[{self.server_name}] SSE 传输需要配置 url")
|
||
return False
|
||
|
||
logger.debug(f"[{self.server_name}] 正在连接 SSE MCP 服务器: {self.config.url}")
|
||
|
||
# v1.4.2: 支持 headers 鉴权
|
||
sse_kwargs = {
|
||
"url": self.config.url,
|
||
"timeout": 60.0,
|
||
"sse_read_timeout": 300.0,
|
||
}
|
||
if self.config.headers:
|
||
sse_kwargs["headers"] = self.config.headers
|
||
|
||
self._sse_context = sse_client(**sse_kwargs)
|
||
self._read_stream, self._write_stream = await self._sse_context.__aenter__()
|
||
|
||
self._session_context = ClientSession(self._read_stream, self._write_stream)
|
||
self._session = await self._session_context.__aenter__()
|
||
|
||
await self._session.initialize()
|
||
await self._fetch_tools()
|
||
|
||
self._connected = True
|
||
logger.info(f"[{self.server_name}] SSE 连接成功,发现 {len(self._tools)} 个工具")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
||
import traceback
|
||
|
||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||
await self._cleanup()
|
||
return False
|
||
|
||
async def _connect_http(self) -> bool:
|
||
"""通过 HTTP Streamable 连接 MCP 服务器"""
|
||
try:
|
||
try:
|
||
from mcp import ClientSession
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
except ImportError:
|
||
logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp")
|
||
return False
|
||
|
||
if not self.config.url:
|
||
logger.error(f"[{self.server_name}] HTTP 传输需要配置 url")
|
||
return False
|
||
|
||
logger.debug(f"[{self.server_name}] 正在连接 HTTP MCP 服务器: {self.config.url}")
|
||
|
||
# v1.4.2: 支持 headers 鉴权
|
||
http_kwargs = {
|
||
"url": self.config.url,
|
||
"timeout": 60.0,
|
||
"sse_read_timeout": 300.0,
|
||
}
|
||
if self.config.headers:
|
||
http_kwargs["headers"] = self.config.headers
|
||
|
||
self._http_context = streamablehttp_client(**http_kwargs)
|
||
self._read_stream, self._write_stream, self._get_session_id = await self._http_context.__aenter__()
|
||
|
||
self._session_context = ClientSession(self._read_stream, self._write_stream)
|
||
self._session = await self._session_context.__aenter__()
|
||
|
||
await self._session.initialize()
|
||
await self._fetch_tools()
|
||
|
||
self._connected = True
|
||
logger.info(f"[{self.server_name}] HTTP 连接成功,发现 {len(self._tools)} 个工具")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
||
import traceback
|
||
|
||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||
await self._cleanup()
|
||
return False
|
||
|
||
async def _fetch_tools(self) -> None:
|
||
"""获取 MCP 服务器的工具列表"""
|
||
if not self._session:
|
||
return
|
||
|
||
try:
|
||
result = await self._session.list_tools()
|
||
self._tools = []
|
||
|
||
for tool in result.tools:
|
||
tool_info = MCPToolInfo(
|
||
name=tool.name,
|
||
description=tool.description or f"MCP tool: {tool.name}",
|
||
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||
server_name=self.server_name,
|
||
)
|
||
self._tools.append(tool_info)
|
||
# 初始化工具统计
|
||
if tool.name not in self._tool_stats:
|
||
self._tool_stats[tool.name] = ToolCallStats(tool_key=tool.name)
|
||
logger.debug(f"[{self.server_name}] 发现工具: {tool.name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.server_name}] 获取工具列表失败: {e}")
|
||
self._tools = []
|
||
|
||
async def fetch_resources(self) -> bool:
|
||
"""v1.2.0: 获取 MCP 服务器的资源列表
|
||
|
||
Returns:
|
||
bool: 是否成功获取(服务器不支持时返回 False)
|
||
"""
|
||
if not self._session:
|
||
return False
|
||
|
||
try:
|
||
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
|
||
self._resources = []
|
||
|
||
for resource in result.resources:
|
||
resource_info = MCPResourceInfo(
|
||
uri=str(resource.uri),
|
||
name=resource.name or str(resource.uri),
|
||
description=resource.description or "",
|
||
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
|
||
server_name=self.server_name,
|
||
)
|
||
self._resources.append(resource_info)
|
||
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
||
|
||
self._supports_resources = True
|
||
logger.info(f"[{self.server_name}] 获取到 {len(self._resources)} 个资源")
|
||
return True
|
||
|
||
except Exception as e:
|
||
# 服务器可能不支持 resources,这不是错误
|
||
error_str = str(e).lower()
|
||
if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str:
|
||
logger.debug(f"[{self.server_name}] 服务器不支持 Resources 功能")
|
||
else:
|
||
logger.warning(f"[{self.server_name}] 获取资源列表失败: {e}")
|
||
self._supports_resources = False
|
||
self._resources = []
|
||
return False
|
||
|
||
async def fetch_prompts(self) -> bool:
|
||
"""v1.2.0: 获取 MCP 服务器的提示模板列表
|
||
|
||
Returns:
|
||
bool: 是否成功获取(服务器不支持时返回 False)
|
||
"""
|
||
if not self._session:
|
||
return False
|
||
|
||
try:
|
||
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
|
||
self._prompts = []
|
||
|
||
for prompt in result.prompts:
|
||
# 解析参数
|
||
arguments = []
|
||
if hasattr(prompt, "arguments") and prompt.arguments:
|
||
for arg in prompt.arguments:
|
||
arguments.append(
|
||
{
|
||
"name": arg.name,
|
||
"description": arg.description or "",
|
||
"required": arg.required if hasattr(arg, "required") else False,
|
||
}
|
||
)
|
||
|
||
prompt_info = MCPPromptInfo(
|
||
name=prompt.name,
|
||
description=prompt.description or f"MCP prompt: {prompt.name}",
|
||
arguments=arguments,
|
||
server_name=self.server_name,
|
||
)
|
||
self._prompts.append(prompt_info)
|
||
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
||
|
||
self._supports_prompts = True
|
||
logger.info(f"[{self.server_name}] 获取到 {len(self._prompts)} 个提示模板")
|
||
return True
|
||
|
||
except Exception as e:
|
||
# 服务器可能不支持 prompts,这不是错误
|
||
error_str = str(e).lower()
|
||
if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str:
|
||
logger.debug(f"[{self.server_name}] 服务器不支持 Prompts 功能")
|
||
else:
|
||
logger.warning(f"[{self.server_name}] 获取提示模板列表失败: {e}")
|
||
self._supports_prompts = False
|
||
self._prompts = []
|
||
return False
|
||
|
||
async def read_resource(self, uri: str) -> MCPCallResult:
|
||
"""v1.2.0: 读取指定资源的内容
|
||
|
||
Args:
|
||
uri: 资源 URI
|
||
|
||
Returns:
|
||
MCPCallResult: 包含资源内容的结果
|
||
"""
|
||
start_time = time.time()
|
||
|
||
if not self._connected or not self._session:
|
||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||
|
||
if not self._supports_resources:
|
||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
|
||
|
||
try:
|
||
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
|
||
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
|
||
# 处理返回内容
|
||
content_parts = []
|
||
for content in result.contents:
|
||
if hasattr(content, "text"):
|
||
content_parts.append(content.text)
|
||
elif hasattr(content, "blob"):
|
||
# 二进制数据,返回 base64 或提示
|
||
import base64
|
||
|
||
blob_data = content.blob
|
||
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
||
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
||
else:
|
||
content_parts.append(f"[二进制数据: {len(blob_data)} bytes]")
|
||
else:
|
||
content_parts.append(str(content))
|
||
|
||
return MCPCallResult(
|
||
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
|
||
)
|
||
|
||
except asyncio.TimeoutError:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
return MCPCallResult(
|
||
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||
)
|
||
except Exception as e:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
||
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||
|
||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
||
"""v1.2.0: 获取提示模板的内容
|
||
|
||
Args:
|
||
name: 提示模板名称
|
||
arguments: 模板参数
|
||
|
||
Returns:
|
||
MCPCallResult: 包含提示内容的结果
|
||
"""
|
||
start_time = time.time()
|
||
|
||
if not self._connected or not self._session:
|
||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||
|
||
if not self._supports_prompts:
|
||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
|
||
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
|
||
)
|
||
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
|
||
# 处理返回的消息
|
||
messages = []
|
||
for msg in result.messages:
|
||
role = msg.role if hasattr(msg, "role") else "unknown"
|
||
content_text = ""
|
||
if hasattr(msg, "content"):
|
||
if hasattr(msg.content, "text"):
|
||
content_text = msg.content.text
|
||
elif isinstance(msg.content, str):
|
||
content_text = msg.content
|
||
else:
|
||
content_text = str(msg.content)
|
||
messages.append(f"[{role}]: {content_text}")
|
||
|
||
return MCPCallResult(
|
||
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
|
||
)
|
||
|
||
except asyncio.TimeoutError:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
return MCPCallResult(
|
||
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||
)
|
||
except Exception as e:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
||
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||
|
||
async def check_health(self) -> bool:
|
||
"""检查连接健康状态(心跳检测)
|
||
|
||
通过调用 list_tools 来验证连接是否正常
|
||
"""
|
||
if not self._connected or not self._session:
|
||
return False
|
||
|
||
try:
|
||
# 使用 list_tools 作为心跳检测
|
||
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
|
||
self.stats.record_heartbeat()
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"[{self.server_name}] 心跳检测失败: {e}")
|
||
# 标记为断开
|
||
self._connected = False
|
||
self.stats.record_disconnect()
|
||
return False
|
||
|
||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
||
"""调用 MCP 工具"""
|
||
start_time = time.time()
|
||
|
||
# v1.7.0: 断路器检查
|
||
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
||
if not can_execute:
|
||
return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True)
|
||
|
||
# 半开状态下增加试探计数
|
||
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
||
self._circuit_breaker.half_open_calls += 1
|
||
|
||
if not self._connected or not self._session:
|
||
error_msg = f"服务器 {self.server_name} 未连接"
|
||
# 记录失败
|
||
if tool_name in self._tool_stats:
|
||
self._tool_stats[tool_name].record_call(False, 0, error_msg)
|
||
self._circuit_breaker.record_failure()
|
||
return MCPCallResult(success=False, content=None, error=error_msg)
|
||
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
|
||
)
|
||
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
|
||
# 处理返回内容
|
||
content_parts = []
|
||
for content in result.content:
|
||
if hasattr(content, "text"):
|
||
content_parts.append(content.text)
|
||
elif hasattr(content, "data"):
|
||
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
||
else:
|
||
content_parts.append(str(content))
|
||
|
||
# 记录成功
|
||
if tool_name in self._tool_stats:
|
||
self._tool_stats[tool_name].record_call(True, duration_ms)
|
||
|
||
# v1.7.0: 断路器记录成功
|
||
self._circuit_breaker.record_success()
|
||
|
||
return MCPCallResult(
|
||
success=True,
|
||
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
||
duration_ms=duration_ms,
|
||
)
|
||
|
||
except asyncio.TimeoutError:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
error_msg = f"工具调用超时({self.call_timeout}秒)"
|
||
if tool_name in self._tool_stats:
|
||
self._tool_stats[tool_name].record_call(False, duration_ms, error_msg)
|
||
# v1.7.0: 断路器记录失败
|
||
self._circuit_breaker.record_failure()
|
||
return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms)
|
||
|
||
except Exception as e:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
error_msg = str(e)
|
||
logger.error(f"[{self.server_name}] 调用工具 {tool_name} 失败: {e}")
|
||
if tool_name in self._tool_stats:
|
||
self._tool_stats[tool_name].record_call(False, duration_ms, error_msg)
|
||
# v1.7.0: 断路器记录失败
|
||
self._circuit_breaker.record_failure()
|
||
# 检查是否是连接问题
|
||
if "connection" in error_msg.lower() or "closed" in error_msg.lower():
|
||
self._connected = False
|
||
self.stats.record_disconnect()
|
||
return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms)
|
||
|
||
async def disconnect(self) -> None:
|
||
"""断开连接"""
|
||
async with self._lock:
|
||
if self._connected:
|
||
self.stats.record_disconnect()
|
||
await self._cleanup()
|
||
|
||
async def _cleanup(self) -> None:
|
||
"""清理资源"""
|
||
self._connected = False
|
||
self._tools = []
|
||
self._resources = [] # v1.2.0
|
||
self._prompts = [] # v1.2.0
|
||
self._supports_resources = False # v1.2.0
|
||
self._supports_prompts = False # v1.2.0
|
||
|
||
try:
|
||
if hasattr(self, "_session_context") and self._session_context:
|
||
await self._session_context.__aexit__(None, None, None)
|
||
except Exception as e:
|
||
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
||
|
||
try:
|
||
if hasattr(self, "_stdio_context") and self._stdio_context:
|
||
await self._stdio_context.__aexit__(None, None, None)
|
||
except Exception as e:
|
||
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
||
|
||
try:
|
||
if hasattr(self, "_http_context") and self._http_context:
|
||
await self._http_context.__aexit__(None, None, None)
|
||
except Exception as e:
|
||
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
||
|
||
try:
|
||
if hasattr(self, "_sse_context") and self._sse_context:
|
||
await self._sse_context.__aexit__(None, None, None)
|
||
except Exception as e:
|
||
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
||
|
||
self._session = None
|
||
self._session_context = None
|
||
self._stdio_context = None
|
||
self._http_context = None
|
||
self._sse_context = None
|
||
self._read_stream = None
|
||
self._write_stream = None
|
||
|
||
logger.debug(f"[{self.server_name}] 连接已关闭")
|
||
|
||
|
||
class MCPClientManager:
|
||
"""MCP 客户端管理器,管理多个 MCP 服务器连接
|
||
|
||
功能:
|
||
- 管理多个 MCP 服务器连接
|
||
- 心跳检测和自动重连
|
||
- 调用统计
|
||
"""
|
||
|
||
_instance: Optional["MCPClientManager"] = None
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
self._initialized = True
|
||
self._clients: Dict[str, MCPClientSession] = {}
|
||
self._all_tools: Dict[str, Tuple[MCPToolInfo, MCPClientSession]] = {}
|
||
self._all_resources: Dict[str, Tuple[MCPResourceInfo, MCPClientSession]] = {} # v1.2.0
|
||
self._all_prompts: Dict[str, Tuple[MCPPromptInfo, MCPClientSession]] = {} # v1.2.0
|
||
self._settings: Dict[str, Any] = {}
|
||
self._lock = asyncio.Lock()
|
||
|
||
# 心跳检测任务
|
||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||
self._heartbeat_running = False
|
||
|
||
# 状态变化回调
|
||
self._on_status_change: Optional[callable] = None
|
||
|
||
# 全局统计
|
||
self._global_stats = {
|
||
"total_tool_calls": 0,
|
||
"successful_calls": 0,
|
||
"failed_calls": 0,
|
||
"start_time": time.time(),
|
||
}
|
||
|
||
def configure(self, settings: Dict[str, Any]) -> None:
|
||
"""配置管理器"""
|
||
self._settings = settings
|
||
|
||
def set_status_change_callback(self, callback: callable) -> None:
|
||
"""设置状态变化回调函数"""
|
||
self._on_status_change = callback
|
||
|
||
def _notify_status_change(self) -> None:
|
||
"""通知状态变化"""
|
||
if self._on_status_change:
|
||
try:
|
||
self._on_status_change()
|
||
except Exception as e:
|
||
logger.debug(f"状态变化回调出错: {e}")
|
||
|
||
@property
|
||
def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]:
|
||
"""获取所有已注册的工具"""
|
||
return self._all_tools.copy()
|
||
|
||
@property
|
||
def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]:
|
||
"""v1.2.0: 获取所有已注册的资源"""
|
||
return self._all_resources.copy()
|
||
|
||
@property
|
||
def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]:
|
||
"""v1.2.0: 获取所有已注册的提示模板"""
|
||
return self._all_prompts.copy()
|
||
|
||
@property
|
||
def connected_servers(self) -> List[str]:
|
||
"""获取已连接的服务器列表"""
|
||
return [name for name, client in self._clients.items() if client.is_connected]
|
||
|
||
@property
|
||
def disconnected_servers(self) -> List[str]:
|
||
"""获取已断开的服务器列表"""
|
||
return [name for name, client in self._clients.items() if not client.is_connected and client.config.enabled]
|
||
|
||
async def add_server(self, config: MCPServerConfig) -> bool:
|
||
"""添加并连接 MCP 服务器"""
|
||
async with self._lock:
|
||
if config.name in self._clients:
|
||
logger.warning(f"服务器 {config.name} 已存在")
|
||
return False
|
||
|
||
call_timeout = self._settings.get("call_timeout", 60.0)
|
||
client = MCPClientSession(config, call_timeout)
|
||
self._clients[config.name] = client
|
||
|
||
if not config.enabled:
|
||
logger.info(f"服务器 {config.name} 已添加但未启用")
|
||
return True
|
||
|
||
# 尝试连接
|
||
retry_attempts = self._settings.get("retry_attempts", 3)
|
||
retry_interval = self._settings.get("retry_interval", 5.0)
|
||
|
||
for attempt in range(1, retry_attempts + 1):
|
||
if await client.connect():
|
||
self._register_tools(client)
|
||
return True
|
||
|
||
if attempt < retry_attempts:
|
||
logger.warning(
|
||
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
|
||
)
|
||
await asyncio.sleep(retry_interval)
|
||
|
||
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
||
# 连接失败,但保留在 _clients 中以便后续重连
|
||
return False
|
||
|
||
def _register_tools(self, client: MCPClientSession) -> None:
|
||
"""注册客户端的工具"""
|
||
tool_prefix = self._settings.get("tool_prefix", "mcp")
|
||
|
||
for tool in client.tools:
|
||
if tool.name.startswith(f"{tool_prefix}_{client.server_name}_"):
|
||
tool_key = tool.name
|
||
else:
|
||
tool_key = f"{tool_prefix}_{client.server_name}_{tool.name}"
|
||
self._all_tools[tool_key] = (tool, client)
|
||
logger.debug(f"注册 MCP 工具: {tool_key}")
|
||
|
||
def _unregister_tools(self, server_name: str) -> List[str]:
|
||
"""注销服务器的工具,返回被注销的工具键列表"""
|
||
tool_prefix = self._settings.get("tool_prefix", "mcp")
|
||
prefix = f"{tool_prefix}_{server_name}_"
|
||
|
||
keys_to_remove = [k for k in self._all_tools.keys() if k.startswith(prefix)]
|
||
for key in keys_to_remove:
|
||
del self._all_tools[key]
|
||
logger.debug(f"注销 MCP 工具: {key}")
|
||
return keys_to_remove
|
||
|
||
def _register_resources(self, client: MCPClientSession) -> None:
|
||
"""v1.2.0: 注册客户端的资源"""
|
||
tool_prefix = self._settings.get("tool_prefix", "mcp")
|
||
|
||
for resource in client.resources:
|
||
# 资源键格式: mcp_{server}_{uri_safe_name}
|
||
# 将 URI 转换为安全的键名
|
||
safe_uri = resource.uri.replace("://", "_").replace("/", "_").replace(".", "_")
|
||
resource_key = f"{tool_prefix}_{client.server_name}_res_{safe_uri}"
|
||
self._all_resources[resource_key] = (resource, client)
|
||
logger.debug(f"注册 MCP 资源: {resource_key}")
|
||
|
||
def _unregister_resources(self, server_name: str) -> List[str]:
|
||
"""v1.2.0: 注销服务器的资源"""
|
||
tool_prefix = self._settings.get("tool_prefix", "mcp")
|
||
prefix = f"{tool_prefix}_{server_name}_res_"
|
||
|
||
keys_to_remove = [k for k in self._all_resources.keys() if k.startswith(prefix)]
|
||
for key in keys_to_remove:
|
||
del self._all_resources[key]
|
||
logger.debug(f"注销 MCP 资源: {key}")
|
||
return keys_to_remove
|
||
|
||
def _register_prompts(self, client: MCPClientSession) -> None:
|
||
"""v1.2.0: 注册客户端的提示模板"""
|
||
tool_prefix = self._settings.get("tool_prefix", "mcp")
|
||
|
||
for prompt in client.prompts:
|
||
prompt_key = f"{tool_prefix}_{client.server_name}_prompt_{prompt.name}"
|
||
self._all_prompts[prompt_key] = (prompt, client)
|
||
logger.debug(f"注册 MCP 提示模板: {prompt_key}")
|
||
|
||
def _unregister_prompts(self, server_name: str) -> List[str]:
|
||
"""v1.2.0: 注销服务器的提示模板"""
|
||
tool_prefix = self._settings.get("tool_prefix", "mcp")
|
||
prefix = f"{tool_prefix}_{server_name}_prompt_"
|
||
|
||
keys_to_remove = [k for k in self._all_prompts.keys() if k.startswith(prefix)]
|
||
for key in keys_to_remove:
|
||
del self._all_prompts[key]
|
||
logger.debug(f"注销 MCP 提示模板: {key}")
|
||
return keys_to_remove
|
||
|
||
async def remove_server(self, server_name: str) -> bool:
|
||
"""移除 MCP 服务器"""
|
||
async with self._lock:
|
||
if server_name not in self._clients:
|
||
return False
|
||
|
||
client = self._clients[server_name]
|
||
await client.disconnect()
|
||
self._unregister_tools(server_name)
|
||
self._unregister_resources(server_name) # v1.2.0
|
||
self._unregister_prompts(server_name) # v1.2.0
|
||
del self._clients[server_name]
|
||
|
||
logger.info(f"服务器 {server_name} 已移除")
|
||
return True
|
||
|
||
async def reconnect_server(self, server_name: str) -> bool:
|
||
"""重新连接服务器"""
|
||
if server_name not in self._clients:
|
||
return False
|
||
|
||
client = self._clients[server_name]
|
||
|
||
async with self._lock:
|
||
self._unregister_tools(server_name)
|
||
self._unregister_resources(server_name) # v1.2.0
|
||
self._unregister_prompts(server_name) # v1.2.0
|
||
await client.disconnect()
|
||
|
||
# 尝试重连
|
||
retry_attempts = self._settings.get("retry_attempts", 3)
|
||
retry_interval = self._settings.get("retry_interval", 5.0)
|
||
|
||
for attempt in range(1, retry_attempts + 1):
|
||
if await client.connect():
|
||
async with self._lock:
|
||
self._register_tools(client)
|
||
# v1.2.0: 重连后也尝试获取 resources 和 prompts
|
||
if self._settings.get("enable_resources", False):
|
||
await client.fetch_resources()
|
||
self._register_resources(client)
|
||
if self._settings.get("enable_prompts", False):
|
||
await client.fetch_prompts()
|
||
self._register_prompts(client)
|
||
client.stats.record_reconnect()
|
||
logger.info(f"服务器 {server_name} 重连成功")
|
||
return True
|
||
|
||
if attempt < retry_attempts:
|
||
logger.warning(f"服务器 {server_name} 重连失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
|
||
await asyncio.sleep(retry_interval)
|
||
|
||
logger.error(f"服务器 {server_name} 重连失败")
|
||
return False
|
||
|
||
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
||
"""调用 MCP 工具"""
|
||
if tool_key not in self._all_tools:
|
||
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
|
||
|
||
tool_info, client = self._all_tools[tool_key]
|
||
|
||
# 更新全局统计
|
||
self._global_stats["total_tool_calls"] += 1
|
||
|
||
result = await client.call_tool(tool_info.name, arguments)
|
||
|
||
if result.success:
|
||
self._global_stats["successful_calls"] += 1
|
||
else:
|
||
self._global_stats["failed_calls"] += 1
|
||
|
||
return result
|
||
|
||
async def fetch_resources_for_server(self, server_name: str) -> bool:
|
||
"""v1.2.0: 获取指定服务器的资源列表"""
|
||
if server_name not in self._clients:
|
||
return False
|
||
|
||
client = self._clients[server_name]
|
||
if not client.is_connected:
|
||
return False
|
||
|
||
success = await client.fetch_resources()
|
||
if success:
|
||
async with self._lock:
|
||
self._register_resources(client)
|
||
return success
|
||
|
||
async def fetch_prompts_for_server(self, server_name: str) -> bool:
|
||
"""v1.2.0: 获取指定服务器的提示模板列表"""
|
||
if server_name not in self._clients:
|
||
return False
|
||
|
||
client = self._clients[server_name]
|
||
if not client.is_connected:
|
||
return False
|
||
|
||
success = await client.fetch_prompts()
|
||
if success:
|
||
async with self._lock:
|
||
self._register_prompts(client)
|
||
return success
|
||
|
||
async def read_resource(self, uri: str, server_name: Optional[str] = None) -> MCPCallResult:
|
||
"""v1.2.0: 读取资源内容
|
||
|
||
Args:
|
||
uri: 资源 URI
|
||
server_name: 指定服务器名称(可选,不指定则自动查找)
|
||
"""
|
||
# 如果指定了服务器
|
||
if server_name:
|
||
if server_name not in self._clients:
|
||
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||
client = self._clients[server_name]
|
||
return await client.read_resource(uri)
|
||
|
||
# 自动查找拥有该资源的服务器
|
||
for resource_key, (resource_info, client) in self._all_resources.items():
|
||
if resource_info.uri == uri:
|
||
return await client.read_resource(uri)
|
||
|
||
# 尝试在所有支持 resources 的服务器上查找
|
||
for client in self._clients.values():
|
||
if client.is_connected and client.supports_resources:
|
||
result = await client.read_resource(uri)
|
||
if result.success:
|
||
return result
|
||
|
||
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
|
||
|
||
async def get_prompt(
|
||
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
|
||
) -> MCPCallResult:
|
||
"""v1.2.0: 获取提示模板内容
|
||
|
||
Args:
|
||
name: 提示模板名称
|
||
arguments: 模板参数
|
||
server_name: 指定服务器名称(可选)
|
||
"""
|
||
# 如果指定了服务器
|
||
if server_name:
|
||
if server_name not in self._clients:
|
||
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||
client = self._clients[server_name]
|
||
return await client.get_prompt(name, arguments)
|
||
|
||
# 自动查找拥有该提示模板的服务器
|
||
for prompt_key, (prompt_info, client) in self._all_prompts.items():
|
||
if prompt_info.name == name:
|
||
return await client.get_prompt(name, arguments)
|
||
|
||
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
|
||
|
||
# ==================== 心跳检测 ====================
|
||
|
||
async def start_heartbeat(self) -> None:
|
||
"""启动心跳检测任务"""
|
||
if self._heartbeat_running:
|
||
logger.warning("心跳检测任务已在运行")
|
||
return
|
||
|
||
heartbeat_enabled = self._settings.get("heartbeat_enabled", True)
|
||
if not heartbeat_enabled:
|
||
logger.info("心跳检测已禁用")
|
||
return
|
||
|
||
self._heartbeat_running = True
|
||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||
logger.info("心跳检测任务已启动")
|
||
|
||
async def stop_heartbeat(self) -> None:
|
||
"""停止心跳检测任务"""
|
||
self._heartbeat_running = False
|
||
if self._heartbeat_task:
|
||
self._heartbeat_task.cancel()
|
||
try:
|
||
await self._heartbeat_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._heartbeat_task = None
|
||
logger.info("心跳检测任务已停止")
|
||
|
||
async def _heartbeat_loop(self) -> None:
|
||
"""心跳检测循环(v1.5.2: 智能心跳间隔)"""
|
||
base_interval = self._settings.get("heartbeat_interval", 60.0)
|
||
auto_reconnect = self._settings.get("auto_reconnect", True)
|
||
max_reconnect_attempts = self._settings.get("max_reconnect_attempts", 3)
|
||
|
||
# v1.5.2: 智能心跳配置
|
||
adaptive_enabled = self._settings.get("heartbeat_adaptive", True)
|
||
max_multiplier = self._settings.get("heartbeat_max_multiplier", 3.0)
|
||
|
||
# 每个服务器独立的心跳间隔(根据稳定性动态调整)
|
||
server_intervals: Dict[str, float] = {}
|
||
min_interval = max(base_interval * 0.5, 30.0) # 最小间隔
|
||
max_interval = base_interval * max_multiplier # 最大间隔
|
||
|
||
mode_str = "智能" if adaptive_enabled else "固定"
|
||
logger.info(f"心跳检测循环启动,{mode_str}模式,基准间隔: {base_interval}秒")
|
||
|
||
while self._heartbeat_running:
|
||
try:
|
||
# 使用最小的服务器间隔作为循环间隔
|
||
current_interval = min(server_intervals.values()) if server_intervals else base_interval
|
||
current_interval = max(current_interval, min_interval)
|
||
|
||
await asyncio.sleep(current_interval)
|
||
|
||
if not self._heartbeat_running:
|
||
break
|
||
|
||
current_time = time.time()
|
||
|
||
# 检查所有已启用的服务器
|
||
for server_name, client in list(self._clients.items()):
|
||
if not client.config.enabled:
|
||
continue
|
||
|
||
# 初始化服务器间隔
|
||
if server_name not in server_intervals:
|
||
server_intervals[server_name] = base_interval
|
||
|
||
# 检查是否到达该服务器的心跳时间
|
||
last_heartbeat = client.stats.last_heartbeat_time or 0
|
||
if current_time - last_heartbeat < server_intervals[server_name] * 0.9:
|
||
continue # 还没到心跳时间
|
||
|
||
if client.is_connected:
|
||
# 检查健康状态
|
||
healthy = await client.check_health()
|
||
if healthy:
|
||
# v1.5.2: 智能心跳 - 稳定服务器逐渐增加间隔
|
||
if adaptive_enabled and client.stats.consecutive_failures == 0:
|
||
new_interval = min(server_intervals[server_name] * 1.2, max_interval)
|
||
if new_interval != server_intervals[server_name]:
|
||
server_intervals[server_name] = new_interval
|
||
logger.debug(f"[{server_name}] 稳定,心跳间隔调整为 {new_interval:.0f}s")
|
||
else:
|
||
logger.warning(f"[{server_name}] 心跳检测失败,连接可能已断开")
|
||
# 失败后重置为基准间隔
|
||
if adaptive_enabled:
|
||
server_intervals[server_name] = base_interval
|
||
self._notify_status_change()
|
||
if auto_reconnect:
|
||
await self._try_reconnect(server_name, max_reconnect_attempts)
|
||
else:
|
||
# 服务器未连接,尝试重连
|
||
if adaptive_enabled:
|
||
# 智能心跳:断开的服务器使用较短间隔
|
||
server_intervals[server_name] = min_interval
|
||
if auto_reconnect and client.stats.consecutive_failures < max_reconnect_attempts:
|
||
logger.info(f"[{server_name}] 检测到断开,尝试重连...")
|
||
await self._try_reconnect(server_name, max_reconnect_attempts)
|
||
elif client.stats.consecutive_failures >= max_reconnect_attempts:
|
||
if adaptive_enabled:
|
||
# 达到最大重连次数,降低检测频率
|
||
server_intervals[server_name] = max_interval
|
||
logger.debug(f"[{server_name}] 已达最大重连次数,降低检测频率")
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"心跳检测循环出错: {e}")
|
||
await asyncio.sleep(5)
|
||
|
||
async def _try_reconnect(self, server_name: str, max_attempts: int) -> bool:
|
||
"""尝试重连服务器"""
|
||
client = self._clients.get(server_name)
|
||
if not client:
|
||
return False
|
||
|
||
if client.stats.consecutive_failures >= max_attempts:
|
||
logger.warning(f"[{server_name}] 连续失败次数已达上限 ({max_attempts}),暂停重连")
|
||
return False
|
||
|
||
logger.info(f"[{server_name}] 尝试重连 (失败次数: {client.stats.consecutive_failures}/{max_attempts})")
|
||
|
||
success = await self.reconnect_server(server_name)
|
||
if not success:
|
||
client.stats.record_failure()
|
||
|
||
self._notify_status_change() # 重连后更新状态
|
||
return success
|
||
|
||
# ==================== 统计和状态 ====================
|
||
|
||
def get_tool_stats(self, tool_key: str) -> Optional[Dict[str, Any]]:
|
||
"""获取指定工具的统计信息"""
|
||
if tool_key not in self._all_tools:
|
||
return None
|
||
|
||
tool_info, client = self._all_tools[tool_key]
|
||
stats = client.get_tool_stats(tool_info.name)
|
||
return stats.to_dict() if stats else None
|
||
|
||
def get_all_stats(self) -> Dict[str, Any]:
|
||
"""获取所有统计信息"""
|
||
server_stats = {}
|
||
tool_stats = {}
|
||
|
||
for server_name, client in self._clients.items():
|
||
server_stats[server_name] = client.stats.to_dict()
|
||
for tool_name, stats in client.get_all_tool_stats().items():
|
||
full_key = f"{self._settings.get('tool_prefix', 'mcp')}_{server_name}_{tool_name}"
|
||
tool_stats[full_key] = stats.to_dict()
|
||
|
||
uptime = time.time() - self._global_stats["start_time"]
|
||
|
||
return {
|
||
"global": {
|
||
**self._global_stats,
|
||
"uptime_seconds": round(uptime, 2),
|
||
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
|
||
if uptime > 0
|
||
else 0,
|
||
},
|
||
"servers": server_stats,
|
||
"tools": tool_stats,
|
||
}
|
||
|
||
async def shutdown(self) -> None:
|
||
"""关闭所有连接"""
|
||
# 停止心跳检测
|
||
await self.stop_heartbeat()
|
||
|
||
async with self._lock:
|
||
for client in self._clients.values():
|
||
await client.disconnect()
|
||
self._clients.clear()
|
||
self._all_tools.clear()
|
||
self._all_resources.clear() # v1.2.0
|
||
self._all_prompts.clear() # v1.2.0
|
||
logger.info("MCP 客户端管理器已关闭")
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取状态信息"""
|
||
return {
|
||
"total_servers": len(self._clients),
|
||
"connected_servers": len(self.connected_servers),
|
||
"disconnected_servers": len(self.disconnected_servers),
|
||
"total_tools": len(self._all_tools),
|
||
"total_resources": len(self._all_resources), # v1.2.0
|
||
"total_prompts": len(self._all_prompts), # v1.2.0
|
||
"heartbeat_running": self._heartbeat_running,
|
||
"servers": {
|
||
name: {
|
||
"connected": client.is_connected,
|
||
"enabled": client.config.enabled,
|
||
"tools_count": len(client.tools),
|
||
"resources_count": len(client.resources), # v1.2.0
|
||
"prompts_count": len(client.prompts), # v1.2.0
|
||
"supports_resources": client.supports_resources, # v1.2.0
|
||
"supports_prompts": client.supports_prompts, # v1.2.0
|
||
"transport": client.config.transport.value,
|
||
"consecutive_failures": client.stats.consecutive_failures,
|
||
"circuit_breaker": client.get_circuit_breaker_status(), # v1.7.0
|
||
}
|
||
for name, client in self._clients.items()
|
||
},
|
||
"global_stats": self._global_stats,
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
mcp_manager = MCPClientManager()
|