Ruff Format

pull/1496/head
DrSmoothl 2026-02-21 16:24:24 +08:00
parent 2cb512120b
commit eaef7f0e98
82 changed files with 1881 additions and 1900 deletions

1
bot.py
View File

@ -50,6 +50,7 @@ print("警告Dev进入不稳定开发状态任何插件与WebUI均可能
print("\n\n\n\n\n") print("\n\n\n\n\n")
print("-----------------------------------------") print("-----------------------------------------")
def run_runner_process(): def run_runner_process():
""" """
Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程 Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程

View File

@ -1,2 +1 @@
"""Core helpers for MCP Bridge Plugin.""" """Core helpers for MCP Bridge Plugin."""

View File

@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
if not mcp_servers: if not mcp_servers:
return "" return ""
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2) return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)

View File

@ -34,19 +34,21 @@ from enum import Enum
# 尝试导入 MaiBot 的 logger如果失败则使用标准 logging # 尝试导入 MaiBot 的 logger如果失败则使用标准 logging
try: try:
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("mcp_client") logger = get_logger("mcp_client")
except ImportError: except ImportError:
# Fallback: 使用标准 logging # Fallback: 使用标准 logging
logger = logging.getLogger("mcp_client") logger = logging.getLogger("mcp_client")
if not logger.handlers: if not logger.handlers:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s')) handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class TransportType(Enum): class TransportType(Enum):
"""MCP 传输类型""" """MCP 传输类型"""
STDIO = "stdio" # 本地进程通信 STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP) SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐) HTTP = "http" # HTTP Streamable (新版,推荐)
@ -56,6 +58,7 @@ class TransportType(Enum):
@dataclass @dataclass
class MCPToolInfo: class MCPToolInfo:
"""MCP 工具信息""" """MCP 工具信息"""
name: str name: str
description: str description: str
input_schema: Dict[str, Any] input_schema: Dict[str, Any]
@ -65,6 +68,7 @@ class MCPToolInfo:
@dataclass @dataclass
class MCPResourceInfo: class MCPResourceInfo:
"""MCP 资源信息""" """MCP 资源信息"""
uri: str uri: str
name: str name: str
description: str description: str
@ -75,6 +79,7 @@ class MCPResourceInfo:
@dataclass @dataclass
class MCPPromptInfo: class MCPPromptInfo:
"""MCP 提示模板信息""" """MCP 提示模板信息"""
name: str name: str
description: str description: str
arguments: List[Dict[str, Any]] # [{name, description, required}] arguments: List[Dict[str, Any]] # [{name, description, required}]
@ -84,6 +89,7 @@ class MCPPromptInfo:
@dataclass @dataclass
class MCPServerConfig: class MCPServerConfig:
"""MCP 服务器配置""" """MCP 服务器配置"""
name: str name: str
enabled: bool = True enabled: bool = True
transport: TransportType = TransportType.STDIO transport: TransportType = TransportType.STDIO
@ -99,6 +105,7 @@ class MCPServerConfig:
@dataclass @dataclass
class MCPCallResult: class MCPCallResult:
"""MCP 工具调用结果""" """MCP 工具调用结果"""
success: bool success: bool
content: Any content: Any
error: Optional[str] = None error: Optional[str] = None
@ -108,6 +115,7 @@ class MCPCallResult:
class CircuitState(Enum): class CircuitState(Enum):
"""断路器状态""" """断路器状态"""
CLOSED = "closed" # 正常状态,允许请求 CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求 OPEN = "open" # 熔断状态,拒绝请求
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求 HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
@ -232,6 +240,7 @@ class CircuitBreaker:
@dataclass @dataclass
class ToolCallStats: class ToolCallStats:
"""工具调用统计""" """工具调用统计"""
tool_key: str tool_key: str
total_calls: int = 0 total_calls: int = 0
success_calls: int = 0 success_calls: int = 0
@ -282,6 +291,7 @@ class ToolCallStats:
@dataclass @dataclass
class ServerStats: class ServerStats:
"""服务器统计""" """服务器统计"""
server_name: str server_name: str
connect_count: int = 0 # 连接次数 connect_count: int = 0 # 连接次数
disconnect_count: int = 0 # 断开次数 disconnect_count: int = 0 # 断开次数
@ -442,9 +452,7 @@ class MCPClientSession:
return False return False
server_params = StdioServerParameters( server_params = StdioServerParameters(
command=self.config.command, command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
args=self.config.args,
env=self.config.env if self.config.env else None
) )
self._stdio_context = stdio_client(server_params) self._stdio_context = stdio_client(server_params)
@ -506,6 +514,7 @@ class MCPClientSession:
except Exception as e: except Exception as e:
logger.error(f"[{self.server_name}] SSE 连接失败: {e}") logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
import traceback import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup() await self._cleanup()
return False return False
@ -551,6 +560,7 @@ class MCPClientSession:
except Exception as e: except Exception as e:
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}") logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
import traceback import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup() await self._cleanup()
return False return False
@ -568,8 +578,8 @@ class MCPClientSession:
tool_info = MCPToolInfo( tool_info = MCPToolInfo(
name=tool.name, name=tool.name,
description=tool.description or f"MCP tool: {tool.name}", description=tool.description or f"MCP tool: {tool.name}",
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {}, input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
server_name=self.server_name server_name=self.server_name,
) )
self._tools.append(tool_info) self._tools.append(tool_info)
# 初始化工具统计 # 初始化工具统计
@ -591,10 +601,7 @@ class MCPClientSession:
return False return False
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
self._session.list_resources(),
timeout=self.call_timeout
)
self._resources = [] self._resources = []
for resource in result.resources: for resource in result.resources:
@ -602,8 +609,8 @@ class MCPClientSession:
uri=str(resource.uri), uri=str(resource.uri),
name=resource.name or str(resource.uri), name=resource.name or str(resource.uri),
description=resource.description or "", description=resource.description or "",
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None, mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
server_name=self.server_name server_name=self.server_name,
) )
self._resources.append(resource_info) self._resources.append(resource_info)
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}") logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
@ -633,28 +640,27 @@ class MCPClientSession:
return False return False
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
self._session.list_prompts(),
timeout=self.call_timeout
)
self._prompts = [] self._prompts = []
for prompt in result.prompts: for prompt in result.prompts:
# 解析参数 # 解析参数
arguments = [] arguments = []
if hasattr(prompt, 'arguments') and prompt.arguments: if hasattr(prompt, "arguments") and prompt.arguments:
for arg in prompt.arguments: for arg in prompt.arguments:
arguments.append({ arguments.append(
{
"name": arg.name, "name": arg.name,
"description": arg.description or "", "description": arg.description or "",
"required": arg.required if hasattr(arg, 'required') else False, "required": arg.required if hasattr(arg, "required") else False,
}) }
)
prompt_info = MCPPromptInfo( prompt_info = MCPPromptInfo(
name=prompt.name, name=prompt.name,
description=prompt.description or f"MCP prompt: {prompt.name}", description=prompt.description or f"MCP prompt: {prompt.name}",
arguments=arguments, arguments=arguments,
server_name=self.server_name server_name=self.server_name,
) )
self._prompts.append(prompt_info) self._prompts.append(prompt_info)
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}") logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
@ -686,35 +692,25 @@ class MCPClientSession:
start_time = time.time() start_time = time.time()
if not self._connected or not self._session: if not self._connected or not self._session:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
if not self._supports_resources: if not self._supports_resources:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Resources 功能"
)
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
self._session.read_resource(uri),
timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
# 处理返回内容 # 处理返回内容
content_parts = [] content_parts = []
for content in result.contents: for content in result.contents:
if hasattr(content, 'text'): if hasattr(content, "text"):
content_parts.append(content.text) content_parts.append(content.text)
elif hasattr(content, 'blob'): elif hasattr(content, "blob"):
# 二进制数据,返回 base64 或提示 # 二进制数据,返回 base64 或提示
import base64 import base64
blob_data = content.blob blob_data = content.blob
if len(blob_data) < 10000: # 小于 10KB 返回 base64 if len(blob_data) < 10000: # 小于 10KB 返回 base64
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}") content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
@ -724,28 +720,18 @@ class MCPClientSession:
content_parts.append(str(content)) content_parts.append(str(content))
return MCPCallResult( return MCPCallResult(
success=True, success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
content="\n".join(content_parts) if content_parts else "",
duration_ms=duration_ms
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
return MCPCallResult( return MCPCallResult(
success=False, success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
content=None,
error=f"读取资源超时({self.call_timeout}秒)",
duration_ms=duration_ms
) )
except Exception as e: except Exception as e:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}") logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
return MCPCallResult( return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult: async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
"""v1.2.0: 获取提示模板的内容 """v1.2.0: 获取提示模板的内容
@ -760,23 +746,14 @@ class MCPClientSession:
start_time = time.time() start_time = time.time()
if not self._connected or not self._session: if not self._connected or not self._session:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
if not self._supports_prompts: if not self._supports_prompts:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
)
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.get_prompt(name, arguments=arguments or {}), self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
timeout=self.call_timeout
) )
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
@ -784,10 +761,10 @@ class MCPClientSession:
# 处理返回的消息 # 处理返回的消息
messages = [] messages = []
for msg in result.messages: for msg in result.messages:
role = msg.role if hasattr(msg, 'role') else "unknown" role = msg.role if hasattr(msg, "role") else "unknown"
content_text = "" content_text = ""
if hasattr(msg, 'content'): if hasattr(msg, "content"):
if hasattr(msg.content, 'text'): if hasattr(msg.content, "text"):
content_text = msg.content.text content_text = msg.content.text
elif isinstance(msg.content, str): elif isinstance(msg.content, str):
content_text = msg.content content_text = msg.content
@ -796,28 +773,18 @@ class MCPClientSession:
messages.append(f"[{role}]: {content_text}") messages.append(f"[{role}]: {content_text}")
return MCPCallResult( return MCPCallResult(
success=True, success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
content="\n\n".join(messages) if messages else "",
duration_ms=duration_ms
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
return MCPCallResult( return MCPCallResult(
success=False, success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
content=None,
error=f"获取提示模板超时({self.call_timeout}秒)",
duration_ms=duration_ms
) )
except Exception as e: except Exception as e:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}") logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
return MCPCallResult( return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
async def check_health(self) -> bool: async def check_health(self) -> bool:
"""检查连接健康状态(心跳检测) """检查连接健康状态(心跳检测)
@ -829,10 +796,7 @@ class MCPClientSession:
try: try:
# 使用 list_tools 作为心跳检测 # 使用 list_tools 作为心跳检测
await asyncio.wait_for( await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
self._session.list_tools(),
timeout=10.0
)
self.stats.record_heartbeat() self.stats.record_heartbeat()
return True return True
except Exception as e: except Exception as e:
@ -849,12 +813,7 @@ class MCPClientSession:
# v1.7.0: 断路器检查 # v1.7.0: 断路器检查
can_execute, reject_reason = self._circuit_breaker.can_execute() can_execute, reject_reason = self._circuit_breaker.can_execute()
if not can_execute: if not can_execute:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"{reject_reason}", circuit_broken=True)
success=False,
content=None,
error=f"{reject_reason}",
circuit_broken=True
)
# 半开状态下增加试探计数 # 半开状态下增加试探计数
if self._circuit_breaker.state == CircuitState.HALF_OPEN: if self._circuit_breaker.state == CircuitState.HALF_OPEN:
@ -870,8 +829,7 @@ class MCPClientSession:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.call_tool(tool_name, arguments=arguments), self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
timeout=self.call_timeout
) )
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
@ -879,9 +837,9 @@ class MCPClientSession:
# 处理返回内容 # 处理返回内容
content_parts = [] content_parts = []
for content in result.content: for content in result.content:
if hasattr(content, 'text'): if hasattr(content, "text"):
content_parts.append(content.text) content_parts.append(content.text)
elif hasattr(content, 'data'): elif hasattr(content, "data"):
content_parts.append(f"[二进制数据: {len(content.data)} bytes]") content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
else: else:
content_parts.append(str(content)) content_parts.append(str(content))
@ -896,7 +854,7 @@ class MCPClientSession:
return MCPCallResult( return MCPCallResult(
success=True, success=True,
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)", content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
duration_ms=duration_ms duration_ms=duration_ms,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -939,25 +897,25 @@ class MCPClientSession:
self._supports_prompts = False # v1.2.0 self._supports_prompts = False # v1.2.0
try: try:
if hasattr(self, '_session_context') and self._session_context: if hasattr(self, "_session_context") and self._session_context:
await self._session_context.__aexit__(None, None, None) await self._session_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}") logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
try: try:
if hasattr(self, '_stdio_context') and self._stdio_context: if hasattr(self, "_stdio_context") and self._stdio_context:
await self._stdio_context.__aexit__(None, None, None) await self._stdio_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}") logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
try: try:
if hasattr(self, '_http_context') and self._http_context: if hasattr(self, "_http_context") and self._http_context:
await self._http_context.__aexit__(None, None, None) await self._http_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}") logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
try: try:
if hasattr(self, '_sse_context') and self._sse_context: if hasattr(self, "_sse_context") and self._sse_context:
await self._sse_context.__aexit__(None, None, None) await self._sse_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}") logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
@ -1082,7 +1040,9 @@ class MCPClientManager:
return True return True
if attempt < retry_attempts: if attempt < retry_attempts:
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") logger.warning(
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
)
await asyncio.sleep(retry_interval) await asyncio.sleep(retry_interval)
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})") logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
@ -1213,11 +1173,7 @@ class MCPClientManager:
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult: async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
"""调用 MCP 工具""" """调用 MCP 工具"""
if tool_key not in self._all_tools: if tool_key not in self._all_tools:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
success=False,
content=None,
error=f"工具 {tool_key} 不存在"
)
tool_info, client = self._all_tools[tool_key] tool_info, client = self._all_tools[tool_key]
@ -1273,11 +1229,7 @@ class MCPClientManager:
# 如果指定了服务器 # 如果指定了服务器
if server_name: if server_name:
if server_name not in self._clients: if server_name not in self._clients:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
client = self._clients[server_name] client = self._clients[server_name]
return await client.read_resource(uri) return await client.read_resource(uri)
@ -1293,14 +1245,11 @@ class MCPClientManager:
if result.success: if result.success:
return result return result
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
success=False,
content=None,
error=f"未找到资源: {uri}"
)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None, async def get_prompt(
server_name: Optional[str] = None) -> MCPCallResult: self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
) -> MCPCallResult:
"""v1.2.0: 获取提示模板内容 """v1.2.0: 获取提示模板内容
Args: Args:
@ -1311,11 +1260,7 @@ class MCPClientManager:
# 如果指定了服务器 # 如果指定了服务器
if server_name: if server_name:
if server_name not in self._clients: if server_name not in self._clients:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
client = self._clients[server_name] client = self._clients[server_name]
return await client.get_prompt(name, arguments) return await client.get_prompt(name, arguments)
@ -1324,11 +1269,7 @@ class MCPClientManager:
if prompt_info.name == name: if prompt_info.name == name:
return await client.get_prompt(name, arguments) return await client.get_prompt(name, arguments)
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
success=False,
content=None,
error=f"未找到提示模板: {name}"
)
# ==================== 心跳检测 ==================== # ==================== 心跳检测 ====================
@ -1489,7 +1430,9 @@ class MCPClientManager:
"global": { "global": {
**self._global_stats, **self._global_stats,
"uptime_seconds": round(uptime, 2), "uptime_seconds": round(uptime, 2),
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0, "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
if uptime > 0
else 0,
}, },
"servers": server_stats, "servers": server_stats,
"tools": tool_stats, "tools": tool_stats,

View File

@ -123,9 +123,11 @@ logger = get_logger("mcp_bridge_plugin")
# v1.4.0: 调用链路追踪 # v1.4.0: 调用链路追踪
# ============================================================================ # ============================================================================
@dataclass @dataclass
class ToolCallRecord: class ToolCallRecord:
"""工具调用记录""" """工具调用记录"""
call_id: str call_id: str
timestamp: float timestamp: float
tool_name: str tool_name: str
@ -208,9 +210,11 @@ tool_call_tracer = ToolCallTracer()
# v1.4.0: 工具调用缓存 # v1.4.0: 工具调用缓存
# ============================================================================ # ============================================================================
@dataclass @dataclass
class CacheEntry: class CacheEntry:
"""缓存条目""" """缓存条目"""
tool_name: str tool_name: str
args_hash: str args_hash: str
result: str result: str
@ -347,6 +351,7 @@ tool_call_cache = ToolCallCache()
# v1.4.0: 工具权限控制 # v1.4.0: 工具权限控制
# ============================================================================ # ============================================================================
class PermissionChecker: class PermissionChecker:
"""工具权限检查器""" """工具权限检查器"""
@ -479,6 +484,7 @@ permission_checker = PermissionChecker()
# 工具类型转换 # 工具类型转换
# ============================================================================ # ============================================================================
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType""" """将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
type_mapping = { type_mapping = {
@ -492,7 +498,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
return type_mapping.get(json_type, ToolParamType.STRING) return type_mapping.get(json_type, ToolParamType.STRING)
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: def parse_mcp_parameters(
input_schema: Dict[str, Any],
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
"""解析 MCP 工具的参数 schema转换为 MaiBot 的参数格式""" """解析 MCP 工具的参数 schema转换为 MaiBot 的参数格式"""
parameters = [] parameters = []
@ -534,6 +542,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
# MCP 工具代理 # MCP 工具代理
# ============================================================================ # ============================================================================
class MCPToolProxy(BaseTool): class MCPToolProxy(BaseTool):
"""MCP 工具代理基类""" """MCP 工具代理基类"""
@ -576,10 +585,7 @@ class MCPToolProxy(BaseTool):
# v1.4.0: 权限检查 # v1.4.0: 权限检查
if not permission_checker.check(self.name, chat_id, user_id, is_group): if not permission_checker.check(self.name, chat_id, user_id, is_group):
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}") logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
return { return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
"name": self.name,
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
}
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}") logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
@ -749,11 +755,7 @@ class MCPToolProxy(BaseTool):
return None return None
async def _call_post_process_llm( async def _call_post_process_llm(
self, self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
prompt: str,
max_tokens: int,
settings: Dict[str, Any],
server_config: Optional[Dict[str, Any]]
) -> Optional[str]: ) -> Optional[str]:
"""调用 LLM 进行后处理""" """调用 LLM 进行后处理"""
from src.config.config import model_config from src.config.config import model_config
@ -811,10 +813,7 @@ class MCPToolProxy(BaseTool):
def create_mcp_tool_class( def create_mcp_tool_class(
tool_key: str, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
) -> Type[MCPToolProxy]: ) -> Type[MCPToolProxy]:
"""根据 MCP 工具信息动态创建 BaseTool 子类""" """根据 MCP 工具信息动态创建 BaseTool 子类"""
parameters = parse_mcp_parameters(tool_info.input_schema) parameters = parse_mcp_parameters(tool_info.input_schema)
@ -837,7 +836,7 @@ def create_mcp_tool_class(
"_mcp_tool_key": tool_key, "_mcp_tool_key": tool_key,
"_mcp_original_name": tool_info.name, "_mcp_original_name": tool_info.name,
"_mcp_server_name": tool_info.server_name, "_mcp_server_name": tool_info.server_name,
} },
) )
return tool_class return tool_class
@ -851,11 +850,7 @@ class MCPToolRegistry:
self._tool_infos: Dict[str, ToolInfo] = {} self._tool_infos: Dict[str, ToolInfo] = {}
def register_tool( def register_tool(
self, self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
tool_key: str,
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
) -> Tuple[ToolInfo, Type[MCPToolProxy]]: ) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
"""注册 MCP 工具""" """注册 MCP 工具"""
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled) tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
@ -902,6 +897,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
# 内置工具 # 内置工具
# ============================================================================ # ============================================================================
class MCPReadResourceTool(BaseTool): class MCPReadResourceTool(BaseTool):
"""v1.2.0: MCP 资源读取工具""" """v1.2.0: MCP 资源读取工具"""
@ -973,6 +969,7 @@ class MCPGetPromptTool(BaseTool):
# v1.8.0: 工具链代理工具 # v1.8.0: 工具链代理工具
# ============================================================================ # ============================================================================
class ToolChainProxyBase(BaseTool): class ToolChainProxyBase(BaseTool):
"""工具链代理基类""" """工具链代理基类"""
@ -1037,7 +1034,7 @@ def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBa
"parameters": parameters, "parameters": parameters,
"available_for_llm": True, "available_for_llm": True,
"_chain_name": chain.name, "_chain_name": chain.name,
} },
) )
return tool_class return tool_class
@ -1095,7 +1092,13 @@ class MCPStatusTool(BaseTool):
name = "mcp_status" name = "mcp_status"
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息" description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
parameters = [ parameters = [
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"]), (
"query_type",
ToolParamType.STRING,
"查询类型",
False,
["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"],
),
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
] ]
available_for_llm = True available_for_llm = True
@ -1132,10 +1135,7 @@ class MCPStatusTool(BaseTool):
if query_type in ("cache",): if query_type in ("cache",):
result_parts.append(self._format_cache()) result_parts.append(self._format_cache())
return { return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
"name": self.name,
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
}
def _format_status(self, server_name: Optional[str] = None) -> str: def _format_status(self, server_name: Optional[str] = None) -> str:
status = mcp_manager.get_status() status = mcp_manager.get_status()
@ -1147,14 +1147,14 @@ class MCPStatusTool(BaseTool):
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}") lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
lines.append("\n🔌 服务器详情:") lines.append("\n🔌 服务器详情:")
for name, info in status['servers'].items(): for name, info in status["servers"].items():
if server_name and name != server_name: if server_name and name != server_name:
continue continue
status_icon = "" if info['connected'] else "" status_icon = "" if info["connected"] else ""
enabled_text = "" if info['enabled'] else " (已禁用)" enabled_text = "" if info["enabled"] else " (已禁用)"
lines.append(f" {status_icon} {name}{enabled_text}") lines.append(f" {status_icon} {name}{enabled_text}")
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}") lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
if info['consecutive_failures'] > 0: if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']}") lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']}")
return "\n".join(lines) return "\n".join(lines)
@ -1184,11 +1184,11 @@ class MCPStatusTool(BaseTool):
stats = mcp_manager.get_all_stats() stats = mcp_manager.get_all_stats()
lines = ["📈 调用统计"] lines = ["📈 调用统计"]
g = stats['global'] g = stats["global"]
lines.append(f" 总调用次数: {g['total_tool_calls']}") lines.append(f" 总调用次数: {g['total_tool_calls']}")
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}") lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
if g['total_tool_calls'] > 0: if g["total_tool_calls"] > 0:
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100 success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
lines.append(f" 成功率: {success_rate:.1f}%") lines.append(f" 成功率: {success_rate:.1f}%")
lines.append(f" 运行时间: {g['uptime_seconds']:.0f}") lines.append(f" 运行时间: {g['uptime_seconds']:.0f}")
@ -1277,7 +1277,7 @@ class MCPStatusTool(BaseTool):
lines.append(f" 描述: {chain.description[:50]}...") lines.append(f" 描述: {chain.description[:50]}...")
lines.append(f" 步骤: {len(chain.steps)}") lines.append(f" 步骤: {len(chain.steps)}")
for i, step in enumerate(chain.steps[:3]): for i, step in enumerate(chain.steps[:3]):
lines.append(f" {i+1}. {step.tool_name}") lines.append(f" {i + 1}. {step.tool_name}")
if len(chain.steps) > 3: if len(chain.steps) > 3:
lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤") lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤")
if chain.input_params: if chain.input_params:
@ -1294,6 +1294,7 @@ class MCPStatusTool(BaseTool):
# 命令处理 # 命令处理
# ============================================================================ # ============================================================================
class MCPStatusCommand(BaseCommand): class MCPStatusCommand(BaseCommand):
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态""" """MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
@ -1644,6 +1645,7 @@ class MCPStatusCommand(BaseCommand):
_plugin_instance._load_tool_chains() _plugin_instance._load_tool_chains()
chains = tool_chain_manager.get_all_chains() chains = tool_chain_manager.get_all_chains()
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
registered = 0 registered = 0
for name, chain in tool_chain_manager.get_enabled_chains().items(): for name, chain in tool_chain_manager.get_enabled_chains().items():
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_") tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
@ -1735,7 +1737,7 @@ class MCPStatusCommand(BaseCommand):
lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):") lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):")
for i, step in enumerate(chain.steps): for i, step in enumerate(chain.steps):
optional_tag = " (可选)" if step.optional else "" optional_tag = " (可选)" if step.optional else ""
lines.append(f" {i+1}. {step.tool_name}{optional_tag}") lines.append(f" {i + 1}. {step.tool_name}{optional_tag}")
if step.description: if step.description:
lines.append(f" {step.description}") lines.append(f" {step.description}")
if step.output_key: if step.output_key:
@ -1983,6 +1985,7 @@ class MCPImportCommand(BaseCommand):
# 事件处理器 # 事件处理器
# ============================================================================ # ============================================================================
class MCPStartupHandler(BaseEventHandler): class MCPStartupHandler(BaseEventHandler):
"""MCP 启动事件处理器""" """MCP 启动事件处理器"""
@ -2037,6 +2040,7 @@ class MCPStopHandler(BaseEventHandler):
# 主插件类 # 主插件类
# ============================================================================ # ============================================================================
@register_plugin @register_plugin
class MCPBridgePlugin(BasePlugin): class MCPBridgePlugin(BasePlugin):
"""MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot""" """MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
@ -2505,7 +2509,7 @@ class MCPBridgePlugin(BasePlugin):
label="📋 工具链列表", label="📋 工具链列表",
input_type="textarea", input_type="textarea",
rows=20, rows=20,
placeholder='''[ placeholder="""[
{ {
"name": "search_and_detail", "name": "search_and_detail",
"description": "先搜索再获取详情", "description": "先搜索再获取详情",
@ -2515,7 +2519,7 @@ class MCPBridgePlugin(BasePlugin):
{"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}} {"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}}
] ]
} }
]''', ]""",
hint="每个工具链包含 name、description、input_params、steps", hint="每个工具链包含 name、description、input_params、steps",
order=30, order=30,
), ),
@ -2653,9 +2657,9 @@ mcp_bing_*""",
label="📜 高级权限规则(可选)", label="📜 高级权限规则(可选)",
input_type="textarea", input_type="textarea",
rows=10, rows=10,
placeholder='''[ placeholder="""[
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
]''', ]""",
hint="格式: qq:ID:group/private/user工具名支持通配符 *", hint="格式: qq:ID:group/private/user工具名支持通配符 *",
order=10, order=10,
), ),
@ -2754,7 +2758,9 @@ mcp_bing_*""",
value = match1.group(2) value = match1.group(2)
suffix = match1.group(3) suffix = match1.group(3)
# 将转义的换行符还原为实际换行符 # 将转义的换行符还原为实际换行符
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") unescaped = (
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
)
fixed_line = f'{prefix}"""{unescaped}"""{suffix}' fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
fixed_lines.append(fixed_line) fixed_lines.append(fixed_line)
modified = True modified = True
@ -2948,11 +2954,13 @@ mcp_bing_*""",
logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}") logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}")
args_template = {} args_template = {}
steps.append({ steps.append(
{
"tool_name": tool_name, "tool_name": tool_name,
"args_template": args_template, "args_template": args_template,
"output_key": output_key, "output_key": output_key,
}) }
)
if not steps: if not steps:
logger.warning("快速添加工具链: 没有有效的步骤") logger.warning("快速添加工具链: 没有有效的步骤")
@ -3056,7 +3064,9 @@ mcp_bing_*""",
if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict): if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict):
self.config["tool_chains"] = {} self.config["tool_chains"] = {}
self.config["tool_chains"]["chains_list"] = chains_json self.config["tool_chains"]["chains_list"] = chains_json
logger.info("检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list请在 WebUI 保存一次以固化)") logger.info(
"检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list请在 WebUI 保存一次以固化)"
)
chains_config = self.config.get("tool_chains", {}) chains_config = self.config.get("tool_chains", {})
if not isinstance(chains_config, dict): if not isinstance(chains_config, dict):
@ -3153,10 +3163,7 @@ mcp_bing_*""",
# 应用过滤器 # 应用过滤器
if filter_patterns: if filter_patterns:
matched = any( matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns)
fnmatch.fnmatch(tool_name, p) or tool_name == p
for p in filter_patterns
)
if filter_mode == "whitelist": if filter_mode == "whitelist":
# 白名单模式:只注册匹配的 # 白名单模式:只注册匹配的
@ -3179,6 +3186,7 @@ mcp_bing_*""",
return result.content or "(无返回内容)" return result.content or "(无返回内容)"
else: else:
return f"工具调用失败: {result.error}" return f"工具调用失败: {result.error}"
return execute_func return execute_func
execute_func = make_execute_func(tool_key) execute_func = make_execute_func(tool_key)
@ -3207,7 +3215,9 @@ mcp_bing_*""",
return registered_count return registered_count
def _update_react_status_display(self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]) -> None: def _update_react_status_display(
self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]
) -> None:
"""更新 ReAct 工具状态显示""" """更新 ReAct 工具状态显示"""
if not registered_tools: if not registered_tools:
status_text = "(未注册任何工具)" status_text = "(未注册任何工具)"
@ -3243,12 +3253,14 @@ mcp_bing_*""",
description = param_info.get("description", f"参数 {param_name}") description = param_info.get("description", f"参数 {param_name}")
is_required = param_name in required is_required = param_name in required
parameters.append({ parameters.append(
{
"name": param_name, "name": param_name,
"type": param_type, "type": param_type,
"description": description, "description": description,
"required": is_required, "required": is_required,
}) }
)
return parameters return parameters
@ -3276,7 +3288,7 @@ mcp_bing_*""",
for i, step in enumerate(chain.steps): for i, step in enumerate(chain.steps):
opt = " (可选)" if step.optional else "" opt = " (可选)" if step.optional else ""
out = f"{step.output_key}" if step.output_key else "" out = f"{step.output_key}" if step.output_key else ""
lines.append(f" {i+1}. {step.tool_name}{out}{opt}") lines.append(f" {i + 1}. {step.tool_name}{out}{opt}")
lines.append("") lines.append("")
status_text = "\n".join(lines) status_text = "\n".join(lines)
@ -3295,6 +3307,7 @@ mcp_bing_*""",
async def _async_connect_servers(self) -> None: async def _async_connect_servers(self) -> None:
"""异步连接所有配置的 MCP 服务器v1.5.0: 并行连接优化)""" """异步连接所有配置的 MCP 服务器v1.5.0: 并行连接优化)"""
import asyncio import asyncio
settings = self.config.get("settings", {}) settings = self.config.get("settings", {})
servers_config = self._load_mcp_servers_config() servers_config = self._load_mcp_servers_config()
@ -3380,10 +3393,7 @@ mcp_bing_*""",
# 并行执行所有连接 # 并行执行所有连接
start_time = time.time() start_time = time.time()
results = await asyncio.gather( results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
*[connect_single_server(cfg) for cfg in enabled_configs],
return_exceptions=True
)
connect_duration = time.time() - start_time connect_duration = time.time() - start_time
# 统计连接结果 # 统计连接结果
@ -3404,15 +3414,14 @@ mcp_bing_*""",
# 注册所有工具 # 注册所有工具
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
registered_count = 0 registered_count = 0
for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
tool_name = tool_key.replace("-", "_").replace(".", "_") tool_name = tool_key.replace("-", "_").replace(".", "_")
is_disabled = tool_name in disabled_tools is_disabled = tool_name in disabled_tools
info, tool_class = mcp_tool_registry.register_tool( info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
tool_key, tool_info, tool_prefix, disabled=is_disabled
)
info.plugin_name = self.plugin_name info.plugin_name = self.plugin_name
if component_registry.register_component(info, tool_class): if component_registry.register_component(info, tool_class):
@ -3433,7 +3442,9 @@ mcp_bing_*""",
react_count = self._register_tools_to_react() react_count = self._register_tools_to_react()
self._initialized = True self._initialized = True
logger.info(f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具") logger.info(
f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具"
)
# 更新状态显示 # 更新状态显示
self._update_status_display() self._update_status_display()
@ -3508,7 +3519,9 @@ mcp_bing_*""",
logger.info("检测到旧版 servers.list已自动迁移为 Claude mcpServers请在 WebUI 保存一次以固化)") logger.info("检测到旧版 servers.list已自动迁移为 Claude mcpServers请在 WebUI 保存一次以固化)")
if not claude_json.strip(): if not claude_json.strip():
self._last_servers_config_error = "未配置任何 MCP 服务器(请在 WebUI 的「MCP ServersClaude」粘贴 mcpServers JSON" self._last_servers_config_error = (
"未配置任何 MCP 服务器(请在 WebUI 的「MCP ServersClaude」粘贴 mcpServers JSON"
)
return [] return []
try: try:

View File

@ -22,15 +22,18 @@ from typing import Any, Dict, List, Optional, Tuple
try: try:
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("mcp_tool_chain") logger = get_logger("mcp_tool_chain")
except ImportError: except ImportError:
import logging import logging
logger = logging.getLogger("mcp_tool_chain") logger = logging.getLogger("mcp_tool_chain")
@dataclass @dataclass
class ToolChainStep: class ToolChainStep:
"""工具链步骤""" """工具链步骤"""
tool_name: str # 要调用的工具名(如 mcp_server_tool tool_name: str # 要调用的工具名(如 mcp_server_tool
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换 args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
output_key: str = "" # 输出存储的键名,供后续步骤引用 output_key: str = "" # 输出存储的键名,供后续步骤引用
@ -60,6 +63,7 @@ class ToolChainStep:
@dataclass @dataclass
class ToolChainDefinition: class ToolChainDefinition:
"""工具链定义""" """工具链定义"""
name: str # 工具链名称(将作为组合工具的名称) name: str # 工具链名称(将作为组合工具的名称)
description: str # 工具链描述(供 LLM 理解) description: str # 工具链描述(供 LLM 理解)
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤 steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
@ -90,6 +94,7 @@ class ToolChainDefinition:
@dataclass @dataclass
class ChainExecutionResult: class ChainExecutionResult:
"""工具链执行结果""" """工具链执行结果"""
success: bool success: bool
final_output: str # 最终输出(最后一个步骤的结果) final_output: str # 最终输出(最后一个步骤的结果)
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果 step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
@ -103,7 +108,7 @@ class ChainExecutionResult:
status = "" if step.get("success") else "" status = "" if step.get("success") else ""
tool = step.get("tool_name", "unknown") tool = step.get("tool_name", "unknown")
duration = step.get("duration_ms", 0) duration = step.get("duration_ms", 0)
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)") lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
if not step.get("success") and step.get("error"): if not step.get("success") and step.get("error"):
lines.append(f" 错误: {step['error'][:50]}") lines.append(f" 错误: {step['error'][:50]}")
return "\n".join(lines) return "\n".join(lines)
@ -113,7 +118,7 @@ class ToolChainExecutor:
"""工具链执行器""" """工具链执行器"""
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev} # 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}') VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
def __init__(self, mcp_manager): def __init__(self, mcp_manager):
self._mcp_manager = mcp_manager self._mcp_manager = mcp_manager
@ -201,7 +206,7 @@ class ToolChainExecutor:
tool_key = self._resolve_tool_key(step.tool_name) tool_key = self._resolve_tool_key(step.tool_name)
if not tool_key: if not tool_key:
step_result["error"] = f"工具 {step.tool_name} 不存在" step_result["error"] = f"工具 {step.tool_name} 不存在"
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在") logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
if not step.optional: if not step.optional:
step_results.append(step_result) step_results.append(step_result)
@ -209,13 +214,13 @@ class ToolChainExecutor:
success=False, success=False,
final_output="", final_output="",
step_results=step_results, step_results=step_results,
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在", error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
step_results.append(step_result) step_results.append(step_result)
continue continue
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}") logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
# 调用工具 # 调用工具
result = await self._mcp_manager.call_tool(tool_key, resolved_args) result = await self._mcp_manager.call_tool(tool_key, resolved_args)
@ -236,10 +241,10 @@ class ToolChainExecutor:
final_output = content final_output = content
content_preview = content[:100] if content else "(空)" content_preview = content[:100] if content else "(空)"
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...") logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
else: else:
step_result["error"] = result.error or "未知错误" step_result["error"] = result.error or "未知错误"
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}") logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
if not step.optional: if not step.optional:
step_results.append(step_result) step_results.append(step_result)
@ -247,7 +252,7 @@ class ToolChainExecutor:
success=False, success=False,
final_output="", final_output="",
step_results=step_results, step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}", error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
@ -255,7 +260,7 @@ class ToolChainExecutor:
step_duration = (time.time() - step_start) * 1000 step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration step_result["duration_ms"] = step_duration
step_result["error"] = str(e) step_result["error"] = str(e)
logger.error(f"工具链步骤 {i+1} 异常: {e}") logger.error(f"工具链步骤 {i + 1} 异常: {e}")
if not step.optional: if not step.optional:
step_results.append(step_result) step_results.append(step_result)
@ -263,7 +268,7 @@ class ToolChainExecutor:
success=False, success=False,
final_output="", final_output="",
step_results=step_results, step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}", error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
@ -295,10 +300,7 @@ class ToolChainExecutor:
elif isinstance(value, dict): elif isinstance(value, dict):
resolved[key] = self._resolve_args(value, context) resolved[key] = self._resolve_args(value, context)
elif isinstance(value, list): elif isinstance(value, list):
resolved[key] = [ resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
self._substitute_vars(v, context) if isinstance(v, str) else v
for v in value
]
else: else:
resolved[key] = value resolved[key] = value
@ -306,6 +308,7 @@ class ToolChainExecutor:
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str: def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
"""替换字符串中的变量""" """替换字符串中的变量"""
def replacer(match): def replacer(match):
var_path = match.group(1) var_path = match.group(1)
return self._get_var_value(var_path, context) return self._get_var_value(var_path, context)
@ -552,7 +555,7 @@ class ToolChainManager:
try: try:
chain = ToolChainDefinition.from_dict(item) chain = ToolChainDefinition.from_dict(item)
if not chain.name: if not chain.name:
errors.append(f"{i+1} 个工具链缺少名称") errors.append(f"{i + 1} 个工具链缺少名称")
continue continue
if not chain.steps: if not chain.steps:
errors.append(f"工具链 {chain.name} 没有步骤") errors.append(f"工具链 {chain.name} 没有步骤")
@ -561,7 +564,7 @@ class ToolChainManager:
self.add_chain(chain) self.add_chain(chain)
loaded += 1 loaded += 1
except Exception as e: except Exception as e:
errors.append(f"{i+1} 个工具链解析失败: {e}") errors.append(f"{i + 1} 个工具链解析失败: {e}")
return loaded, errors return loaded, errors

View File

@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
chat_stream=self.message.chat_stream, chat_stream=self.message.chat_stream,
reply_reason=reply_reason, reply_reason=reply_reason,
enable_chinese_typo=False, enable_chinese_typo=False,
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"", extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
) )
if result_status: if result_status:
# 发送生成的回复 # 发送生成的回复

View File

@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
config_base_module.logger = logging.getLogger("config_base_test_logger") config_base_module.logger = logging.getLogger("config_base_test_logger")
class SimpleClass(ConfigBase): class SimpleClass(ConfigBase):
a: int = 1 a: int = 1
b: str = "test" b: str = "test"
@ -282,7 +283,7 @@ class TestConfigBase:
True, True,
"ConfigBase is not Hashable", "ConfigBase is not Hashable",
id="listset-validation-set-configbase-element_reject", id="listset-validation-set-configbase-element_reject",
) ),
], ],
) )
def test_validate_list_set_type(self, annotation, expect_error, error_fragment): def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
@ -340,7 +341,7 @@ class TestConfigBase:
False, False,
None, None,
id="dict-validation-happy-configbase-value", id="dict-validation-happy-configbase-value",
) ),
], ],
) )
def test_validate_dict_type(self, annotation, expect_error, error_fragment): def test_validate_dict_type(self, annotation, expect_error, error_fragment):
@ -353,13 +354,11 @@ class TestConfigBase:
field_name = "mapping" field_name = "mapping"
if expect_error: if expect_error:
# Act / Assert # Act / Assert
with pytest.raises(TypeError) as exc_info: with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name) dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value) assert error_fragment in str(exc_info.value)
else: else:
# Act # Act
dummy._validate_dict_type(annotation, field_name) dummy._validate_dict_type(annotation, field_name)

View File

@ -4,7 +4,6 @@ import importlib
import pytest import pytest
from pathlib import Path from pathlib import Path
import importlib.util import importlib.util
import asyncio
class DummyLogger: class DummyLogger:
@ -71,6 +70,7 @@ class DummyLLMRequest:
async def generate_response_for_image(self, prompt, image_base64, image_format, temp): async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
return ("dummy description", {}) return ("dummy description", {})
class DummySelect: class DummySelect:
def __init__(self, *a, **k): def __init__(self, *a, **k):
pass pass
@ -81,6 +81,7 @@ class DummySelect:
def limit(self, n): def limit(self, n):
return self return self
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_external_dependencies(monkeypatch): def patch_external_dependencies(monkeypatch):
# Provide dummy implementations as modules so that importing image_manager is safe # Provide dummy implementations as modules so that importing image_manager is safe
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
if tmp_path is not None: if tmp_path is not None:
tmpdir = Path(tmp_path) tmpdir = Path(tmp_path)
tmpdir.mkdir(parents=True, exist_ok=True) tmpdir.mkdir(parents=True, exist_ok=True)
setattr(mod, "IMAGE_DIR", tmpdir) mod.IMAGE_DIR = tmpdir
except Exception: except Exception:
pass pass
return mod return mod
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
# cleanup should run without error # cleanup should run without error
mgr.cleanup_invalid_descriptions_in_db() mgr.cleanup_invalid_descriptions_in_db()

View File

@ -1,5 +1,3 @@
import pytest
from src.config.official_configs import ChatConfig from src.config.official_configs import ChatConfig
from src.config.config import Config from src.config.config import Config
from src.webui.config_schema import ConfigSchemaGenerator from src.webui.config_schema import ConfigSchemaGenerator

View File

@ -1,6 +1,5 @@
"""Expression routes pytest tests""" """Expression routes pytest tests"""
from datetime import datetime
from typing import Generator from typing import Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -12,7 +11,6 @@ from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, select from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
def create_test_app() -> FastAPI: def create_test_app() -> FastAPI:

View File

@ -115,7 +115,7 @@ def analyze_single_file(file_path: str) -> Dict:
stats["date_range"] = { stats["date_range"] = {
"start": min_date.isoformat(), "start": min_date.isoformat(),
"end": max_date.isoformat(), "end": max_date.isoformat(),
"duration_days": (max_date - min_date).days + 1 "duration_days": (max_date - min_date).days + 1,
} }
# 检查字段存在性 # 检查字段存在性
@ -151,8 +151,8 @@ def print_file_stats(stats: Dict, index: int = None):
print(f" 文件中的 total_count: {stats['total_count']}") print(f" 文件中的 total_count: {stats['total_count']}")
print(f" 实际记录数: {stats['actual_count']}") print(f" 实际记录数: {stats['actual_count']}")
if stats['total_count'] != stats['actual_count']: if stats["total_count"] != stats["actual_count"]:
diff = stats['total_count'] - stats['actual_count'] diff = stats["total_count"] - stats["actual_count"]
print(f" ⚠️ 数量不一致,差值: {diff:+d}") print(f" ⚠️ 数量不一致,差值: {diff:+d}")
print("\n【评估结果统计】") print("\n【评估结果统计】")
@ -161,21 +161,21 @@ def print_file_stats(stats: Dict, index: int = None):
print("\n【唯一性统计】") print("\n【唯一性统计】")
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}") print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}")
if stats['actual_count'] > 0: if stats["actual_count"] > 0:
duplicate_count = stats['actual_count'] - stats['unique_pairs'] duplicate_count = stats["actual_count"] - stats["unique_pairs"]
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)") print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
print("\n【评估者统计】") print("\n【评估者统计】")
if stats['evaluators']: if stats["evaluators"]:
for evaluator, count in stats['evaluators'].most_common(): for evaluator, count in stats["evaluators"].most_common():
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" {evaluator}: {count} 条 ({rate:.2f}%)") print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
else: else:
print(" 无评估者信息") print(" 无评估者信息")
print("\n【时间统计】") print("\n【时间统计】")
if stats['date_range']: if stats["date_range"]:
print(f" 最早评估时间: {stats['date_range']['start']}") print(f" 最早评估时间: {stats['date_range']['start']}")
print(f" 最晚评估时间: {stats['date_range']['end']}") print(f" 最晚评估时间: {stats['date_range']['end']}")
print(f" 评估时间跨度: {stats['date_range']['duration_days']}") print(f" 评估时间跨度: {stats['date_range']['duration_days']}")
@ -185,8 +185,8 @@ def print_file_stats(stats: Dict, index: int = None):
print("\n【字段统计】") print("\n【字段统计】")
print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}") print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}")
print(f" 包含 reason: {'' if stats['has_reason'] else ''}") print(f" 包含 reason: {'' if stats['has_reason'] else ''}")
if stats['has_reason']: if stats["has_reason"]:
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)") print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
@ -215,15 +215,15 @@ def print_summary(all_stats: List[Dict]):
return return
# 汇总记录统计 # 汇总记录统计
total_records = sum(s['actual_count'] for s in valid_files) total_records = sum(s["actual_count"] for s in valid_files)
total_suitable = sum(s['suitable_count'] for s in valid_files) total_suitable = sum(s["suitable_count"] for s in valid_files)
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files) total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
total_unique_pairs = set() total_unique_pairs = set()
# 收集所有唯一的(situation, style)对 # 收集所有唯一的(situation, style)对
for stats in valid_files: for stats in valid_files:
try: try:
with open(stats['file_path'], "r", encoding="utf-8") as f: with open(stats["file_path"], "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
results = data.get("manual_results", []) results = data.get("manual_results", [])
for r in results: for r in results:
@ -234,8 +234,16 @@ def print_summary(all_stats: List[Dict]):
print("\n【记录汇总】") print("\n【记录汇总】")
print(f" 总记录数: {total_records:,}") print(f" 总记录数: {total_records:,}")
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条") print(
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条") f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
if total_records > 0
else " 通过: 0 条"
)
print(
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
if total_records > 0
else " 不通过: 0 条"
)
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}") print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}")
if total_records > 0: if total_records > 0:
@ -246,7 +254,7 @@ def print_summary(all_stats: List[Dict]):
# 汇总评估者统计 # 汇总评估者统计
all_evaluators = Counter() all_evaluators = Counter()
for stats in valid_files: for stats in valid_files:
all_evaluators.update(stats['evaluators']) all_evaluators.update(stats["evaluators"])
print("\n【评估者汇总】") print("\n【评估者汇总】")
if all_evaluators: if all_evaluators:
@ -259,7 +267,7 @@ def print_summary(all_stats: List[Dict]):
# 汇总时间范围 # 汇总时间范围
all_dates = [] all_dates = []
for stats in valid_files: for stats in valid_files:
all_dates.extend(stats['evaluation_dates']) all_dates.extend(stats["evaluation_dates"])
if all_dates: if all_dates:
min_date = min(all_dates) min_date = min(all_dates)
@ -270,7 +278,7 @@ def print_summary(all_stats: List[Dict]):
print(f" 总时间跨度: {(max_date - min_date).days + 1}") print(f" 总时间跨度: {(max_date - min_date).days + 1}")
# 文件大小汇总 # 文件大小汇总
total_size = sum(s['file_size'] for s in valid_files) total_size = sum(s["file_size"] for s in valid_files)
avg_size = total_size / len(valid_files) if valid_files else 0 avg_size = total_size / len(valid_files) if valid_files else 0
print("\n【文件大小汇总】") print("\n【文件大小汇总】")
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)") print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
@ -318,5 +326,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -171,7 +171,9 @@ def main():
sys.exit(1) sys.exit(1)
if not args.raw_index: if not args.raw_index:
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3") logger.info(
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
)
sys.exit(1) sys.exit(1)
# 解析索引列表1-based # 解析索引列表1-based

View File

@ -71,7 +71,7 @@ def save_results(evaluation_results: List[Dict]):
data = { data = {
"last_updated": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results), "total_count": len(evaluation_results),
"evaluation_results": evaluation_results "evaluation_results": evaluation_results,
} }
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f: with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
@ -84,9 +84,7 @@ def save_results(evaluation_results: List[Dict]):
print(f"\n✗ 保存评估结果失败: {e}") print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation( def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
evaluated_pairs: Set[Tuple[str, str]] = None
) -> List[Expression]:
""" """
选择用于评估的表达方式 选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目 选择所有count>1的项目然后选择两倍数量的count=1的项目
@ -109,10 +107,7 @@ def select_expressions_for_evaluation(
return [] return []
# 过滤出未评估的项目 # 过滤出未评估的项目
unevaluated = [ unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
if not unevaluated: if not unevaluated:
logger.warning("所有项目都已评估完成") logger.warning("所有项目都已评估完成")
@ -132,7 +127,9 @@ def select_expressions_for_evaluation(
count_eq1_needed = count_gt1_count * 2 count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed: if len(count_eq1) < count_eq1_needed:
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}") logger.warning(
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}"
)
count_eq1_needed = len(count_eq1) count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目 # 随机选择count=1的项目
@ -141,13 +138,16 @@ def select_expressions_for_evaluation(
selected = selected_count_gt1 + selected_count_eq1 selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序 random.shuffle(selected) # 打乱顺序
logger.info(f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍") logger.info(
f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍"
)
return selected return selected
except Exception as e: except Exception as e:
logger.error(f"选择表达方式失败: {e}") logger.error(f"选择表达方式失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return [] return []
@ -202,9 +202,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
logger.debug(f"正在评估表达方式: situation={situation}, style={style}") logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async( response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt, prompt=prompt, temperature=0.6, max_tokens=1024
temperature=0.6,
max_tokens=1024
) )
logger.debug(f"LLM响应: {response}") logger.debug(f"LLM响应: {response}")
@ -241,7 +239,9 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
Returns: Returns:
评估结果字典 评估结果字典
""" """
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}") logger.info(
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
)
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm) suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
@ -258,7 +258,7 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
"reason": reason, "reason": reason,
"error": error, "error": error,
"evaluator": "llm", "evaluator": "llm",
"evaluated_at": datetime.now().isoformat() "evaluated_at": datetime.now().isoformat(),
} }
@ -302,7 +302,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print(f"Count = {count}:") print(f"Count = {count}:")
print(f" 总数: {total}") print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)") print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)") print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
print() print()
# 比较count=1和count>1 # 比较count=1和count>1
@ -340,12 +340,12 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print("Count = 1:") print("Count = 1:")
print(f" 总数: {eq1_total}") print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)") print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)") print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
print() print()
print("Count > 1:") print("Count > 1:")
print(f" 总数: {gt1_total}") print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)") print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)") print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
print() print()
# 进行卡方检验简化版使用2x2列联表 # 进行卡方检验简化版使用2x2列联表
@ -427,7 +427,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
"count_groups": {str(k): v for k, v in count_groups.items()}, "count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group, "count_eq1": count_eq1_group,
"count_gt1": count_gt1_group, "count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results) "total_evaluated": len(evaluation_results),
} }
try: try:
@ -466,8 +466,7 @@ async def main():
try: try:
all_expressions = list(Expression.select()) all_expressions = list(Expression.select())
unevaluated_count_gt1 = [ unevaluated_count_gt1 = [
expr for expr in all_expressions expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
] ]
has_unevaluated = len(unevaluated_count_gt1) > 0 has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e: except Exception as e:
@ -485,21 +484,20 @@ async def main():
try: try:
llm = LLMRequest( llm = LLMRequest(
model_set=model_config.model_task_config.tool_use, model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm" request_type="expression_evaluator_count_analysis_llm",
) )
print("✓ LLM实例创建成功\n") print("✓ LLM实例创建成功\n")
except Exception as e: except Exception as e:
logger.error(f"创建LLM实例失败: {e}") logger.error(f"创建LLM实例失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}") print(f"\n✗ 创建LLM实例失败: {e}")
db.close() db.close()
return return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目 # 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation( expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
evaluated_pairs=evaluated_pairs
)
if not expressions: if not expressions:
print("\n没有可评估的项目") print("\n没有可评估的项目")
@ -518,7 +516,7 @@ async def main():
llm_result = await llm_evaluate_expression(expression, llm) llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}") print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get('error'): if llm_result.get("error"):
print(f" 错误: {llm_result['error']}") print(f" 错误: {llm_result['error']}")
print() print()
@ -553,4 +551,3 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -140,9 +140,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
logger.debug(f"正在评估表达方式: situation={situation}, style={style}") logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async( response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt, prompt=prompt, temperature=0.6, max_tokens=1024
temperature=0.6,
max_tokens=1024
) )
logger.debug(f"LLM响应: {response}") logger.debug(f"LLM响应: {response}")
@ -152,6 +150,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
evaluation = json.loads(response) evaluation = json.loads(response)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
import re import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL) json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match: if json_match:
evaluation = json.loads(json_match.group()) evaluation = json.loads(json_match.group())
@ -196,7 +195,7 @@ async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -
"suitable": suitable, "suitable": suitable,
"reason": reason, "reason": reason,
"error": error, "error": error,
"evaluator": "llm" "evaluator": "llm",
} }
@ -244,10 +243,16 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
false_negatives += 1 false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0 accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0 precision = (
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0 (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
)
recall = (
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
)
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0 f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0 specificity = (
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
)
# 计算人工效标的不合适率 # 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数 manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
@ -283,7 +288,7 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
"specificity": specificity, "specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate, "manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate, "llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference "rate_difference": rate_difference,
} }
@ -334,13 +339,11 @@ async def main(count: int | None = None):
# 2. 创建LLM实例并评估 # 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例") print("\n步骤2: 创建LLM实例")
try: try:
llm = LLMRequest( llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_llm"
)
except Exception as e: except Exception as e:
logger.error(f"创建LLM实例失败: {e}") logger.error(f"创建LLM实例失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return return
@ -348,11 +351,7 @@ async def main(count: int | None = None):
llm_results = [] llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1): for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}") print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm( llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
manual_result["situation"],
manual_result["style"],
llm
))
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前 # 5. 输出FP和FN项目在评估结果之前
@ -372,14 +371,16 @@ async def main(count: int | None = None):
# 人工评估不通过但LLM评估通过FP情况 # 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]: if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({ fp_items.append(
{
"situation": manual_result["situation"], "situation": manual_result["situation"],
"style": manual_result["style"], "style": manual_result["style"],
"manual_suitable": manual_result["suitable"], "manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"], "llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"), "llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error") "llm_error": llm_result.get("error"),
}) }
)
if fp_items: if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n") print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
@ -389,7 +390,7 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}") print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌") print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)") print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'): if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}") print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}") print(f"LLM理由: {item['llm_reason']}")
print() print()
@ -410,14 +411,16 @@ async def main(count: int | None = None):
# 人工评估通过但LLM评估不通过FN情况 # 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]: if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append({ fn_items.append(
{
"situation": manual_result["situation"], "situation": manual_result["situation"],
"style": manual_result["style"], "style": manual_result["style"],
"manual_suitable": manual_result["suitable"], "manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"], "llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"), "llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error") "llm_error": llm_result.get("error"),
}) }
)
if fn_items: if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n") print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
@ -427,7 +430,7 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}") print(f"Style: {item['style']}")
print("人工评估: 通过 ✅") print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)") print("LLM评估: 不通过 ❌ (误删)")
if item.get('llm_error'): if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}") print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}") print(f"LLM理由: {item['llm_reason']}")
print() print()
@ -447,13 +450,21 @@ async def main(count: int | None = None):
print() print()
# print(" 【核心能力指标】") # print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)") print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})") print(
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}") f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)") # print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print() print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)") print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})") print(
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}") f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
)
print(
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)") # print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print() print()
print(" 【其他指标】") print(" 【其他指标】")
@ -464,12 +475,18 @@ async def main(count: int | None = None):
print() print()
print(" 【不合适率分析】") print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%") print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}") print(
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
)
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适") print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print() print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%") print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})") print(
print(f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%") f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
)
print() print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%") # print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%") # print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
@ -486,11 +503,12 @@ async def main(count: int | None = None):
try: try:
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, "w", encoding="utf-8") as f:
json.dump({ json.dump(
"manual_results": valid_manual_results, {"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
"llm_results": llm_results, f,
"comparison": comparison ensure_ascii=False,
}, f, ensure_ascii=False, indent=2) indent=2,
)
logger.info(f"\n评估结果已保存到: {output_file}") logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e: except Exception as e:
logger.warning(f"保存结果到文件失败: {e}") logger.warning(f"保存结果到文件失败: {e}")
@ -509,15 +527,9 @@ if __name__ == "__main__":
python evaluate_expressions_llm_v6.py # 使用全部数据 python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据 python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据 python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
""" """,
)
parser.add_argument(
"-n", "--count",
type=int,
default=None,
help="随机选取的数据条数(默认:使用全部数据)"
) )
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
args = parser.parse_args() args = parser.parse_args()
asyncio.run(main(count=args.count)) asyncio.run(main(count=args.count))

View File

@ -65,7 +65,7 @@ def save_results(manual_results: List[Dict]):
data = { data = {
"last_updated": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(),
"total_count": len(manual_results), "total_count": len(manual_results),
"manual_results": manual_results "manual_results": manual_results,
} }
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f: with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
@ -98,10 +98,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
return [] return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致 # 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [ unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
if not unevaluated: if not unevaluated:
logger.info("所有项目都已评估完成") logger.info("所有项目都已评估完成")
@ -120,6 +117,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
except Exception as e: except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}") logger.error(f"获取未评估表达方式失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return [] return []
@ -150,18 +148,18 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
while True: while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower() user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ['q', 'quit']: if user_input in ["q", "quit"]:
print("退出评估") print("退出评估")
return None return None
if user_input in ['s', 'skip']: if user_input in ["s", "skip"]:
print("跳过当前项目") print("跳过当前项目")
return "skip" return "skip"
if user_input in ['y', 'yes', '1', '', '通过']: if user_input in ["y", "yes", "1", "", "通过"]:
suitable = True suitable = True
break break
elif user_input in ['n', 'no', '0', '', '不通过']: elif user_input in ["n", "no", "0", "", "不通过"]:
suitable = False suitable = False
break break
else: else:
@ -173,7 +171,7 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
"suitable": suitable, "suitable": suitable,
"reason": None, "reason": None,
"evaluator": "manual", "evaluator": "manual",
"evaluated_at": datetime.now().isoformat() "evaluated_at": datetime.now().isoformat(),
} }
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}") print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
@ -257,9 +255,9 @@ def main():
# 询问是否继续 # 询问是否继续
while True: while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower() continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ['y', 'yes', '1', '', '继续']: if continue_input in ["y", "yes", "1", "", "继续"]:
break break
elif continue_input in ['n', 'no', '0', '', '退出']: elif continue_input in ["n", "no", "0", "", "退出"]:
print("\n评估结束") print("\n评估结束")
return return
else: else:
@ -275,4 +273,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -134,9 +134,7 @@ def handle_import_openie(
# 在非交互模式下,不再询问用户,而是直接报错终止 # 在非交互模式下,不再询问用户,而是直接报错终止
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
if non_interactive: if non_interactive:
logger.error( logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
)
sys.exit(1) sys.exit(1)
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower() user_choice = input().strip().lower()
@ -189,9 +187,7 @@ def handle_import_openie(
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
# 新增确认提示 # 新增确认提示
if non_interactive: if non_interactive:
logger.warning( logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
)
else: else:
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型") print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
def main(argv: Optional[list[str]] = None) -> None: def main(argv: Optional[list[str]] = None) -> None:
"""主函数 - 解析参数并运行异步主流程。""" """主函数 - 解析参数并运行异步主流程。"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=( description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
"将其导入到 LPMM 的向量库与知识图中。"
)
) )
parser.add_argument( parser.add_argument(
"--non-interactive", "--non-interactive",

View File

@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
ensure_dirs() # 确保目录存在 ensure_dirs() # 确保目录存在
# 新增用户确认提示 # 新增用户确认提示
if non_interactive: if non_interactive:
logger.warning( logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
)
else: else:
print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。") print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@ -68,4 +68,3 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -29,6 +29,7 @@ except ImportError as e:
logger = get_logger("lpmm_interactive_manager") logger = get_logger("lpmm_interactive_manager")
async def interactive_add(): async def interactive_add():
"""交互式导入知识""" """交互式导入知识"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -68,6 +69,7 @@ async def interactive_add():
print(f"\n[×] 发生异常: {e}") print(f"\n[×] 发生异常: {e}")
logger.error(f"add_content 异常: {e}", exc_info=True) logger.error(f"add_content 异常: {e}", exc_info=True)
async def interactive_delete(): async def interactive_delete():
"""交互式删除知识""" """交互式删除知识"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -108,8 +110,12 @@ async def interactive_delete():
return return
print("-" * 40) print("-" * 40)
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower() confirm = (
if confirm != 'y': input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
.strip()
.lower()
)
if confirm != "y":
print("\n[!] 已取消删除操作。") print("\n[!] 已取消删除操作。")
return return
@ -129,6 +135,7 @@ async def interactive_delete():
print(f"\n[×] 发生异常: {e}") print(f"\n[×] 发生异常: {e}")
logger.error(f"delete 异常: {e}", exc_info=True) logger.error(f"delete 异常: {e}", exc_info=True)
async def interactive_clear(): async def interactive_clear():
"""交互式清空知识库""" """交互式清空知识库"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -165,16 +172,21 @@ async def interactive_clear():
before = stats.get("before", {}) before = stats.get("before", {})
after = stats.get("after", {}) after = stats.get("after", {})
print("\n[统计信息]") print("\n[统计信息]")
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, " print(
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}") f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, " f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}") )
print(
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
)
else: else:
print(f"\n[×] 失败:{result['message']}") print(f"\n[×] 失败:{result['message']}")
except Exception as e: except Exception as e:
print(f"\n[×] 发生异常: {e}") print(f"\n[×] 发生异常: {e}")
logger.error(f"clear_all 异常: {e}", exc_info=True) logger.error(f"clear_all 异常: {e}", exc_info=True)
async def interactive_search(): async def interactive_search():
"""交互式查询知识""" """交互式查询知识"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -224,6 +236,7 @@ async def interactive_search():
print(f"\n[×] 查询失败: {e}") print(f"\n[×] 查询失败: {e}")
logger.error(f"查询异常: {e}", exc_info=True) logger.error(f"查询异常: {e}", exc_info=True)
async def main(): async def main():
"""主循环""" """主循环"""
while True: while True:
@ -253,6 +266,7 @@ async def main():
else: else:
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。") print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
# 运行主循环 # 运行主循环
@ -262,4 +276,3 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"\n[!] 程序运行出错: {e}") print(f"\n[!] 程序运行出错: {e}")
logger.error(f"Main loop 异常: {e}", exc_info=True) logger.error(f"Main loop 异常: {e}", exc_info=True)

View File

@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data" raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
txt_files = list(raw_dir.glob("*.txt")) txt_files = list(raw_dir.glob("*.txt"))
if not txt_files: if not txt_files:
msg = ( msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件info_extraction 可能立即退出或无数据可处理。"
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
"info_extraction 可能立即退出或无数据可处理。"
)
print(msg) print(msg)
if non_interactive: if non_interactive:
logger.error( logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
)
return False return False
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower() cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
return cont == "y" return cont == "y"
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
openie_dir = Path(PROJECT_ROOT) / "data" / "openie" openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
json_files = list(openie_dir.glob("*.json")) json_files = list(openie_dir.glob("*.json"))
if not json_files: if not json_files:
msg = ( msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件import_openie 可能会因为找不到批次而失败。"
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
"import_openie 可能会因为找不到批次而失败。"
)
print(msg) print(msg)
if non_interactive: if non_interactive:
logger.error( logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
)
return False return False
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower() cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
return cont == "y" return cont == "y"
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
"""在部分操作前提醒 lpmm_knowledge.enable 状态。""" """在部分操作前提醒 lpmm_knowledge.enable 状态。"""
try: try:
if not getattr(global_config.lpmm_knowledge, "enable", False): if not getattr(global_config.lpmm_knowledge, "enable", False):
print( print("[WARN] 当前配置 lpmm_knowledge.enable = false刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
"[WARN] 当前配置 lpmm_knowledge.enable = false"
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
)
except Exception: except Exception:
# 配置异常时不阻断主流程,仅忽略提示 # 配置异常时不阻断主流程,仅忽略提示
pass pass
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
if action == "prepare_raw": if action == "prepare_raw":
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...") logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
sha_list, raw_data = load_raw_data() sha_list, raw_data = load_raw_data()
print( print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
elif action == "info_extract": elif action == "info_extract":
if not _check_before_info_extract("--non-interactive" in extra_args): if not _check_before_info_extract("--non-interactive" in extra_args):
print("已根据用户选择,取消执行信息提取。") print("已根据用户选择,取消执行信息提取。")
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新 # 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新") logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
sha_list, raw_data = load_raw_data() sha_list, raw_data = load_raw_data()
print( print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
non_interactive = "--non-interactive" in extra_args non_interactive = "--non-interactive" in extra_args
if not _check_before_info_extract(non_interactive): if not _check_before_info_extract(non_interactive):
print("已根据用户选择,取消 full_import信息提取阶段被取消") print("已根据用户选择,取消 full_import信息提取阶段被取消")
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
) )
# 快速选项:按推荐方式清理所有相关实体/关系 # 快速选项:按推荐方式清理所有相关实体/关系
quick_all = input( quick_all = (
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): " input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
).strip().lower() )
if quick_all in ("", "y", "yes"): if quick_all in ("", "y", "yes"):
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"]) args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
else: else:
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
def _interactive_build_batch_inspect_args() -> List[str]: def _interactive_build_batch_inspect_args() -> List[str]:
"""为 inspect_lpmm_batch 构造 --openie-file 参数。""" """为 inspect_lpmm_batch 构造 --openie-file 参数。"""
path = _interactive_choose_openie_file( path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
)
if not path: if not path:
return [] return []
return ["--openie-file", path] return ["--openie-file", path]
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
def _interactive_build_test_args() -> List[str]: def _interactive_build_test_args() -> List[str]:
"""为 test_lpmm_retrieval 构造自定义测试用例参数。""" """为 test_lpmm_retrieval 构造自定义测试用例参数。"""
print( print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
"\n[TEST] 你可以:\n"
"- 直接回车使用内置的默认测试用例;\n"
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
)
query = input("请输入自定义测试问题(回车则使用默认用例):").strip() query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
if not query: if not query:
return [] return []
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}") print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}") print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
new_dim = input( new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
).strip()
if new_dim and not new_dim.isdigit(): if new_dim and not new_dim.isdigit():
print("输入的维度不是纯数字,已取消操作。") print("输入的维度不是纯数字,已取消操作。")
return return
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -28,6 +28,7 @@ from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval") logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题 # 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval(): def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入""" """使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
@ -44,7 +45,7 @@ def _import_memory_retrieval():
# 如果 prompt 已经初始化,尝试直接使用已加载的模块 # 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules: if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name] existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'): if hasattr(existing_module, "init_memory_retrieval_prompt"):
return ( return (
existing_module.init_memory_retrieval_prompt, existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question, existing_module._react_agent_solve_question,
@ -54,14 +55,14 @@ def _import_memory_retrieval():
# 如果模块已经在 sys.modules 中但部分初始化,先移除它 # 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules: if module_name in sys.modules:
existing_module = sys.modules[module_name] existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'): if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它 # 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入") logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name] del sys.modules[module_name]
# 清理可能相关的部分初始化模块 # 清理可能相关的部分初始化模块
keys_to_remove = [] keys_to_remove = []
for key in sys.modules.keys(): for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system': if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key) keys_to_remove.append(key)
for key in keys_to_remove: for key in keys_to_remove:
try: try:
@ -75,6 +76,7 @@ def _import_memory_retrieval():
# 先导入可能触发循环导入的模块,让它们完成初始化 # 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config import src.config.config
import src.chat.utils.prompt_builder import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化 # 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
@ -253,7 +255,7 @@ async def test_memory_retrieval(
包含测试结果的字典 包含测试结果的字典
""" """
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"[测试] 记忆检索测试") print("[测试] 记忆检索测试")
print(f"[问题] {question}") print(f"[问题] {question}")
print("=" * 80) print("=" * 80)
@ -267,6 +269,7 @@ async def test_memory_retrieval(
# 检查 prompt 是否已经初始化,避免重复初始化 # 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts: if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt() init_memory_retrieval_prompt()
else: else:
@ -284,7 +287,7 @@ async def test_memory_retrieval(
timeout = global_config.memory.agent_timeout_seconds timeout = global_config.memory.agent_timeout_seconds
print(f"\n[配置]") print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}") print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}") print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}") print(f" 聊天ID: {chat_id}")
@ -321,26 +324,26 @@ async def test_memory_retrieval(
# 输出结果 # 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print(f"\n[结果]") print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}") print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer: if found_answer and answer:
print(f" 答案: {answer}") print(f" 答案: {answer}")
else: else:
print(f" 答案: (未找到答案)") print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}") print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}") print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}") print(f" 总耗时: {elapsed_time:.2f}")
print(f"\n[Token使用情况]") print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}") print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}") print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}") print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}") print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}") print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage['model_usage']: if token_usage["model_usage"]:
print(f"\n[按模型统计]") print("\n[按模型统计]")
for model_name, usage in token_usage['model_usage'].items(): for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:") print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}") print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}") print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
@ -348,7 +351,7 @@ async def test_memory_retrieval(
print(f" 总Tokens: {usage['total_tokens']:,}") print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}") print(f" 成本: ${usage['cost']:.6f}")
print(f"\n[迭代详情]") print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps)) print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80) print("\n" + "=" * 80)
@ -444,4 +447,3 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -455,6 +455,7 @@ class ExpressionSelector:
expr_obj.save() expr_obj.save()
logger.debug("表达方式激活: 更新last_active_time in db") logger.debug("表达方式激活: 更新last_active_time in db")
try: try:
expression_selector = ExpressionSelector() expression_selector = ExpressionSelector()
except Exception as e: except Exception as e:

View File

@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
logger = get_logger("jargon") logger = get_logger("jargon")
class JargonExplainer: class JargonExplainer:
"""黑话解释器,用于在回复前识别和解释上下文中的黑话""" """黑话解释器,用于在回复前识别和解释上下文中的黑话"""

View File

@ -2,7 +2,6 @@ import time
import asyncio import asyncio
from typing import List, Any from typing import List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.common_utils import TempMethodsExpression
@ -119,9 +118,7 @@ class MessageRecorder:
# 触发 expression_learner 和 jargon_miner 的处理 # 触发 expression_learner 和 jargon_miner 的处理
if self.enable_expression_learning: if self.enable_expression_learning:
asyncio.create_task( asyncio.create_task(self._trigger_expression_learning(messages))
self._trigger_expression_learning(messages)
)
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
@ -130,9 +127,7 @@ class MessageRecorder:
traceback.print_exc() traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试 # 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning( async def _trigger_expression_learning(self, messages: List[Any]) -> None:
self, messages: List[Any]
) -> None:
""" """
触发 expression 学习使用指定的消息列表 触发 expression 学习使用指定的消息列表

View File

@ -1,5 +1,5 @@
import time import time
from typing import Tuple, Optional, Dict, Any # 增加了 Optional from typing import Tuple, Optional # 增加了 Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@ -170,13 +170,10 @@ class ActionPlanner:
) )
break break
else: else:
logger.debug( logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
)
except Exception as e: except Exception as e:
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}") logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
# --- 获取超时提示信息 --- # --- 获取超时提示信息 ---
# (这部分逻辑不变) # (这部分逻辑不变)
timeout_context = "" timeout_context = ""

View File

@ -112,7 +112,7 @@ class Conversation:
"user_nickname": msg.user_info.user_nickname if msg.user_info else "", "user_nickname": msg.user_info.user_nickname if msg.user_info else "",
"user_cardname": msg.user_info.user_cardname if msg.user_info else None, "user_cardname": msg.user_info.user_cardname if msg.user_info else None,
"platform": msg.user_info.platform if msg.user_info else "", "platform": msg.user_info.platform if msg.user_info else "",
} },
} }
initial_messages_dict.append(msg_dict) initial_messages_dict.append(msg_dict)

View File

@ -5,7 +5,7 @@ from src.common.logger import get_logger
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType, Notification from .chat_states import NotificationHandler, NotificationType, Notification
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseMessages
import traceback # 导入 traceback 用于调试 import traceback # 导入 traceback 用于调试
logger = get_logger("observation_info") logger = get_logger("observation_info")

View File

@ -42,9 +42,7 @@ class GoalAnalyzer:
"""对话目标分析器""" """对话目标分析器"""
def __init__(self, stream_id: str, private_name: str): def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest( self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
)
self.personality_info = self._get_personality_prompt() self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname self.name = global_config.bot.nickname

View File

@ -1,10 +1,10 @@
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned # NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
# from src.plugins.memory_system.Hippocampus import HippocampusManager # from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import model_config
from src.chat.message_receive.message import Message
from src.chat.knowledge import qa_manager from src.chat.knowledge import qa_manager
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
@ -16,9 +16,7 @@ class KnowledgeFetcher:
"""知识调取器""" """知识调取器"""
def __init__(self, private_name: str): def __init__(self, private_name: str):
self.llm = LLMRequest( self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
model_set=model_config.model_task_config.utils
)
self.private_name = private_name self.private_name = private_name
def _lpmm_get_knowledge(self, query: str) -> str: def _lpmm_get_knowledge(self, query: str) -> str:

View File

@ -14,10 +14,7 @@ class ReplyChecker:
"""回复检查器""" """回复检查器"""
def __init__(self, stream_id: str, private_name: str): def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest( self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
model_set=model_config.model_task_config.utils,
request_type="reply_check"
)
self.personality_info = self._get_personality_prompt() self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname self.name = global_config.bot.nickname
self.private_name = private_name self.private_name = private_name

View File

@ -704,10 +704,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断 # 等待指定时间,但可被新消息打断
try: try:
await asyncio.wait_for( await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
self._new_message_event.wait(),
timeout=wait_seconds
)
# 如果事件被触发,说明有新消息到达 # 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待") logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -731,7 +728,9 @@ class BrainChatting:
# 使用默认等待时间 # 使用默认等待时间
wait_seconds = 3 wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)") logger.info(
f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)"
)
# 清除事件状态,准备等待新消息 # 清除事件状态,准备等待新消息
self._new_message_event.clear() self._new_message_event.clear()
@ -749,10 +748,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断 # 等待指定时间,但可被新消息打断
try: try:
await asyncio.wait_for( await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
self._new_message_event.wait(),
timeout=wait_seconds
)
# 如果事件被触发,说明有新消息到达 # 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待") logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@ -431,7 +431,9 @@ class BrainPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}" extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
return extracted_reasoning, [ return (
extracted_reasoning,
[
ActionPlannerInfo( ActionPlannerInfo(
action_type="complete_talk", action_type="complete_talk",
reasoning=extracted_reasoning, reasoning=extracted_reasoning,
@ -439,7 +441,11 @@ class BrainPlanner:
action_message=None, action_message=None,
available_actions=available_actions, available_actions=available_actions,
) )
], llm_content, llm_reasoning, llm_duration_ms ],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应 # 解析LLM响应
if llm_content: if llm_content:

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
import time import time
from typing import List, Union, Dict, Any from typing import List, Union
from .global_logger import logger from .global_logger import logger
from . import prompt_template from . import prompt_template
@ -200,9 +200,7 @@ class IEProcess:
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁 # 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行 # 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
try: try:
entities, triples = await asyncio.to_thread( entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
)
if entities is not None: if entities is not None:
results.append( results.append(

View File

@ -395,8 +395,7 @@ class KGManager:
appear_cnt = self.ent_appear_cnt.get(ent_hash) appear_cnt = self.ent_appear_cnt.get(ent_hash)
if not appear_cnt or appear_cnt <= 0: if not appear_cnt or appear_cnt <= 0:
logger.debug( logger.debug(
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0" f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0将使用 1.0 作为默认出现次数参与权重计算"
f"将使用 1.0 作为默认出现次数参与权重计算"
) )
appear_cnt = 1.0 appear_cnt = 1.0
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt) ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)

View File

@ -11,6 +11,7 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API") logger = get_logger("LPMM-Plugin-API")
class LPMMOperations: class LPMMOperations:
""" """
LPMM 内部操作接口 LPMM 内部操作接口
@ -20,9 +21,7 @@ class LPMMOperations:
def __init__(self): def __init__(self):
self._initialized = False self._initialized = False
async def _run_cancellable_executor( async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
self, func: Callable, *args, **kwargs
) -> Any:
""" """
在线程池中执行可取消的同步操作 在线程池中执行可取消的同步操作
当任务被取消时 Ctrl+C会立即响应并抛出 CancelledError 当任务被取消时 Ctrl+C会立即响应并抛出 CancelledError
@ -79,7 +78,7 @@ class LPMMOperations:
# 1. 分段处理 # 1. 分段处理
if auto_split: if auto_split:
# 自动按双换行符分割 # 自动按双换行符分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
else: else:
# 不分割,作为完整一段 # 不分割,作为完整一段
text_stripped = text.strip() text_stripped = text.strip()
@ -95,7 +94,9 @@ class LPMMOperations:
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract") llm_ner = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf) ie_process = IEProcess(llm_ner, llm_rdf)
@ -128,25 +129,21 @@ class LPMMOperations:
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入 # 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储 # store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor( await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
embed_mgr.store_new_data_set,
new_raw_paragraphs,
new_triple_list_data
)
# 3. 构建知识图谱只需要三元组数据和embedding_manager # 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor( await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
kg_mgr.build_kg,
new_triple_list_data,
embed_mgr
)
# 4. 持久化 # 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file)
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"} return {
"status": "success",
"count": len(new_raw_paragraphs),
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
}
except asyncio.CancelledError: except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断") logger.warning("[Plugin API] 导入操作被用户中断")
@ -215,8 +212,7 @@ class LPMMOperations:
# a. 从向量库删除 # a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor( deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items, embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
to_delete_keys
) )
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys()) embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
@ -224,10 +220,7 @@ class LPMMOperations:
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial( delete_func = partial(
kg_mgr.delete_paragraphs, kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
to_delete_hashes,
ent_hashes=None,
remove_orphan_entities=True
) )
await self._run_cancellable_executor(delete_func) await self._run_cancellable_executor(delete_func)
@ -237,7 +230,11 @@ class LPMMOperations:
await self._run_cancellable_executor(kg_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词" match_type = "完整文段" if exact_match else "关键词"
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"} return {
"status": "success",
"deleted_count": deleted_count,
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
}
except asyncio.CancelledError: except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断") logger.warning("[Plugin API] 删除操作被用户中断")
@ -275,16 +272,14 @@ class LPMMOperations:
# 删除所有段落向量 # 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor( para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items, embed_mgr.paragraphs_embedding_store.delete_items, para_keys
para_keys
) )
embed_mgr.stored_pg_hashes.clear() embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量 # 删除所有实体向量
if ent_keys: if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor( ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items, embed_mgr.entities_embedding_store.delete_items, ent_keys
ent_keys
) )
else: else:
ent_deleted = 0 ent_deleted = 0
@ -292,8 +287,7 @@ class LPMMOperations:
# 删除所有关系向量 # 删除所有关系向量
if rel_keys: if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor( rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items, embed_mgr.relation_embedding_store.delete_items, rel_keys
rel_keys
) )
else: else:
rel_deleted = 0 rel_deleted = 0
@ -341,15 +335,13 @@ class LPMMOperations:
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial( delete_func = partial(
kg_mgr.delete_paragraphs, kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
all_pg_hashes,
ent_hashes=None,
remove_orphan_entities=True
) )
await self._run_cancellable_executor(delete_func) await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行 # 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph() kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear() kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear() kg_mgr.ent_appear_cnt.clear()
@ -373,7 +365,7 @@ class LPMMOperations:
"stats": { "stats": {
"before": before_stats, "before": before_stats,
"after": after_stats, "after": after_stats,
} },
} }
except asyncio.CancelledError: except asyncio.CancelledError:
@ -383,6 +375,6 @@ class LPMMOperations:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True) logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
# 内部使用的单例 # 内部使用的单例
lpmm_ops = LPMMOperations() lpmm_ops = LPMMOperations()

View File

@ -136,4 +136,3 @@ class PlanReplyLogger:
return str(value) return str(value)
# Fallback to string for other complex types # Fallback to string for other complex types
return str(value) return str(value)

View File

@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 如果未开启 API Server直接跳过 Fallback # 如果未开启 API Server直接跳过 Fallback
if not global_config.maim_message.enable_api_server: if not global_config.maim_message.enable_api_server:
logger.debug(f"[API Server Fallback] API Server未开启跳过fallback") logger.debug("[API Server Fallback] API Server未开启跳过fallback")
if legacy_exception: if legacy_exception:
raise legacy_exception raise legacy_exception
return False return False
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
extra_server = getattr(global_api, "extra_server", None) extra_server = getattr(global_api, "extra_server", None)
if not extra_server: if not extra_server:
logger.warning(f"[API Server Fallback] extra_server不存在") logger.warning("[API Server Fallback] extra_server不存在")
if legacy_exception: if legacy_exception:
raise legacy_exception raise legacy_exception
return False return False
if not extra_server.is_running(): if not extra_server.is_running():
logger.warning(f"[API Server Fallback] extra_server未运行") logger.warning("[API Server Fallback] extra_server未运行")
if legacy_exception: if legacy_exception:
raise legacy_exception raise legacy_exception
return False return False
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
) )
# 直接调用 Server 的 send_message 接口,它会自动处理路由 # 直接调用 Server 的 send_message 接口,它会自动处理路由
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...") logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
results = await extra_server.send_message(api_message) results = await extra_server.send_message(api_message)
logger.debug(f"[API Server Fallback] 发送结果: {results}") logger.debug(f"[API Server Fallback] 发送结果: {results}")

View File

@ -35,6 +35,7 @@ logger = get_logger("planner")
install(extra_lines=3) install(extra_lines=3)
class ActionPlanner: class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager): def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id self.chat_id = chat_id
@ -111,20 +112,29 @@ class ActionPlanner:
# 替换 [picid:xxx] 为 [图片:描述] # 替换 [picid:xxx] 为 [图片:描述]
pic_pattern = r"\[picid:([^\]]+)\]" pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(pic_match: re.Match) -> str: def replace_pic_id(pic_match: re.Match) -> str:
pic_id = pic_match.group(1) pic_id = pic_match.group(1)
description = translate_pid_to_description(pic_id) description = translate_pid_to_description(pic_id)
return f"[图片:{description}]" return f"[图片:{description}]"
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text) msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb> # 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq" platform = (
getattr(message, "user_info", None)
and message.user_info.platform
or getattr(message, "chat_info", None)
and message.chat_info.platform
or "qq"
)
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True) msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
# 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式) # 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式)
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式, # 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
# 这里匹配到的应该都是单独的格式 # 这里匹配到的应该都是单独的格式
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>" user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
def replace_user_ref(user_match: re.Match) -> str: def replace_user_ref(user_match: re.Match) -> str:
user_name = user_match.group(1) user_name = user_match.group(1)
user_id = user_match.group(2) user_id = user_match.group(2)
@ -137,6 +147,7 @@ class ActionPlanner:
except Exception: except Exception:
# 如果解析失败,使用原始昵称 # 如果解析失败,使用原始昵称
return user_name return user_name
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text) msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..." preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
@ -306,9 +317,7 @@ class ActionPlanner:
return merged_words return merged_words
def _process_unknown_words_cache( def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
self, actions: List[ActionPlannerInfo]
) -> None:
""" """
处理黑话缓存逻辑 处理黑话缓存逻辑
1. 检查是否有 reply action 提取了 unknown_words 1. 检查是否有 reply action 提取了 unknown_words
@ -442,7 +451,11 @@ class ActionPlanner:
# 检查是否已经有回复该消息的 action # 检查是否已经有回复该消息的 action
has_reply_to_force_message = False has_reply_to_force_message = False
for action in actions: for action in actions:
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id: if (
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
):
has_reply_to_force_message = True has_reply_to_force_message = True
break break
@ -577,10 +590,11 @@ class ActionPlanner:
if global_config.chat.think_mode == "classic": if global_config.chat.think_mode == "classic":
reply_action_example = "" reply_action_example = ""
if global_config.chat.llm_quote: if global_config.chat.llm_quote:
reply_action_example += "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += ( reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", ' "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
'"unknown_words":["词语1","词语2"]' )
reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
) )
if global_config.chat.llm_quote: if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"' reply_action_example += ', "quote":"如果需要引用该message设置为true"'
@ -590,7 +604,9 @@ class ActionPlanner:
"5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n" "5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n"
) )
if global_config.chat.llm_quote: if global_config.chat.llm_quote:
reply_action_example += "6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n" reply_action_example += (
"6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += ( reply_action_example += (
'{{"action":"reply", "think_level":数值等级(0或1), ' '{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", ' '"target_message_id":"消息id(m+数字)", '
@ -741,7 +757,9 @@ class ActionPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}", [ return (
f"LLM 请求失败,模型出现问题: {req_e}",
[
ActionPlannerInfo( ActionPlannerInfo(
action_type="no_reply", action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}", reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
@ -749,7 +767,11 @@ class ActionPlanner:
action_message=None, action_message=None,
available_actions=available_actions, available_actions=available_actions,
) )
], llm_content, llm_reasoning, llm_duration_ms ],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应 # 解析LLM响应
extracted_reasoning = "" extracted_reasoning = ""

View File

@ -1071,7 +1071,6 @@ class DefaultReplyer:
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2") chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt) chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换 # 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style reply_style = global_config.personality.reply_style
multi_styles = global_config.personality.multiple_reply_style multi_styles = global_config.personality.multiple_reply_style

View File

@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
) )
from src.bw_learner.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.base.component_types import ActionInfo, EventType

View File

@ -5,6 +5,7 @@ from src.common.logger import get_logger
logger = get_logger("common_utils") logger = get_logger("common_utils")
class TempMethodsExpression: class TempMethodsExpression:
"""用于临时存放一些方法的类""" """用于临时存放一些方法的类"""

View File

@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
from . import BaseDatabaseDataModel from . import BaseDatabaseDataModel
class MaiChatSession(BaseDatabaseDataModel[ChatSession]): class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None): def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
self.session_id = session_id self.session_id = session_id

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple, Union from typing import Any, Iterable, List, Optional, Tuple, Union

View File

@ -96,8 +96,12 @@ class Images(SQLModel, table=True):
no_file_flag: bool = Field(default=False) # 文件不存在标记如果为True表示文件已经不存在仅保留描述字段 no_file_flag: bool = Field(default=False) # 文件不存在标记如果为True表示文件已经不存在仅保留描述字段
record_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间(数据库记录被创建的时间) record_time: datetime = Field(
register_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 注册时间(被注册为可用表情包的时间) default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 记录时间(数据库记录被创建的时间)
register_time: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime, nullable=True)
) # 注册时间(被注册为可用表情包的时间)
last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间 last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理 vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
@ -171,7 +175,9 @@ class Expression(SQLModel, table=True):
content_list: str # 内容列表JSON格式存储 content_list: str # 内容列表JSON格式存储
count: int = Field(default=0) # 使用次数 count: int = Field(default=0) # 使用次数
last_active_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 上次使用时间 last_active_time: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 上次使用时间
create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间 create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间
session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID区分是否为全局表达方式 session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID区分是否为全局表达方式
@ -232,8 +238,12 @@ class ThinkingQuestion(SQLModel, table=True):
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案 answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤JSON格式存储 thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤JSON格式存储
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间 created_timestamp: datetime = Field(
updated_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后更新时间 default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 创建时间
updated_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 最后更新时间
class BinaryData(SQLModel, table=True): class BinaryData(SQLModel, table=True):
@ -272,7 +282,9 @@ class PersonInfo(SQLModel, table=True):
# 认识次数和时间 # 认识次数和时间
know_counts: int = Field(default=0) # 认识次数 know_counts: int = Field(default=0) # 认识次数
first_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 首次认识时间 first_known_time: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime, nullable=True)
) # 首次认识时间
last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间 last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间
@ -285,8 +297,12 @@ class ChatSession(SQLModel, table=True):
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间 created_timestamp: datetime = Field(
last_active_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后活跃时间 default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 创建时间
last_active_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 最后活跃时间
# 身份元数据 # 身份元数据
user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID

View File

@ -221,5 +221,7 @@ if not supports_truecolor():
CONVERTED_MODULE_COLORS[name] = escape_str CONVERTED_MODULE_COLORS[name] = escape_str
else: else:
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items(): for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold) escape_str = rgb_pair_to_ansi_truecolor(
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
)
CONVERTED_MODULE_COLORS[name] = escape_str CONVERTED_MODULE_COLORS[name] = escape_str

View File

@ -9,6 +9,7 @@ from .server import get_global_server
global_api = None global_api = None
def get_global_api() -> MessageServer: # sourcery skip: extract-method def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例""" """获取全局MessageServer实例"""
global global_api global global_api

View File

@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
logger = get_logger("file_utils") logger = get_logger("file_utils")
class FileUtils: class FileUtils:
@staticmethod @staticmethod
def save_binary_to_file(file_path: Path, data: bytes): def save_binary_to_file(file_path: Path, data: bytes):

View File

@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
reason = ",".join(reasons) reason = ",".join(reasons)
return MigrationResult(data=data, migrated=migrated_any, reason=reason) return MigrationResult(data=data, migrated=migrated_any, reason=reason)

View File

@ -54,8 +54,6 @@ async def generate_dream_summary(
) -> None: ) -> None:
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户""" """生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
try: try:
# 第一步:建立工具调用结果映射 (call_id -> result) # 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {} tool_results_map: dict[str, str] = {}
for msg in conversation_messages: for msg in conversation_messages:

View File

@ -4,4 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数 每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册 生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册
""" """

View File

@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}" return f"create_chat_history 执行失败: {e}"
return create_chat_history return create_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}" return f"delete_chat_history 执行失败: {e}"
return delete_chat_history return delete_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}" return f"delete_jargon 执行失败: {e}"
return delete_jargon return delete_jargon

View File

@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg return msg
return finish_maintenance return finish_maintenance

View File

@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
return f"get_chat_history_detail 执行失败: {e}" return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail return get_chat_history_detail

View File

@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}" return f"search_chat_history 执行失败: {e}"
return search_chat_history return search_chat_history

View File

@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}" return f"update_chat_history 执行失败: {e}"
return update_chat_history return update_chat_history

View File

@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}" return f"update_jargon 执行失败: {e}"
return update_jargon return update_jargon

View File

@ -459,7 +459,7 @@ def _default_normal_response_parser(
# 此时为了调试方便,建议打印出 arguments 的类型 # 此时为了调试方便,建议打印出 arguments 的类型
raise RespParseException( raise RespParseException(
resp, resp,
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}" f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
) )
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e: except json.JSONDecodeError as e:

View File

@ -2,7 +2,7 @@ import time
import json import json
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable, cast from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager

View File

@ -113,6 +113,7 @@ async def search_chat_history(
if start_time: if start_time:
try: try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp from src.memory_system.memory_utils import parse_datetime_to_timestamp
start_timestamp = parse_datetime_to_timestamp(start_time) start_timestamp = parse_datetime_to_timestamp(start_time)
except ValueError as e: except ValueError as e:
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'" return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
@ -120,6 +121,7 @@ async def search_chat_history(
if end_time: if end_time:
try: try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp from src.memory_system.memory_utils import parse_datetime_to_timestamp
end_timestamp = parse_datetime_to_timestamp(end_time) end_timestamp = parse_datetime_to_timestamp(end_time)
except ValueError as e: except ValueError as e:
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'" return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
@ -165,16 +167,13 @@ async def search_chat_history(
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段 # 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
query = query.where( query = query.where(
( (
(ChatHistory.start_time >= start_timestamp) (ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
& (ChatHistory.start_time <= end_timestamp)
) # 记录开始时间在查询时间段内 ) # 记录开始时间在查询时间段内
| ( | (
(ChatHistory.end_time >= start_timestamp) (ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
& (ChatHistory.end_time <= end_timestamp)
) # 记录结束时间在查询时间段内 ) # 记录结束时间在查询时间段内
| ( | (
(ChatHistory.start_time <= start_timestamp) (ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
& (ChatHistory.end_time >= end_timestamp)
) # 记录完全包含查询时间段 ) # 记录完全包含查询时间段
) )
logger.debug( logger.debug(

View File

@ -76,4 +76,3 @@ def register_tool():
], ],
execute_func=query_words, execute_func=query_words,
) )

View File

@ -558,7 +558,9 @@ class PluginBase(ABC):
if version_spec: if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec) is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok: if not is_ok:
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})") logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
)
return False return False
if min_version or max_version: if min_version or max_version:

View File

@ -751,9 +751,7 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count, "workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count, "enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": { "workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
},
} }

View File

@ -429,7 +429,9 @@ class PluginManager:
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]: def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。""" """根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()} indegree: Dict[str, int] = {
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph} reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items(): for plugin_name, dependencies in dependency_graph.items():

View File

@ -55,7 +55,9 @@ class PluginServiceRegistry:
full_name = self._resolve_full_name(service_name, plugin_name) full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None return self._service_handlers.get(full_name) if full_name else None
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]: def list_services(
self, plugin_name: Optional[str] = None, enabled_only: bool = False
) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。""" """列出插件服务。"""
services = self._services.copy() services = self._services.copy()
if plugin_name: if plugin_name:
@ -120,7 +122,12 @@ class PluginServiceRegistry:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}") raise ValueError(f"插件服务未注册: {target_name}")
if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin: if (
"." not in service_name
and plugin_name is None
and caller_plugin
and service_info.plugin_name != caller_plugin
):
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name") raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin): if not self._is_call_authorized(service_info, caller_plugin):
@ -153,7 +160,9 @@ class PluginServiceRegistry:
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()} allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: def _validate_input_contract(
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""校验服务入参契约。""" """校验服务入参契约。"""
schema = service_info.params_schema schema = service_info.params_schema
if not schema: if not schema:

View File

@ -96,7 +96,9 @@ class WorkflowEngine:
except Exception as e: except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}") workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True) logger.error(
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
)
self._execution_history[workflow_context.trace_id]["status"] = "failed" self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return ( return (
@ -195,7 +197,9 @@ class WorkflowEngine:
except Exception as e: except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}") context.errors.append(f"{step_info.full_name}: {e}")
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True) logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
)
return WorkflowStepResult( return WorkflowStepResult(
status="failed", status="failed",
return_message=str(e), return_message=str(e),

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容 2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载 3. 详情按需加载
""" """
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional from typing import List, Dict, Optional
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel): class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容""" """聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str chat_id: str
plan_count: int plan_count: int
latest_timestamp: float latest_timestamp: float
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
class PlanLogSummary(BaseModel): class PlanLogSummary(BaseModel):
"""规划日志摘要""" """规划日志摘要"""
chat_id: str chat_id: str
timestamp: float timestamp: float
filename: str filename: str
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
class PlanLogDetail(BaseModel): class PlanLogDetail(BaseModel):
"""规划日志详情""" """规划日志详情"""
type: str type: str
chat_id: str chat_id: str
timestamp: float timestamp: float
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
class PlannerOverview(BaseModel): class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计""" """规划器总览 - 轻量级统计"""
total_chats: int total_chats: int
total_plans: int total_plans: int
chats: List[ChatSummary] chats: List[ChatSummary]
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
class PaginatedChatLogs(BaseModel): class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表""" """分页的聊天日志列表"""
data: List[PlanLogSummary] data: List[PlanLogSummary]
total: int total: int
page: int page: int
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float: def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220""" """从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try: try:
timestamp_str = filename.split('_')[0] timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒 # 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000 return float(timestamp_str) / 1000
except (ValueError, IndexError): except (ValueError, IndexError):
@ -106,21 +112,19 @@ async def get_planner_overview():
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name)) latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name) latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary( chats.append(
ChatSummary(
chat_id=chat_dir.name, chat_id=chat_dir.name,
plan_count=plan_count, plan_count=plan_count,
latest_timestamp=latest_timestamp, latest_timestamp=latest_timestamp,
latest_filename=latest_file.name latest_filename=latest_file.name,
)) )
)
# 按最新时间戳排序 # 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True) chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview( return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs) @router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
chat_id: str, chat_id: str,
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容") search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
): ):
""" """
获取指定聊天的规划日志列表分页 获取指定聊天的规划日志列表分页
@ -137,9 +141,7 @@ async def get_chat_plan_logs(
""" """
chat_dir = PLAN_LOG_DIR / chat_id chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists(): if not chat_dir.exists():
return PaginatedChatLogs( return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序 # 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json")) json_files = list(chat_dir.glob("*.json"))
@ -151,9 +153,9 @@ async def get_chat_plan_logs(
filtered_files = [] filtered_files = []
for log_file in json_files: for log_file in json_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
prompt = data.get('prompt', '') prompt = data.get("prompt", "")
if search_lower in prompt.lower(): if search_lower in prompt.lower():
filtered_files.append(log_file) filtered_files.append(log_file)
except Exception: except Exception:
@ -164,29 +166,32 @@ async def get_chat_plan_logs(
# 分页 - 只读取当前页的文件 # 分页 - 只读取当前页的文件
offset = (page - 1) * page_size offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size] page_files = json_files[offset : offset + page_size]
logs = [] logs = []
for log_file in page_files: for log_file in page_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
reasoning = data.get('reasoning', '') reasoning = data.get("reasoning", "")
actions = data.get('actions', []) actions = data.get("actions", [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')] action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
logs.append(PlanLogSummary( logs.append(
chat_id=data.get('chat_id', chat_id), PlanLogSummary(
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)), chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
filename=log_file.name, filename=log_file.name,
action_count=len(actions), action_count=len(actions),
action_types=action_types, action_types=action_types,
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0), total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0), llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
reasoning_preview=reasoning[:100] if reasoning else '' reasoning_preview=reasoning[:100] if reasoning else "",
)) )
)
except Exception: except Exception:
# 文件读取失败时使用文件名信息 # 文件读取失败时使用文件名信息
logs.append(PlanLogSummary( logs.append(
PlanLogSummary(
chat_id=chat_id, chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name), timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name, filename=log_file.name,
@ -194,16 +199,11 @@ async def get_chat_plan_logs(
action_types=[], action_types=[],
total_plan_ms=0, total_plan_ms=0,
llm_duration_ms=0, llm_duration_ms=0,
reasoning_preview='[读取失败]' reasoning_preview="[读取失败]",
))
return PaginatedChatLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
) )
)
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail) @router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
@ -214,7 +214,7 @@ async def get_log_detail(chat_id: str, filename: str):
raise HTTPException(status_code=404, detail="日志文件不存在") raise HTTPException(status_code=404, detail="日志文件不存在")
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return PlanLogDetail(**data) return PlanLogDetail(**data)
except Exception as e: except Exception as e:
@ -223,6 +223,7 @@ async def get_log_detail(chat_id: str, filename: str):
# ========== 兼容旧接口 ========== # ========== 兼容旧接口 ==========
@router.get("/stats") @router.get("/stats")
async def get_planner_stats(): async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口""" """获取规划器统计信息 - 兼容旧接口"""
@ -246,7 +247,7 @@ async def get_planner_stats():
"total_plans": overview.total_plans, "total_plans": overview.total_plans,
"avg_plan_time_ms": 0, "avg_plan_time_ms": 0,
"avg_llm_time_ms": 0, "avg_llm_time_ms": 0,
"recent_plans": recent_plans "recent_plans": recent_plans,
} }
@ -258,10 +259,7 @@ async def get_chat_list():
@router.get("/all-logs") @router.get("/all-logs")
async def get_all_logs( async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
"""获取所有规划日志 - 兼容旧接口""" """获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists(): if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size} return {"data": [], "total": 0, "page": page, "page_size": page_size}
@ -278,23 +276,25 @@ async def get_all_logs(
total = len(all_files) total = len(all_files)
offset = (page - 1) * page_size offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size] page_files = all_files[offset : offset + page_size]
logs = [] logs = []
for chat_id, log_file in page_files: for chat_id, log_file in page_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
reasoning = data.get('reasoning', '') reasoning = data.get("reasoning", "")
logs.append({ logs.append(
"chat_id": data.get('chat_id', chat_id), {
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)), "chat_id": data.get("chat_id", chat_id),
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name, "filename": log_file.name,
"action_count": len(data.get('actions', [])), "action_count": len(data.get("actions", [])),
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0), "total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0), "llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
"reasoning_preview": reasoning[:100] if reasoning else '' "reasoning_preview": reasoning[:100] if reasoning else "",
}) }
)
except Exception: except Exception:
continue continue

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容 2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载 3. 详情按需加载
""" """
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional from typing import List, Dict, Optional
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel): class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容""" """聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str chat_id: str
reply_count: int reply_count: int
latest_timestamp: float latest_timestamp: float
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
class ReplyLogSummary(BaseModel): class ReplyLogSummary(BaseModel):
"""回复日志摘要""" """回复日志摘要"""
chat_id: str chat_id: str
timestamp: float timestamp: float
filename: str filename: str
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
class ReplyLogDetail(BaseModel): class ReplyLogDetail(BaseModel):
"""回复日志详情""" """回复日志详情"""
type: str type: str
chat_id: str chat_id: str
timestamp: float timestamp: float
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
class ReplierOverview(BaseModel): class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计""" """回复器总览 - 轻量级统计"""
total_chats: int total_chats: int
total_replies: int total_replies: int
chats: List[ReplierChatSummary] chats: List[ReplierChatSummary]
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
class PaginatedReplyLogs(BaseModel): class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表""" """分页的回复日志列表"""
data: List[ReplyLogSummary] data: List[ReplyLogSummary]
total: int total: int
page: int page: int
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float: def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220""" """从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try: try:
timestamp_str = filename.split('_')[0] timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒 # 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000 return float(timestamp_str) / 1000
except (ValueError, IndexError): except (ValueError, IndexError):
@ -109,21 +115,19 @@ async def get_replier_overview():
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name)) latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name) latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary( chats.append(
ReplierChatSummary(
chat_id=chat_dir.name, chat_id=chat_dir.name,
reply_count=reply_count, reply_count=reply_count,
latest_timestamp=latest_timestamp, latest_timestamp=latest_timestamp,
latest_filename=latest_file.name latest_filename=latest_file.name,
)) )
)
# 按最新时间戳排序 # 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True) chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview( return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs) @router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
chat_id: str, chat_id: str,
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容") search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
): ):
""" """
获取指定聊天的回复日志列表分页 获取指定聊天的回复日志列表分页
@ -140,9 +144,7 @@ async def get_chat_reply_logs(
""" """
chat_dir = REPLY_LOG_DIR / chat_id chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists(): if not chat_dir.exists():
return PaginatedReplyLogs( return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序 # 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json")) json_files = list(chat_dir.glob("*.json"))
@ -154,9 +156,9 @@ async def get_chat_reply_logs(
filtered_files = [] filtered_files = []
for log_file in json_files: for log_file in json_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
prompt = data.get('prompt', '') prompt = data.get("prompt", "")
if search_lower in prompt.lower(): if search_lower in prompt.lower():
filtered_files.append(log_file) filtered_files.append(log_file)
except Exception: except Exception:
@ -167,44 +169,42 @@ async def get_chat_reply_logs(
# 分页 - 只读取当前页的文件 # 分页 - 只读取当前页的文件
offset = (page - 1) * page_size offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size] page_files = json_files[offset : offset + page_size]
logs = [] logs = []
for log_file in page_files: for log_file in page_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
output = data.get('output', '') output = data.get("output", "")
logs.append(ReplyLogSummary( logs.append(
chat_id=data.get('chat_id', chat_id), ReplyLogSummary(
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)), chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
filename=log_file.name, filename=log_file.name,
model=data.get('model', ''), model=data.get("model", ""),
success=data.get('success', True), success=data.get("success", True),
llm_ms=data.get('timing', {}).get('llm_ms', 0), llm_ms=data.get("timing", {}).get("llm_ms", 0),
overall_ms=data.get('timing', {}).get('overall_ms', 0), overall_ms=data.get("timing", {}).get("overall_ms", 0),
output_preview=output[:100] if output else '' output_preview=output[:100] if output else "",
)) )
)
except Exception: except Exception:
# 文件读取失败时使用文件名信息 # 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary( logs.append(
ReplyLogSummary(
chat_id=chat_id, chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name), timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name, filename=log_file.name,
model='', model="",
success=False, success=False,
llm_ms=0, llm_ms=0,
overall_ms=0, overall_ms=0,
output_preview='[读取失败]' output_preview="[读取失败]",
))
return PaginatedReplyLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
) )
)
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail) @router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
@ -215,21 +215,21 @@ async def get_reply_log_detail(chat_id: str, filename: str):
raise HTTPException(status_code=404, detail="日志文件不存在") raise HTTPException(status_code=404, detail="日志文件不存在")
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return ReplyLogDetail( return ReplyLogDetail(
type=data.get('type', 'reply'), type=data.get("type", "reply"),
chat_id=data.get('chat_id', chat_id), chat_id=data.get("chat_id", chat_id),
timestamp=data.get('timestamp', 0), timestamp=data.get("timestamp", 0),
prompt=data.get('prompt', ''), prompt=data.get("prompt", ""),
output=data.get('output', ''), output=data.get("output", ""),
processed_output=data.get('processed_output', []), processed_output=data.get("processed_output", []),
model=data.get('model', ''), model=data.get("model", ""),
reasoning=data.get('reasoning', ''), reasoning=data.get("reasoning", ""),
think_level=data.get('think_level', 0), think_level=data.get("think_level", 0),
timing=data.get('timing', {}), timing=data.get("timing", {}),
error=data.get('error'), error=data.get("error"),
success=data.get('success', True) success=data.get("success", True),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
@ -237,6 +237,7 @@ async def get_reply_log_detail(chat_id: str, filename: str):
# ========== 兼容接口 ========== # ========== 兼容接口 ==========
@router.get("/stats") @router.get("/stats")
async def get_replier_stats(): async def get_replier_stats():
"""获取回复器统计信息""" """获取回复器统计信息"""
@ -258,7 +259,7 @@ async def get_replier_stats():
return { return {
"total_chats": overview.total_chats, "total_chats": overview.total_chats,
"total_replies": overview.total_replies, "total_replies": overview.total_replies,
"recent_replies": recent_replies "recent_replies": recent_replies,
} }

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi import Depends, Cookie, Header, Request, HTTPException from fastapi import Depends, Cookie, Header, Request
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit from .core import get_current_token, get_token_manager, check_auth_rate_limit
async def require_auth( async def require_auth(

View File

@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
# loose: 宽松模式(较宽松的检测,较高的频率限制) # loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP # basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔 # IP白名单配置从配置文件读取逗号分隔
# 支持格式: # 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100 # - 精确IP127.0.0.1, 192.168.1.100
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
def _get_anti_crawler_config(): def _get_anti_crawler_config():
"""获取防爬虫配置""" """获取防爬虫配置"""
from src.config.config import global_config from src.config.config import global_config
return { return {
'mode': global_config.webui.anti_crawler_mode, "mode": global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips), "allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies), "trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff "trust_xff": global_config.webui.trust_xff,
} }
# 初始化配置(将在模块加载时执行) # 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config() _config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode'] ANTI_CRAWLER_MODE = _config["mode"]
ALLOWED_IPS = _config['allowed_ips'] ALLOWED_IPS = _config["allowed_ips"]
TRUSTED_PROXIES = _config['trusted_proxies'] TRUSTED_PROXIES = _config["trusted_proxies"]
TRUST_XFF = _config['trust_xff'] TRUST_XFF = _config["trust_xff"]
def _get_mode_config(mode: str) -> dict: def _get_mode_config(mode: str) -> dict:

View File

@ -43,7 +43,7 @@ def _get_paragraph_store():
namespace="paragraph", namespace="paragraph",
dir_path=embedding_dir, dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程 max_workers=1, # 只读不需要多线程
chunk_size=100 chunk_size=100,
) )
paragraph_store.load_from_file() paragraph_store.load_from_file()
@ -74,7 +74,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
paragraph_item = paragraph_store.store.get(node_id) paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None: if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本 # paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '') content: str = getattr(paragraph_item, "str", "")
if content: if content:
return content, True return content, True
return None, True return None, True
@ -160,7 +160,11 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
# 对于段落节点,尝试从 embedding store 获取完整内容 # 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph": if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id) full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else: else:
content = node_data["content"] if "content" in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
@ -249,7 +253,11 @@ async def get_knowledge_graph(
# 对于段落节点,尝试从 embedding store 获取完整内容 # 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph": if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id) full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else: else:
content = node_data["content"] if "content" in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
@ -372,7 +380,11 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
# 对于段落节点,尝试从 embedding store 获取完整内容 # 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph": if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id) full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else: else:
content = node_data["content"] if "content" in node_data else node_id content = node_data["content"] if "content" in node_data else node_id