mirror of https://github.com/Mai-with-u/MaiBot.git
Ruff Format
parent
2cb512120b
commit
eaef7f0e98
1
bot.py
1
bot.py
|
|
@ -50,6 +50,7 @@ print("警告:Dev进入不稳定开发状态,任何插件与WebUI均可能
|
||||||
print("\n\n\n\n\n")
|
print("\n\n\n\n\n")
|
||||||
print("-----------------------------------------")
|
print("-----------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
def run_runner_process():
|
def run_runner_process():
|
||||||
"""
|
"""
|
||||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
"""Core helpers for MCP Bridge Plugin."""
|
"""Core helpers for MCP Bridge Plugin."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
|
||||||
if not mcp_servers:
|
if not mcp_servers:
|
||||||
return ""
|
return ""
|
||||||
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)
|
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,19 +34,21 @@ from enum import Enum
|
||||||
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
||||||
try:
|
try:
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("mcp_client")
|
logger = get_logger("mcp_client")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback: 使用标准 logging
|
# Fallback: 使用标准 logging
|
||||||
logger = logging.getLogger("mcp_client")
|
logger = logging.getLogger("mcp_client")
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s'))
|
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class TransportType(Enum):
|
class TransportType(Enum):
|
||||||
"""MCP 传输类型"""
|
"""MCP 传输类型"""
|
||||||
|
|
||||||
STDIO = "stdio" # 本地进程通信
|
STDIO = "stdio" # 本地进程通信
|
||||||
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
||||||
HTTP = "http" # HTTP Streamable (新版,推荐)
|
HTTP = "http" # HTTP Streamable (新版,推荐)
|
||||||
|
|
@ -56,6 +58,7 @@ class TransportType(Enum):
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPToolInfo:
|
class MCPToolInfo:
|
||||||
"""MCP 工具信息"""
|
"""MCP 工具信息"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
input_schema: Dict[str, Any]
|
input_schema: Dict[str, Any]
|
||||||
|
|
@ -65,6 +68,7 @@ class MCPToolInfo:
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPResourceInfo:
|
class MCPResourceInfo:
|
||||||
"""MCP 资源信息"""
|
"""MCP 资源信息"""
|
||||||
|
|
||||||
uri: str
|
uri: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
@ -75,6 +79,7 @@ class MCPResourceInfo:
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPPromptInfo:
|
class MCPPromptInfo:
|
||||||
"""MCP 提示模板信息"""
|
"""MCP 提示模板信息"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
||||||
|
|
@ -84,6 +89,7 @@ class MCPPromptInfo:
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPServerConfig:
|
class MCPServerConfig:
|
||||||
"""MCP 服务器配置"""
|
"""MCP 服务器配置"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
transport: TransportType = TransportType.STDIO
|
transport: TransportType = TransportType.STDIO
|
||||||
|
|
@ -99,6 +105,7 @@ class MCPServerConfig:
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPCallResult:
|
class MCPCallResult:
|
||||||
"""MCP 工具调用结果"""
|
"""MCP 工具调用结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
content: Any
|
content: Any
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
@ -108,6 +115,7 @@ class MCPCallResult:
|
||||||
|
|
||||||
class CircuitState(Enum):
|
class CircuitState(Enum):
|
||||||
"""断路器状态"""
|
"""断路器状态"""
|
||||||
|
|
||||||
CLOSED = "closed" # 正常状态,允许请求
|
CLOSED = "closed" # 正常状态,允许请求
|
||||||
OPEN = "open" # 熔断状态,拒绝请求
|
OPEN = "open" # 熔断状态,拒绝请求
|
||||||
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
||||||
|
|
@ -232,6 +240,7 @@ class CircuitBreaker:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallStats:
|
class ToolCallStats:
|
||||||
"""工具调用统计"""
|
"""工具调用统计"""
|
||||||
|
|
||||||
tool_key: str
|
tool_key: str
|
||||||
total_calls: int = 0
|
total_calls: int = 0
|
||||||
success_calls: int = 0
|
success_calls: int = 0
|
||||||
|
|
@ -282,6 +291,7 @@ class ToolCallStats:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ServerStats:
|
class ServerStats:
|
||||||
"""服务器统计"""
|
"""服务器统计"""
|
||||||
|
|
||||||
server_name: str
|
server_name: str
|
||||||
connect_count: int = 0 # 连接次数
|
connect_count: int = 0 # 连接次数
|
||||||
disconnect_count: int = 0 # 断开次数
|
disconnect_count: int = 0 # 断开次数
|
||||||
|
|
@ -442,9 +452,7 @@ class MCPClientSession:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command=self.config.command,
|
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
|
||||||
args=self.config.args,
|
|
||||||
env=self.config.env if self.config.env else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._stdio_context = stdio_client(server_params)
|
self._stdio_context = stdio_client(server_params)
|
||||||
|
|
@ -506,6 +514,7 @@ class MCPClientSession:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
return False
|
return False
|
||||||
|
|
@ -551,6 +560,7 @@ class MCPClientSession:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
return False
|
return False
|
||||||
|
|
@ -568,8 +578,8 @@ class MCPClientSession:
|
||||||
tool_info = MCPToolInfo(
|
tool_info = MCPToolInfo(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description or f"MCP tool: {tool.name}",
|
description=tool.description or f"MCP tool: {tool.name}",
|
||||||
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {},
|
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._tools.append(tool_info)
|
self._tools.append(tool_info)
|
||||||
# 初始化工具统计
|
# 初始化工具统计
|
||||||
|
|
@ -591,10 +601,7 @@ class MCPClientSession:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
|
||||||
self._session.list_resources(),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
self._resources = []
|
self._resources = []
|
||||||
|
|
||||||
for resource in result.resources:
|
for resource in result.resources:
|
||||||
|
|
@ -602,8 +609,8 @@ class MCPClientSession:
|
||||||
uri=str(resource.uri),
|
uri=str(resource.uri),
|
||||||
name=resource.name or str(resource.uri),
|
name=resource.name or str(resource.uri),
|
||||||
description=resource.description or "",
|
description=resource.description or "",
|
||||||
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None,
|
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._resources.append(resource_info)
|
self._resources.append(resource_info)
|
||||||
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
||||||
|
|
@ -633,28 +640,27 @@ class MCPClientSession:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
|
||||||
self._session.list_prompts(),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
self._prompts = []
|
self._prompts = []
|
||||||
|
|
||||||
for prompt in result.prompts:
|
for prompt in result.prompts:
|
||||||
# 解析参数
|
# 解析参数
|
||||||
arguments = []
|
arguments = []
|
||||||
if hasattr(prompt, 'arguments') and prompt.arguments:
|
if hasattr(prompt, "arguments") and prompt.arguments:
|
||||||
for arg in prompt.arguments:
|
for arg in prompt.arguments:
|
||||||
arguments.append({
|
arguments.append(
|
||||||
|
{
|
||||||
"name": arg.name,
|
"name": arg.name,
|
||||||
"description": arg.description or "",
|
"description": arg.description or "",
|
||||||
"required": arg.required if hasattr(arg, 'required') else False,
|
"required": arg.required if hasattr(arg, "required") else False,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
prompt_info = MCPPromptInfo(
|
prompt_info = MCPPromptInfo(
|
||||||
name=prompt.name,
|
name=prompt.name,
|
||||||
description=prompt.description or f"MCP prompt: {prompt.name}",
|
description=prompt.description or f"MCP prompt: {prompt.name}",
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._prompts.append(prompt_info)
|
self._prompts.append(prompt_info)
|
||||||
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
||||||
|
|
@ -686,35 +692,25 @@ class MCPClientSession:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not self._connected or not self._session:
|
if not self._connected or not self._session:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 未连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._supports_resources:
|
if not self._supports_resources:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 不支持 Resources 功能"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
|
||||||
self._session.read_resource(uri),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
# 处理返回内容
|
# 处理返回内容
|
||||||
content_parts = []
|
content_parts = []
|
||||||
for content in result.contents:
|
for content in result.contents:
|
||||||
if hasattr(content, 'text'):
|
if hasattr(content, "text"):
|
||||||
content_parts.append(content.text)
|
content_parts.append(content.text)
|
||||||
elif hasattr(content, 'blob'):
|
elif hasattr(content, "blob"):
|
||||||
# 二进制数据,返回 base64 或提示
|
# 二进制数据,返回 base64 或提示
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
blob_data = content.blob
|
blob_data = content.blob
|
||||||
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
||||||
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
||||||
|
|
@ -724,28 +720,18 @@ class MCPClientSession:
|
||||||
content_parts.append(str(content))
|
content_parts.append(str(content))
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
|
||||||
content="\n".join(content_parts) if content_parts else "",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=False,
|
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||||
content=None,
|
|
||||||
error=f"读取资源超时({self.call_timeout}秒)",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=str(e),
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
||||||
"""v1.2.0: 获取提示模板的内容
|
"""v1.2.0: 获取提示模板的内容
|
||||||
|
|
@ -760,23 +746,14 @@ class MCPClientSession:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not self._connected or not self._session:
|
if not self._connected or not self._session:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 未连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._supports_prompts:
|
if not self._supports_prompts:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.get_prompt(name, arguments=arguments or {}),
|
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
@ -784,10 +761,10 @@ class MCPClientSession:
|
||||||
# 处理返回的消息
|
# 处理返回的消息
|
||||||
messages = []
|
messages = []
|
||||||
for msg in result.messages:
|
for msg in result.messages:
|
||||||
role = msg.role if hasattr(msg, 'role') else "unknown"
|
role = msg.role if hasattr(msg, "role") else "unknown"
|
||||||
content_text = ""
|
content_text = ""
|
||||||
if hasattr(msg, 'content'):
|
if hasattr(msg, "content"):
|
||||||
if hasattr(msg.content, 'text'):
|
if hasattr(msg.content, "text"):
|
||||||
content_text = msg.content.text
|
content_text = msg.content.text
|
||||||
elif isinstance(msg.content, str):
|
elif isinstance(msg.content, str):
|
||||||
content_text = msg.content
|
content_text = msg.content
|
||||||
|
|
@ -796,28 +773,18 @@ class MCPClientSession:
|
||||||
messages.append(f"[{role}]: {content_text}")
|
messages.append(f"[{role}]: {content_text}")
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
|
||||||
content="\n\n".join(messages) if messages else "",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=False,
|
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||||
content=None,
|
|
||||||
error=f"获取提示模板超时({self.call_timeout}秒)",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=str(e),
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
async def check_health(self) -> bool:
|
async def check_health(self) -> bool:
|
||||||
"""检查连接健康状态(心跳检测)
|
"""检查连接健康状态(心跳检测)
|
||||||
|
|
@ -829,10 +796,7 @@ class MCPClientSession:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 list_tools 作为心跳检测
|
# 使用 list_tools 作为心跳检测
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
|
||||||
self._session.list_tools(),
|
|
||||||
timeout=10.0
|
|
||||||
)
|
|
||||||
self.stats.record_heartbeat()
|
self.stats.record_heartbeat()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -849,12 +813,7 @@ class MCPClientSession:
|
||||||
# v1.7.0: 断路器检查
|
# v1.7.0: 断路器检查
|
||||||
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
||||||
if not can_execute:
|
if not can_execute:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"⚡ {reject_reason}",
|
|
||||||
circuit_broken=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 半开状态下增加试探计数
|
# 半开状态下增加试探计数
|
||||||
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
||||||
|
|
@ -870,8 +829,7 @@ class MCPClientSession:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.call_tool(tool_name, arguments=arguments),
|
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
@ -879,9 +837,9 @@ class MCPClientSession:
|
||||||
# 处理返回内容
|
# 处理返回内容
|
||||||
content_parts = []
|
content_parts = []
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if hasattr(content, 'text'):
|
if hasattr(content, "text"):
|
||||||
content_parts.append(content.text)
|
content_parts.append(content.text)
|
||||||
elif hasattr(content, 'data'):
|
elif hasattr(content, "data"):
|
||||||
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
||||||
else:
|
else:
|
||||||
content_parts.append(str(content))
|
content_parts.append(str(content))
|
||||||
|
|
@ -896,7 +854,7 @@ class MCPClientSession:
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True,
|
||||||
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
||||||
duration_ms=duration_ms
|
duration_ms=duration_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|
@ -939,25 +897,25 @@ class MCPClientSession:
|
||||||
self._supports_prompts = False # v1.2.0
|
self._supports_prompts = False # v1.2.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_session_context') and self._session_context:
|
if hasattr(self, "_session_context") and self._session_context:
|
||||||
await self._session_context.__aexit__(None, None, None)
|
await self._session_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_stdio_context') and self._stdio_context:
|
if hasattr(self, "_stdio_context") and self._stdio_context:
|
||||||
await self._stdio_context.__aexit__(None, None, None)
|
await self._stdio_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_http_context') and self._http_context:
|
if hasattr(self, "_http_context") and self._http_context:
|
||||||
await self._http_context.__aexit__(None, None, None)
|
await self._http_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_sse_context') and self._sse_context:
|
if hasattr(self, "_sse_context") and self._sse_context:
|
||||||
await self._sse_context.__aexit__(None, None, None)
|
await self._sse_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
||||||
|
|
@ -1082,7 +1040,9 @@ class MCPClientManager:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if attempt < retry_attempts:
|
if attempt < retry_attempts:
|
||||||
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
|
logger.warning(
|
||||||
|
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
|
||||||
|
)
|
||||||
await asyncio.sleep(retry_interval)
|
await asyncio.sleep(retry_interval)
|
||||||
|
|
||||||
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
||||||
|
|
@ -1213,11 +1173,7 @@ class MCPClientManager:
|
||||||
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
||||||
"""调用 MCP 工具"""
|
"""调用 MCP 工具"""
|
||||||
if tool_key not in self._all_tools:
|
if tool_key not in self._all_tools:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"工具 {tool_key} 不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_info, client = self._all_tools[tool_key]
|
tool_info, client = self._all_tools[tool_key]
|
||||||
|
|
||||||
|
|
@ -1273,11 +1229,7 @@ class MCPClientManager:
|
||||||
# 如果指定了服务器
|
# 如果指定了服务器
|
||||||
if server_name:
|
if server_name:
|
||||||
if server_name not in self._clients:
|
if server_name not in self._clients:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {server_name} 不存在"
|
|
||||||
)
|
|
||||||
client = self._clients[server_name]
|
client = self._clients[server_name]
|
||||||
return await client.read_resource(uri)
|
return await client.read_resource(uri)
|
||||||
|
|
||||||
|
|
@ -1293,14 +1245,11 @@ class MCPClientManager:
|
||||||
if result.success:
|
if result.success:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"未找到资源: {uri}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None,
|
async def get_prompt(
|
||||||
server_name: Optional[str] = None) -> MCPCallResult:
|
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
|
||||||
|
) -> MCPCallResult:
|
||||||
"""v1.2.0: 获取提示模板内容
|
"""v1.2.0: 获取提示模板内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1311,11 +1260,7 @@ class MCPClientManager:
|
||||||
# 如果指定了服务器
|
# 如果指定了服务器
|
||||||
if server_name:
|
if server_name:
|
||||||
if server_name not in self._clients:
|
if server_name not in self._clients:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {server_name} 不存在"
|
|
||||||
)
|
|
||||||
client = self._clients[server_name]
|
client = self._clients[server_name]
|
||||||
return await client.get_prompt(name, arguments)
|
return await client.get_prompt(name, arguments)
|
||||||
|
|
||||||
|
|
@ -1324,11 +1269,7 @@ class MCPClientManager:
|
||||||
if prompt_info.name == name:
|
if prompt_info.name == name:
|
||||||
return await client.get_prompt(name, arguments)
|
return await client.get_prompt(name, arguments)
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"未找到提示模板: {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==================== 心跳检测 ====================
|
# ==================== 心跳检测 ====================
|
||||||
|
|
||||||
|
|
@ -1489,7 +1430,9 @@ class MCPClientManager:
|
||||||
"global": {
|
"global": {
|
||||||
**self._global_stats,
|
**self._global_stats,
|
||||||
"uptime_seconds": round(uptime, 2),
|
"uptime_seconds": round(uptime, 2),
|
||||||
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0,
|
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
|
||||||
|
if uptime > 0
|
||||||
|
else 0,
|
||||||
},
|
},
|
||||||
"servers": server_stats,
|
"servers": server_stats,
|
||||||
"tools": tool_stats,
|
"tools": tool_stats,
|
||||||
|
|
|
||||||
|
|
@ -123,9 +123,11 @@ logger = get_logger("mcp_bridge_plugin")
|
||||||
# v1.4.0: 调用链路追踪
|
# v1.4.0: 调用链路追踪
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallRecord:
|
class ToolCallRecord:
|
||||||
"""工具调用记录"""
|
"""工具调用记录"""
|
||||||
|
|
||||||
call_id: str
|
call_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
tool_name: str
|
tool_name: str
|
||||||
|
|
@ -208,9 +210,11 @@ tool_call_tracer = ToolCallTracer()
|
||||||
# v1.4.0: 工具调用缓存
|
# v1.4.0: 工具调用缓存
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
"""缓存条目"""
|
"""缓存条目"""
|
||||||
|
|
||||||
tool_name: str
|
tool_name: str
|
||||||
args_hash: str
|
args_hash: str
|
||||||
result: str
|
result: str
|
||||||
|
|
@ -347,6 +351,7 @@ tool_call_cache = ToolCallCache()
|
||||||
# v1.4.0: 工具权限控制
|
# v1.4.0: 工具权限控制
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class PermissionChecker:
|
class PermissionChecker:
|
||||||
"""工具权限检查器"""
|
"""工具权限检查器"""
|
||||||
|
|
||||||
|
|
@ -479,6 +484,7 @@ permission_checker = PermissionChecker()
|
||||||
# 工具类型转换
|
# 工具类型转换
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||||
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
|
|
@ -492,7 +498,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||||
return type_mapping.get(json_type, ToolParamType.STRING)
|
return type_mapping.get(json_type, ToolParamType.STRING)
|
||||||
|
|
||||||
|
|
||||||
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
def parse_mcp_parameters(
|
||||||
|
input_schema: Dict[str, Any],
|
||||||
|
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
||||||
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
|
|
@ -534,6 +542,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
|
||||||
# MCP 工具代理
|
# MCP 工具代理
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPToolProxy(BaseTool):
|
class MCPToolProxy(BaseTool):
|
||||||
"""MCP 工具代理基类"""
|
"""MCP 工具代理基类"""
|
||||||
|
|
||||||
|
|
@ -576,10 +585,7 @@ class MCPToolProxy(BaseTool):
|
||||||
# v1.4.0: 权限检查
|
# v1.4.0: 权限检查
|
||||||
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
||||||
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
||||||
return {
|
return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
|
||||||
"name": self.name,
|
|
||||||
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
||||||
|
|
||||||
|
|
@ -749,11 +755,7 @@ class MCPToolProxy(BaseTool):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _call_post_process_llm(
|
async def _call_post_process_llm(
|
||||||
self,
|
self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
|
||||||
prompt: str,
|
|
||||||
max_tokens: int,
|
|
||||||
settings: Dict[str, Any],
|
|
||||||
server_config: Optional[Dict[str, Any]]
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""调用 LLM 进行后处理"""
|
"""调用 LLM 进行后处理"""
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
|
@ -811,10 +813,7 @@ class MCPToolProxy(BaseTool):
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_tool_class(
|
def create_mcp_tool_class(
|
||||||
tool_key: str,
|
tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||||
tool_info: MCPToolInfo,
|
|
||||||
tool_prefix: str,
|
|
||||||
disabled: bool = False
|
|
||||||
) -> Type[MCPToolProxy]:
|
) -> Type[MCPToolProxy]:
|
||||||
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
||||||
parameters = parse_mcp_parameters(tool_info.input_schema)
|
parameters = parse_mcp_parameters(tool_info.input_schema)
|
||||||
|
|
@ -837,7 +836,7 @@ def create_mcp_tool_class(
|
||||||
"_mcp_tool_key": tool_key,
|
"_mcp_tool_key": tool_key,
|
||||||
"_mcp_original_name": tool_info.name,
|
"_mcp_original_name": tool_info.name,
|
||||||
"_mcp_server_name": tool_info.server_name,
|
"_mcp_server_name": tool_info.server_name,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_class
|
return tool_class
|
||||||
|
|
@ -851,11 +850,7 @@ class MCPToolRegistry:
|
||||||
self._tool_infos: Dict[str, ToolInfo] = {}
|
self._tool_infos: Dict[str, ToolInfo] = {}
|
||||||
|
|
||||||
def register_tool(
|
def register_tool(
|
||||||
self,
|
self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||||
tool_key: str,
|
|
||||||
tool_info: MCPToolInfo,
|
|
||||||
tool_prefix: str,
|
|
||||||
disabled: bool = False
|
|
||||||
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
||||||
"""注册 MCP 工具"""
|
"""注册 MCP 工具"""
|
||||||
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
||||||
|
|
@ -902,6 +897,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
|
||||||
# 内置工具
|
# 内置工具
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPReadResourceTool(BaseTool):
|
class MCPReadResourceTool(BaseTool):
|
||||||
"""v1.2.0: MCP 资源读取工具"""
|
"""v1.2.0: MCP 资源读取工具"""
|
||||||
|
|
||||||
|
|
@ -973,6 +969,7 @@ class MCPGetPromptTool(BaseTool):
|
||||||
# v1.8.0: 工具链代理工具
|
# v1.8.0: 工具链代理工具
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ToolChainProxyBase(BaseTool):
|
class ToolChainProxyBase(BaseTool):
|
||||||
"""工具链代理基类"""
|
"""工具链代理基类"""
|
||||||
|
|
||||||
|
|
@ -1037,7 +1034,7 @@ def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBa
|
||||||
"parameters": parameters,
|
"parameters": parameters,
|
||||||
"available_for_llm": True,
|
"available_for_llm": True,
|
||||||
"_chain_name": chain.name,
|
"_chain_name": chain.name,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_class
|
return tool_class
|
||||||
|
|
@ -1095,7 +1092,13 @@ class MCPStatusTool(BaseTool):
|
||||||
name = "mcp_status"
|
name = "mcp_status"
|
||||||
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
||||||
parameters = [
|
parameters = [
|
||||||
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"]),
|
(
|
||||||
|
"query_type",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"查询类型",
|
||||||
|
False,
|
||||||
|
["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"],
|
||||||
|
),
|
||||||
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
||||||
]
|
]
|
||||||
available_for_llm = True
|
available_for_llm = True
|
||||||
|
|
@ -1132,10 +1135,7 @@ class MCPStatusTool(BaseTool):
|
||||||
if query_type in ("cache",):
|
if query_type in ("cache",):
|
||||||
result_parts.append(self._format_cache())
|
result_parts.append(self._format_cache())
|
||||||
|
|
||||||
return {
|
return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
|
||||||
"name": self.name,
|
|
||||||
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
|
|
||||||
}
|
|
||||||
|
|
||||||
def _format_status(self, server_name: Optional[str] = None) -> str:
|
def _format_status(self, server_name: Optional[str] = None) -> str:
|
||||||
status = mcp_manager.get_status()
|
status = mcp_manager.get_status()
|
||||||
|
|
@ -1147,14 +1147,14 @@ class MCPStatusTool(BaseTool):
|
||||||
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
||||||
|
|
||||||
lines.append("\n🔌 服务器详情:")
|
lines.append("\n🔌 服务器详情:")
|
||||||
for name, info in status['servers'].items():
|
for name, info in status["servers"].items():
|
||||||
if server_name and name != server_name:
|
if server_name and name != server_name:
|
||||||
continue
|
continue
|
||||||
status_icon = "✅" if info['connected'] else "❌"
|
status_icon = "✅" if info["connected"] else "❌"
|
||||||
enabled_text = "" if info['enabled'] else " (已禁用)"
|
enabled_text = "" if info["enabled"] else " (已禁用)"
|
||||||
lines.append(f" {status_icon} {name}{enabled_text}")
|
lines.append(f" {status_icon} {name}{enabled_text}")
|
||||||
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
||||||
if info['consecutive_failures'] > 0:
|
if info["consecutive_failures"] > 0:
|
||||||
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
@ -1184,11 +1184,11 @@ class MCPStatusTool(BaseTool):
|
||||||
stats = mcp_manager.get_all_stats()
|
stats = mcp_manager.get_all_stats()
|
||||||
lines = ["📈 调用统计"]
|
lines = ["📈 调用统计"]
|
||||||
|
|
||||||
g = stats['global']
|
g = stats["global"]
|
||||||
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
||||||
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
||||||
if g['total_tool_calls'] > 0:
|
if g["total_tool_calls"] > 0:
|
||||||
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100
|
success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
|
||||||
lines.append(f" 成功率: {success_rate:.1f}%")
|
lines.append(f" 成功率: {success_rate:.1f}%")
|
||||||
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
||||||
|
|
||||||
|
|
@ -1294,6 +1294,7 @@ class MCPStatusTool(BaseTool):
|
||||||
# 命令处理
|
# 命令处理
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPStatusCommand(BaseCommand):
|
class MCPStatusCommand(BaseCommand):
|
||||||
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
||||||
|
|
||||||
|
|
@ -1644,6 +1645,7 @@ class MCPStatusCommand(BaseCommand):
|
||||||
_plugin_instance._load_tool_chains()
|
_plugin_instance._load_tool_chains()
|
||||||
chains = tool_chain_manager.get_all_chains()
|
chains = tool_chain_manager.get_all_chains()
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
registered = 0
|
registered = 0
|
||||||
for name, chain in tool_chain_manager.get_enabled_chains().items():
|
for name, chain in tool_chain_manager.get_enabled_chains().items():
|
||||||
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
|
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
|
||||||
|
|
@ -1983,6 +1985,7 @@ class MCPImportCommand(BaseCommand):
|
||||||
# 事件处理器
|
# 事件处理器
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPStartupHandler(BaseEventHandler):
|
class MCPStartupHandler(BaseEventHandler):
|
||||||
"""MCP 启动事件处理器"""
|
"""MCP 启动事件处理器"""
|
||||||
|
|
||||||
|
|
@ -2037,6 +2040,7 @@ class MCPStopHandler(BaseEventHandler):
|
||||||
# 主插件类
|
# 主插件类
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class MCPBridgePlugin(BasePlugin):
|
class MCPBridgePlugin(BasePlugin):
|
||||||
"""MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
"""MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
||||||
|
|
@ -2505,7 +2509,7 @@ class MCPBridgePlugin(BasePlugin):
|
||||||
label="📋 工具链列表",
|
label="📋 工具链列表",
|
||||||
input_type="textarea",
|
input_type="textarea",
|
||||||
rows=20,
|
rows=20,
|
||||||
placeholder='''[
|
placeholder="""[
|
||||||
{
|
{
|
||||||
"name": "search_and_detail",
|
"name": "search_and_detail",
|
||||||
"description": "先搜索再获取详情",
|
"description": "先搜索再获取详情",
|
||||||
|
|
@ -2515,7 +2519,7 @@ class MCPBridgePlugin(BasePlugin):
|
||||||
{"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}}
|
{"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]''',
|
]""",
|
||||||
hint="每个工具链包含 name、description、input_params、steps",
|
hint="每个工具链包含 name、description、input_params、steps",
|
||||||
order=30,
|
order=30,
|
||||||
),
|
),
|
||||||
|
|
@ -2653,9 +2657,9 @@ mcp_bing_*""",
|
||||||
label="📜 高级权限规则(可选)",
|
label="📜 高级权限规则(可选)",
|
||||||
input_type="textarea",
|
input_type="textarea",
|
||||||
rows=10,
|
rows=10,
|
||||||
placeholder='''[
|
placeholder="""[
|
||||||
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
||||||
]''',
|
]""",
|
||||||
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
||||||
order=10,
|
order=10,
|
||||||
),
|
),
|
||||||
|
|
@ -2754,7 +2758,9 @@ mcp_bing_*""",
|
||||||
value = match1.group(2)
|
value = match1.group(2)
|
||||||
suffix = match1.group(3)
|
suffix = match1.group(3)
|
||||||
# 将转义的换行符还原为实际换行符
|
# 将转义的换行符还原为实际换行符
|
||||||
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
unescaped = (
|
||||||
|
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
||||||
|
)
|
||||||
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
||||||
fixed_lines.append(fixed_line)
|
fixed_lines.append(fixed_line)
|
||||||
modified = True
|
modified = True
|
||||||
|
|
@ -2948,11 +2954,13 @@ mcp_bing_*""",
|
||||||
logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}")
|
logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}")
|
||||||
args_template = {}
|
args_template = {}
|
||||||
|
|
||||||
steps.append({
|
steps.append(
|
||||||
|
{
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
"args_template": args_template,
|
"args_template": args_template,
|
||||||
"output_key": output_key,
|
"output_key": output_key,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if not steps:
|
if not steps:
|
||||||
logger.warning("快速添加工具链: 没有有效的步骤")
|
logger.warning("快速添加工具链: 没有有效的步骤")
|
||||||
|
|
@ -3056,7 +3064,9 @@ mcp_bing_*""",
|
||||||
if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict):
|
if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict):
|
||||||
self.config["tool_chains"] = {}
|
self.config["tool_chains"] = {}
|
||||||
self.config["tool_chains"]["chains_list"] = chains_json
|
self.config["tool_chains"]["chains_list"] = chains_json
|
||||||
logger.info("检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)")
|
logger.info(
|
||||||
|
"检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)"
|
||||||
|
)
|
||||||
|
|
||||||
chains_config = self.config.get("tool_chains", {})
|
chains_config = self.config.get("tool_chains", {})
|
||||||
if not isinstance(chains_config, dict):
|
if not isinstance(chains_config, dict):
|
||||||
|
|
@ -3153,10 +3163,7 @@ mcp_bing_*""",
|
||||||
|
|
||||||
# 应用过滤器
|
# 应用过滤器
|
||||||
if filter_patterns:
|
if filter_patterns:
|
||||||
matched = any(
|
matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns)
|
||||||
fnmatch.fnmatch(tool_name, p) or tool_name == p
|
|
||||||
for p in filter_patterns
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter_mode == "whitelist":
|
if filter_mode == "whitelist":
|
||||||
# 白名单模式:只注册匹配的
|
# 白名单模式:只注册匹配的
|
||||||
|
|
@ -3179,6 +3186,7 @@ mcp_bing_*""",
|
||||||
return result.content or "(无返回内容)"
|
return result.content or "(无返回内容)"
|
||||||
else:
|
else:
|
||||||
return f"工具调用失败: {result.error}"
|
return f"工具调用失败: {result.error}"
|
||||||
|
|
||||||
return execute_func
|
return execute_func
|
||||||
|
|
||||||
execute_func = make_execute_func(tool_key)
|
execute_func = make_execute_func(tool_key)
|
||||||
|
|
@ -3207,7 +3215,9 @@ mcp_bing_*""",
|
||||||
|
|
||||||
return registered_count
|
return registered_count
|
||||||
|
|
||||||
def _update_react_status_display(self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]) -> None:
|
def _update_react_status_display(
|
||||||
|
self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]
|
||||||
|
) -> None:
|
||||||
"""更新 ReAct 工具状态显示"""
|
"""更新 ReAct 工具状态显示"""
|
||||||
if not registered_tools:
|
if not registered_tools:
|
||||||
status_text = "(未注册任何工具)"
|
status_text = "(未注册任何工具)"
|
||||||
|
|
@ -3243,12 +3253,14 @@ mcp_bing_*""",
|
||||||
description = param_info.get("description", f"参数 {param_name}")
|
description = param_info.get("description", f"参数 {param_name}")
|
||||||
is_required = param_name in required
|
is_required = param_name in required
|
||||||
|
|
||||||
parameters.append({
|
parameters.append(
|
||||||
|
{
|
||||||
"name": param_name,
|
"name": param_name,
|
||||||
"type": param_type,
|
"type": param_type,
|
||||||
"description": description,
|
"description": description,
|
||||||
"required": is_required,
|
"required": is_required,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
|
|
@ -3295,6 +3307,7 @@ mcp_bing_*""",
|
||||||
async def _async_connect_servers(self) -> None:
|
async def _async_connect_servers(self) -> None:
|
||||||
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
settings = self.config.get("settings", {})
|
settings = self.config.get("settings", {})
|
||||||
|
|
||||||
servers_config = self._load_mcp_servers_config()
|
servers_config = self._load_mcp_servers_config()
|
||||||
|
|
@ -3380,10 +3393,7 @@ mcp_bing_*""",
|
||||||
|
|
||||||
# 并行执行所有连接
|
# 并行执行所有连接
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
|
||||||
*[connect_single_server(cfg) for cfg in enabled_configs],
|
|
||||||
return_exceptions=True
|
|
||||||
)
|
|
||||||
connect_duration = time.time() - start_time
|
connect_duration = time.time() - start_time
|
||||||
|
|
||||||
# 统计连接结果
|
# 统计连接结果
|
||||||
|
|
@ -3404,15 +3414,14 @@ mcp_bing_*""",
|
||||||
|
|
||||||
# 注册所有工具
|
# 注册所有工具
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
registered_count = 0
|
registered_count = 0
|
||||||
|
|
||||||
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
||||||
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
||||||
is_disabled = tool_name in disabled_tools
|
is_disabled = tool_name in disabled_tools
|
||||||
|
|
||||||
info, tool_class = mcp_tool_registry.register_tool(
|
info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
|
||||||
tool_key, tool_info, tool_prefix, disabled=is_disabled
|
|
||||||
)
|
|
||||||
info.plugin_name = self.plugin_name
|
info.plugin_name = self.plugin_name
|
||||||
|
|
||||||
if component_registry.register_component(info, tool_class):
|
if component_registry.register_component(info, tool_class):
|
||||||
|
|
@ -3433,7 +3442,9 @@ mcp_bing_*""",
|
||||||
react_count = self._register_tools_to_react()
|
react_count = self._register_tools_to_react()
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info(f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具")
|
logger.info(
|
||||||
|
f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具"
|
||||||
|
)
|
||||||
|
|
||||||
# 更新状态显示
|
# 更新状态显示
|
||||||
self._update_status_display()
|
self._update_status_display()
|
||||||
|
|
@ -3508,7 +3519,9 @@ mcp_bing_*""",
|
||||||
logger.info("检测到旧版 servers.list,已自动迁移为 Claude mcpServers(请在 WebUI 保存一次以固化)")
|
logger.info("检测到旧版 servers.list,已自动迁移为 Claude mcpServers(请在 WebUI 保存一次以固化)")
|
||||||
|
|
||||||
if not claude_json.strip():
|
if not claude_json.strip():
|
||||||
self._last_servers_config_error = "未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)"
|
self._last_servers_config_error = (
|
||||||
|
"未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,18 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("mcp_tool_chain")
|
logger = get_logger("mcp_tool_chain")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger("mcp_tool_chain")
|
logger = logging.getLogger("mcp_tool_chain")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolChainStep:
|
class ToolChainStep:
|
||||||
"""工具链步骤"""
|
"""工具链步骤"""
|
||||||
|
|
||||||
tool_name: str # 要调用的工具名(如 mcp_server_tool)
|
tool_name: str # 要调用的工具名(如 mcp_server_tool)
|
||||||
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
|
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
|
||||||
output_key: str = "" # 输出存储的键名,供后续步骤引用
|
output_key: str = "" # 输出存储的键名,供后续步骤引用
|
||||||
|
|
@ -60,6 +63,7 @@ class ToolChainStep:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolChainDefinition:
|
class ToolChainDefinition:
|
||||||
"""工具链定义"""
|
"""工具链定义"""
|
||||||
|
|
||||||
name: str # 工具链名称(将作为组合工具的名称)
|
name: str # 工具链名称(将作为组合工具的名称)
|
||||||
description: str # 工具链描述(供 LLM 理解)
|
description: str # 工具链描述(供 LLM 理解)
|
||||||
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
|
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
|
||||||
|
|
@ -90,6 +94,7 @@ class ToolChainDefinition:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChainExecutionResult:
|
class ChainExecutionResult:
|
||||||
"""工具链执行结果"""
|
"""工具链执行结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
final_output: str # 最终输出(最后一个步骤的结果)
|
final_output: str # 最终输出(最后一个步骤的结果)
|
||||||
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
|
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
|
||||||
|
|
@ -113,7 +118,7 @@ class ToolChainExecutor:
|
||||||
"""工具链执行器"""
|
"""工具链执行器"""
|
||||||
|
|
||||||
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
|
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
|
||||||
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
|
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
|
||||||
|
|
||||||
def __init__(self, mcp_manager):
|
def __init__(self, mcp_manager):
|
||||||
self._mcp_manager = mcp_manager
|
self._mcp_manager = mcp_manager
|
||||||
|
|
@ -295,10 +300,7 @@ class ToolChainExecutor:
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
resolved[key] = self._resolve_args(value, context)
|
resolved[key] = self._resolve_args(value, context)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
resolved[key] = [
|
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
|
||||||
self._substitute_vars(v, context) if isinstance(v, str) else v
|
|
||||||
for v in value
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
resolved[key] = value
|
resolved[key] = value
|
||||||
|
|
||||||
|
|
@ -306,6 +308,7 @@ class ToolChainExecutor:
|
||||||
|
|
||||||
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
|
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
|
||||||
"""替换字符串中的变量"""
|
"""替换字符串中的变量"""
|
||||||
|
|
||||||
def replacer(match):
|
def replacer(match):
|
||||||
var_path = match.group(1)
|
var_path = match.group(1)
|
||||||
return self._get_var_value(var_path, context)
|
return self._get_var_value(var_path, context)
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
|
||||||
chat_stream=self.message.chat_stream,
|
chat_stream=self.message.chat_stream,
|
||||||
reply_reason=reply_reason,
|
reply_reason=reply_reason,
|
||||||
enable_chinese_typo=False,
|
enable_chinese_typo=False,
|
||||||
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
|
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
|
||||||
)
|
)
|
||||||
if result_status:
|
if result_status:
|
||||||
# 发送生成的回复
|
# 发送生成的回复
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
|
||||||
|
|
||||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
||||||
|
|
||||||
|
|
||||||
class SimpleClass(ConfigBase):
|
class SimpleClass(ConfigBase):
|
||||||
a: int = 1
|
a: int = 1
|
||||||
b: str = "test"
|
b: str = "test"
|
||||||
|
|
@ -282,7 +283,7 @@ class TestConfigBase:
|
||||||
True,
|
True,
|
||||||
"ConfigBase is not Hashable",
|
"ConfigBase is not Hashable",
|
||||||
id="listset-validation-set-configbase-element_reject",
|
id="listset-validation-set-configbase-element_reject",
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
||||||
|
|
@ -340,7 +341,7 @@ class TestConfigBase:
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
id="dict-validation-happy-configbase-value",
|
id="dict-validation-happy-configbase-value",
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
||||||
|
|
@ -353,13 +354,11 @@ class TestConfigBase:
|
||||||
field_name = "mapping"
|
field_name = "mapping"
|
||||||
|
|
||||||
if expect_error:
|
if expect_error:
|
||||||
|
|
||||||
# Act / Assert
|
# Act / Assert
|
||||||
with pytest.raises(TypeError) as exc_info:
|
with pytest.raises(TypeError) as exc_info:
|
||||||
dummy._validate_dict_type(annotation, field_name)
|
dummy._validate_dict_type(annotation, field_name)
|
||||||
assert error_fragment in str(exc_info.value)
|
assert error_fragment in str(exc_info.value)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
dummy._validate_dict_type(annotation, field_name)
|
dummy._validate_dict_type(annotation, field_name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import importlib
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLogger:
|
class DummyLogger:
|
||||||
|
|
@ -71,6 +70,7 @@ class DummyLLMRequest:
|
||||||
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
|
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
|
||||||
return ("dummy description", {})
|
return ("dummy description", {})
|
||||||
|
|
||||||
|
|
||||||
class DummySelect:
|
class DummySelect:
|
||||||
def __init__(self, *a, **k):
|
def __init__(self, *a, **k):
|
||||||
pass
|
pass
|
||||||
|
|
@ -81,6 +81,7 @@ class DummySelect:
|
||||||
def limit(self, n):
|
def limit(self, n):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def patch_external_dependencies(monkeypatch):
|
def patch_external_dependencies(monkeypatch):
|
||||||
# Provide dummy implementations as modules so that importing image_manager is safe
|
# Provide dummy implementations as modules so that importing image_manager is safe
|
||||||
|
|
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
|
||||||
if tmp_path is not None:
|
if tmp_path is not None:
|
||||||
tmpdir = Path(tmp_path)
|
tmpdir = Path(tmp_path)
|
||||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||||
setattr(mod, "IMAGE_DIR", tmpdir)
|
mod.IMAGE_DIR = tmpdir
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return mod
|
return mod
|
||||||
|
|
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
|
||||||
|
|
||||||
# cleanup should run without error
|
# cleanup should run without error
|
||||||
mgr.cleanup_invalid_descriptions_in_db()
|
mgr.cleanup_invalid_descriptions_in_db()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.config.official_configs import ChatConfig
|
from src.config.official_configs import ChatConfig
|
||||||
from src.config.config import Config
|
from src.config.config import Config
|
||||||
from src.webui.config_schema import ConfigSchemaGenerator
|
from src.webui.config_schema import ConfigSchemaGenerator
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Expression routes pytest tests"""
|
"""Expression routes pytest tests"""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
@ -12,7 +11,6 @@ from sqlalchemy import text
|
||||||
from sqlmodel import Session, SQLModel, create_engine, select
|
from sqlmodel import Session, SQLModel, create_engine, select
|
||||||
|
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.common.database.database import get_db_session
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_app() -> FastAPI:
|
def create_test_app() -> FastAPI:
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ def analyze_single_file(file_path: str) -> Dict:
|
||||||
stats["date_range"] = {
|
stats["date_range"] = {
|
||||||
"start": min_date.isoformat(),
|
"start": min_date.isoformat(),
|
||||||
"end": max_date.isoformat(),
|
"end": max_date.isoformat(),
|
||||||
"duration_days": (max_date - min_date).days + 1
|
"duration_days": (max_date - min_date).days + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查字段存在性
|
# 检查字段存在性
|
||||||
|
|
@ -151,8 +151,8 @@ def print_file_stats(stats: Dict, index: int = None):
|
||||||
print(f" 文件中的 total_count: {stats['total_count']}")
|
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||||
print(f" 实际记录数: {stats['actual_count']}")
|
print(f" 实际记录数: {stats['actual_count']}")
|
||||||
|
|
||||||
if stats['total_count'] != stats['actual_count']:
|
if stats["total_count"] != stats["actual_count"]:
|
||||||
diff = stats['total_count'] - stats['actual_count']
|
diff = stats["total_count"] - stats["actual_count"]
|
||||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||||
|
|
||||||
print("\n【评估结果统计】")
|
print("\n【评估结果统计】")
|
||||||
|
|
@ -161,21 +161,21 @@ def print_file_stats(stats: Dict, index: int = None):
|
||||||
|
|
||||||
print("\n【唯一性统计】")
|
print("\n【唯一性统计】")
|
||||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||||
if stats['actual_count'] > 0:
|
if stats["actual_count"] > 0:
|
||||||
duplicate_count = stats['actual_count'] - stats['unique_pairs']
|
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
|
||||||
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||||
|
|
||||||
print("\n【评估者统计】")
|
print("\n【评估者统计】")
|
||||||
if stats['evaluators']:
|
if stats["evaluators"]:
|
||||||
for evaluator, count in stats['evaluators'].most_common():
|
for evaluator, count in stats["evaluators"].most_common():
|
||||||
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||||
else:
|
else:
|
||||||
print(" 无评估者信息")
|
print(" 无评估者信息")
|
||||||
|
|
||||||
print("\n【时间统计】")
|
print("\n【时间统计】")
|
||||||
if stats['date_range']:
|
if stats["date_range"]:
|
||||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||||
|
|
@ -185,8 +185,8 @@ def print_file_stats(stats: Dict, index: int = None):
|
||||||
print("\n【字段统计】")
|
print("\n【字段统计】")
|
||||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||||
if stats['has_reason']:
|
if stats["has_reason"]:
|
||||||
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -215,15 +215,15 @@ def print_summary(all_stats: List[Dict]):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 汇总记录统计
|
# 汇总记录统计
|
||||||
total_records = sum(s['actual_count'] for s in valid_files)
|
total_records = sum(s["actual_count"] for s in valid_files)
|
||||||
total_suitable = sum(s['suitable_count'] for s in valid_files)
|
total_suitable = sum(s["suitable_count"] for s in valid_files)
|
||||||
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
|
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
|
||||||
total_unique_pairs = set()
|
total_unique_pairs = set()
|
||||||
|
|
||||||
# 收集所有唯一的(situation, style)对
|
# 收集所有唯一的(situation, style)对
|
||||||
for stats in valid_files:
|
for stats in valid_files:
|
||||||
try:
|
try:
|
||||||
with open(stats['file_path'], "r", encoding="utf-8") as f:
|
with open(stats["file_path"], "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
results = data.get("manual_results", [])
|
results = data.get("manual_results", [])
|
||||||
for r in results:
|
for r in results:
|
||||||
|
|
@ -234,8 +234,16 @@ def print_summary(all_stats: List[Dict]):
|
||||||
|
|
||||||
print("\n【记录汇总】")
|
print("\n【记录汇总】")
|
||||||
print(f" 总记录数: {total_records:,} 条")
|
print(f" 总记录数: {total_records:,} 条")
|
||||||
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
|
print(
|
||||||
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
|
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
|
||||||
|
if total_records > 0
|
||||||
|
else " 通过: 0 条"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
|
||||||
|
if total_records > 0
|
||||||
|
else " 不通过: 0 条"
|
||||||
|
)
|
||||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||||
|
|
||||||
if total_records > 0:
|
if total_records > 0:
|
||||||
|
|
@ -246,7 +254,7 @@ def print_summary(all_stats: List[Dict]):
|
||||||
# 汇总评估者统计
|
# 汇总评估者统计
|
||||||
all_evaluators = Counter()
|
all_evaluators = Counter()
|
||||||
for stats in valid_files:
|
for stats in valid_files:
|
||||||
all_evaluators.update(stats['evaluators'])
|
all_evaluators.update(stats["evaluators"])
|
||||||
|
|
||||||
print("\n【评估者汇总】")
|
print("\n【评估者汇总】")
|
||||||
if all_evaluators:
|
if all_evaluators:
|
||||||
|
|
@ -259,7 +267,7 @@ def print_summary(all_stats: List[Dict]):
|
||||||
# 汇总时间范围
|
# 汇总时间范围
|
||||||
all_dates = []
|
all_dates = []
|
||||||
for stats in valid_files:
|
for stats in valid_files:
|
||||||
all_dates.extend(stats['evaluation_dates'])
|
all_dates.extend(stats["evaluation_dates"])
|
||||||
|
|
||||||
if all_dates:
|
if all_dates:
|
||||||
min_date = min(all_dates)
|
min_date = min(all_dates)
|
||||||
|
|
@ -270,7 +278,7 @@ def print_summary(all_stats: List[Dict]):
|
||||||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||||
|
|
||||||
# 文件大小汇总
|
# 文件大小汇总
|
||||||
total_size = sum(s['file_size'] for s in valid_files)
|
total_size = sum(s["file_size"] for s in valid_files)
|
||||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||||
print("\n【文件大小汇总】")
|
print("\n【文件大小汇总】")
|
||||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||||
|
|
@ -318,5 +326,3 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,9 @@ def main():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if not args.raw_index:
|
if not args.raw_index:
|
||||||
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
|
logger.info(
|
||||||
|
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 解析索引列表(1-based)
|
# 解析索引列表(1-based)
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def save_results(evaluation_results: List[Dict]):
|
||||||
data = {
|
data = {
|
||||||
"last_updated": datetime.now().isoformat(),
|
"last_updated": datetime.now().isoformat(),
|
||||||
"total_count": len(evaluation_results),
|
"total_count": len(evaluation_results),
|
||||||
"evaluation_results": evaluation_results
|
"evaluation_results": evaluation_results,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||||
|
|
@ -84,9 +84,7 @@ def save_results(evaluation_results: List[Dict]):
|
||||||
print(f"\n✗ 保存评估结果失败: {e}")
|
print(f"\n✗ 保存评估结果失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
def select_expressions_for_evaluation(
|
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
|
||||||
evaluated_pairs: Set[Tuple[str, str]] = None
|
|
||||||
) -> List[Expression]:
|
|
||||||
"""
|
"""
|
||||||
选择用于评估的表达方式
|
选择用于评估的表达方式
|
||||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||||
|
|
@ -109,10 +107,7 @@ def select_expressions_for_evaluation(
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 过滤出未评估的项目
|
# 过滤出未评估的项目
|
||||||
unevaluated = [
|
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||||
expr for expr in all_expressions
|
|
||||||
if (expr.situation, expr.style) not in evaluated_pairs
|
|
||||||
]
|
|
||||||
|
|
||||||
if not unevaluated:
|
if not unevaluated:
|
||||||
logger.warning("所有项目都已评估完成")
|
logger.warning("所有项目都已评估完成")
|
||||||
|
|
@ -132,7 +127,9 @@ def select_expressions_for_evaluation(
|
||||||
count_eq1_needed = count_gt1_count * 2
|
count_eq1_needed = count_gt1_count * 2
|
||||||
|
|
||||||
if len(count_eq1) < count_eq1_needed:
|
if len(count_eq1) < count_eq1_needed:
|
||||||
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条")
|
logger.warning(
|
||||||
|
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条"
|
||||||
|
)
|
||||||
count_eq1_needed = len(count_eq1)
|
count_eq1_needed = len(count_eq1)
|
||||||
|
|
||||||
# 随机选择count=1的项目
|
# 随机选择count=1的项目
|
||||||
|
|
@ -141,13 +138,16 @@ def select_expressions_for_evaluation(
|
||||||
selected = selected_count_gt1 + selected_count_eq1
|
selected = selected_count_gt1 + selected_count_eq1
|
||||||
random.shuffle(selected) # 打乱顺序
|
random.shuffle(selected) # 打乱顺序
|
||||||
|
|
||||||
logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)")
|
logger.info(
|
||||||
|
f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)"
|
||||||
|
)
|
||||||
|
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"选择表达方式失败: {e}")
|
logger.error(f"选择表达方式失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -202,9 +202,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||||
|
|
||||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1024
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"LLM响应: {response}")
|
logger.debug(f"LLM响应: {response}")
|
||||||
|
|
@ -241,7 +239,9 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
||||||
Returns:
|
Returns:
|
||||||
评估结果字典
|
评估结果字典
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
|
logger.info(
|
||||||
|
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
|
||||||
|
)
|
||||||
|
|
||||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||||
|
|
||||||
|
|
@ -258,7 +258,7 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"error": error,
|
"error": error,
|
||||||
"evaluator": "llm",
|
"evaluator": "llm",
|
||||||
"evaluated_at": datetime.now().isoformat()
|
"evaluated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -427,7 +427,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||||
"count_eq1": count_eq1_group,
|
"count_eq1": count_eq1_group,
|
||||||
"count_gt1": count_gt1_group,
|
"count_gt1": count_gt1_group,
|
||||||
"total_evaluated": len(evaluation_results)
|
"total_evaluated": len(evaluation_results),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -466,8 +466,7 @@ async def main():
|
||||||
try:
|
try:
|
||||||
all_expressions = list(Expression.select())
|
all_expressions = list(Expression.select())
|
||||||
unevaluated_count_gt1 = [
|
unevaluated_count_gt1 = [
|
||||||
expr for expr in all_expressions
|
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||||
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
|
||||||
]
|
]
|
||||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -485,21 +484,20 @@ async def main():
|
||||||
try:
|
try:
|
||||||
llm = LLMRequest(
|
llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.tool_use,
|
model_set=model_config.model_task_config.tool_use,
|
||||||
request_type="expression_evaluator_count_analysis_llm"
|
request_type="expression_evaluator_count_analysis_llm",
|
||||||
)
|
)
|
||||||
print("✓ LLM实例创建成功\n")
|
print("✓ LLM实例创建成功\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建LLM实例失败: {e}")
|
logger.error(f"创建LLM实例失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||||
db.close()
|
db.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||||
expressions = select_expressions_for_evaluation(
|
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
|
||||||
evaluated_pairs=evaluated_pairs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not expressions:
|
if not expressions:
|
||||||
print("\n没有可评估的项目")
|
print("\n没有可评估的项目")
|
||||||
|
|
@ -518,7 +516,7 @@ async def main():
|
||||||
llm_result = await llm_evaluate_expression(expression, llm)
|
llm_result = await llm_evaluate_expression(expression, llm)
|
||||||
|
|
||||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||||
if llm_result.get('error'):
|
if llm_result.get("error"):
|
||||||
print(f" 错误: {llm_result['error']}")
|
print(f" 错误: {llm_result['error']}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
@ -553,4 +551,3 @@ async def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -140,9 +140,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||||
|
|
||||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1024
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"LLM响应: {response}")
|
logger.debug(f"LLM响应: {response}")
|
||||||
|
|
@ -152,6 +150,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||||
evaluation = json.loads(response)
|
evaluation = json.loads(response)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
import re
|
import re
|
||||||
|
|
||||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
evaluation = json.loads(json_match.group())
|
evaluation = json.loads(json_match.group())
|
||||||
|
|
@ -196,7 +195,7 @@ async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -
|
||||||
"suitable": suitable,
|
"suitable": suitable,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"error": error,
|
"error": error,
|
||||||
"evaluator": "llm"
|
"evaluator": "llm",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -244,10 +243,16 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
||||||
false_negatives += 1
|
false_negatives += 1
|
||||||
|
|
||||||
accuracy = (matched / total * 100) if total > 0 else 0
|
accuracy = (matched / total * 100) if total > 0 else 0
|
||||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
precision = (
|
||||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||||
|
)
|
||||||
|
recall = (
|
||||||
|
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||||
|
)
|
||||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
specificity = (
|
||||||
|
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
# 计算人工效标的不合适率
|
# 计算人工效标的不合适率
|
||||||
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
||||||
|
|
@ -283,7 +288,7 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
||||||
"specificity": specificity,
|
"specificity": specificity,
|
||||||
"manual_unsuitable_rate": manual_unsuitable_rate,
|
"manual_unsuitable_rate": manual_unsuitable_rate,
|
||||||
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
||||||
"rate_difference": rate_difference
|
"rate_difference": rate_difference,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -334,13 +339,11 @@ async def main(count: int | None = None):
|
||||||
# 2. 创建LLM实例并评估
|
# 2. 创建LLM实例并评估
|
||||||
print("\n步骤2: 创建LLM实例")
|
print("\n步骤2: 创建LLM实例")
|
||||||
try:
|
try:
|
||||||
llm = LLMRequest(
|
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
|
||||||
model_set=model_config.model_task_config.tool_use,
|
|
||||||
request_type="expression_evaluator_llm"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建LLM实例失败: {e}")
|
logger.error(f"创建LLM实例失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -348,11 +351,7 @@ async def main(count: int | None = None):
|
||||||
llm_results = []
|
llm_results = []
|
||||||
for i, manual_result in enumerate(valid_manual_results, 1):
|
for i, manual_result in enumerate(valid_manual_results, 1):
|
||||||
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
||||||
llm_results.append(await evaluate_expression_llm(
|
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
|
||||||
manual_result["situation"],
|
|
||||||
manual_result["style"],
|
|
||||||
llm
|
|
||||||
))
|
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
# 5. 输出FP和FN项目(在评估结果之前)
|
# 5. 输出FP和FN项目(在评估结果之前)
|
||||||
|
|
@ -372,14 +371,16 @@ async def main(count: int | None = None):
|
||||||
|
|
||||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||||
fp_items.append({
|
fp_items.append(
|
||||||
|
{
|
||||||
"situation": manual_result["situation"],
|
"situation": manual_result["situation"],
|
||||||
"style": manual_result["style"],
|
"style": manual_result["style"],
|
||||||
"manual_suitable": manual_result["suitable"],
|
"manual_suitable": manual_result["suitable"],
|
||||||
"llm_suitable": llm_result["suitable"],
|
"llm_suitable": llm_result["suitable"],
|
||||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||||
"llm_error": llm_result.get("error")
|
"llm_error": llm_result.get("error"),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if fp_items:
|
if fp_items:
|
||||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||||
|
|
@ -389,7 +390,7 @@ async def main(count: int | None = None):
|
||||||
print(f"Style: {item['style']}")
|
print(f"Style: {item['style']}")
|
||||||
print("人工评估: 不通过 ❌")
|
print("人工评估: 不通过 ❌")
|
||||||
print("LLM评估: 通过 ✅ (误判)")
|
print("LLM评估: 通过 ✅ (误判)")
|
||||||
if item.get('llm_error'):
|
if item.get("llm_error"):
|
||||||
print(f"LLM错误: {item['llm_error']}")
|
print(f"LLM错误: {item['llm_error']}")
|
||||||
print(f"LLM理由: {item['llm_reason']}")
|
print(f"LLM理由: {item['llm_reason']}")
|
||||||
print()
|
print()
|
||||||
|
|
@ -410,14 +411,16 @@ async def main(count: int | None = None):
|
||||||
|
|
||||||
# 人工评估通过,但LLM评估不通过(FN情况)
|
# 人工评估通过,但LLM评估不通过(FN情况)
|
||||||
if manual_result["suitable"] and not llm_result["suitable"]:
|
if manual_result["suitable"] and not llm_result["suitable"]:
|
||||||
fn_items.append({
|
fn_items.append(
|
||||||
|
{
|
||||||
"situation": manual_result["situation"],
|
"situation": manual_result["situation"],
|
||||||
"style": manual_result["style"],
|
"style": manual_result["style"],
|
||||||
"manual_suitable": manual_result["suitable"],
|
"manual_suitable": manual_result["suitable"],
|
||||||
"llm_suitable": llm_result["suitable"],
|
"llm_suitable": llm_result["suitable"],
|
||||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||||
"llm_error": llm_result.get("error")
|
"llm_error": llm_result.get("error"),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if fn_items:
|
if fn_items:
|
||||||
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
||||||
|
|
@ -427,7 +430,7 @@ async def main(count: int | None = None):
|
||||||
print(f"Style: {item['style']}")
|
print(f"Style: {item['style']}")
|
||||||
print("人工评估: 通过 ✅")
|
print("人工评估: 通过 ✅")
|
||||||
print("LLM评估: 不通过 ❌ (误删)")
|
print("LLM评估: 不通过 ❌ (误删)")
|
||||||
if item.get('llm_error'):
|
if item.get("llm_error"):
|
||||||
print(f"LLM错误: {item['llm_error']}")
|
print(f"LLM错误: {item['llm_error']}")
|
||||||
print(f"LLM理由: {item['llm_reason']}")
|
print(f"LLM理由: {item['llm_reason']}")
|
||||||
print()
|
print()
|
||||||
|
|
@ -447,13 +450,21 @@ async def main(count: int | None = None):
|
||||||
print()
|
print()
|
||||||
# print(" 【核心能力指标】")
|
# print(" 【核心能力指标】")
|
||||||
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
print(
|
||||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个"
|
||||||
|
)
|
||||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||||
print()
|
print()
|
||||||
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
print(
|
||||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个"
|
||||||
|
)
|
||||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||||
print()
|
print()
|
||||||
print(" 【其他指标】")
|
print(" 【其他指标】")
|
||||||
|
|
@ -464,12 +475,18 @@ async def main(count: int | None = None):
|
||||||
print()
|
print()
|
||||||
print(" 【不合适率分析】")
|
print(" 【不合适率分析】")
|
||||||
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
||||||
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
|
print(
|
||||||
|
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
|
||||||
|
)
|
||||||
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
||||||
print()
|
print()
|
||||||
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||||
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
|
print(
|
||||||
print(f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
||||||
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||||
|
|
@ -486,11 +503,12 @@ async def main(count: int | None = None):
|
||||||
try:
|
try:
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
json.dump({
|
json.dump(
|
||||||
"manual_results": valid_manual_results,
|
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
|
||||||
"llm_results": llm_results,
|
f,
|
||||||
"comparison": comparison
|
ensure_ascii=False,
|
||||||
}, f, ensure_ascii=False, indent=2)
|
indent=2,
|
||||||
|
)
|
||||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"保存结果到文件失败: {e}")
|
logger.warning(f"保存结果到文件失败: {e}")
|
||||||
|
|
@ -509,15 +527,9 @@ if __name__ == "__main__":
|
||||||
python evaluate_expressions_llm_v6.py # 使用全部数据
|
python evaluate_expressions_llm_v6.py # 使用全部数据
|
||||||
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
||||||
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
||||||
"""
|
""",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-n", "--count",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="随机选取的数据条数(默认:使用全部数据)"
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
asyncio.run(main(count=args.count))
|
asyncio.run(main(count=args.count))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ def save_results(manual_results: List[Dict]):
|
||||||
data = {
|
data = {
|
||||||
"last_updated": datetime.now().isoformat(),
|
"last_updated": datetime.now().isoformat(),
|
||||||
"total_count": len(manual_results),
|
"total_count": len(manual_results),
|
||||||
"manual_results": manual_results
|
"manual_results": manual_results,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
||||||
|
|
@ -98,10 +98,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
||||||
unevaluated = [
|
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||||
expr for expr in all_expressions
|
|
||||||
if (expr.situation, expr.style) not in evaluated_pairs
|
|
||||||
]
|
|
||||||
|
|
||||||
if not unevaluated:
|
if not unevaluated:
|
||||||
logger.info("所有项目都已评估完成")
|
logger.info("所有项目都已评估完成")
|
||||||
|
|
@ -120,6 +117,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取未评估表达方式失败: {e}")
|
logger.error(f"获取未评估表达方式失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -150,18 +148,18 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
|
||||||
while True:
|
while True:
|
||||||
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
||||||
|
|
||||||
if user_input in ['q', 'quit']:
|
if user_input in ["q", "quit"]:
|
||||||
print("退出评估")
|
print("退出评估")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if user_input in ['s', 'skip']:
|
if user_input in ["s", "skip"]:
|
||||||
print("跳过当前项目")
|
print("跳过当前项目")
|
||||||
return "skip"
|
return "skip"
|
||||||
|
|
||||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
if user_input in ["y", "yes", "1", "是", "通过"]:
|
||||||
suitable = True
|
suitable = True
|
||||||
break
|
break
|
||||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
elif user_input in ["n", "no", "0", "否", "不通过"]:
|
||||||
suitable = False
|
suitable = False
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
@ -173,7 +171,7 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
|
||||||
"suitable": suitable,
|
"suitable": suitable,
|
||||||
"reason": None,
|
"reason": None,
|
||||||
"evaluator": "manual",
|
"evaluator": "manual",
|
||||||
"evaluated_at": datetime.now().isoformat()
|
"evaluated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||||
|
|
@ -257,9 +255,9 @@ def main():
|
||||||
# 询问是否继续
|
# 询问是否继续
|
||||||
while True:
|
while True:
|
||||||
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
||||||
if continue_input in ['y', 'yes', '1', '是', '继续']:
|
if continue_input in ["y", "yes", "1", "是", "继续"]:
|
||||||
break
|
break
|
||||||
elif continue_input in ['n', 'no', '0', '否', '退出']:
|
elif continue_input in ["n", "no", "0", "否", "退出"]:
|
||||||
print("\n评估结束")
|
print("\n评估结束")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|
@ -275,4 +273,3 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -134,9 +134,7 @@ def handle_import_openie(
|
||||||
# 在非交互模式下,不再询问用户,而是直接报错终止
|
# 在非交互模式下,不再询问用户,而是直接报错终止
|
||||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||||
if non_interactive:
|
if non_interactive:
|
||||||
logger.error(
|
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
|
||||||
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||||
user_choice = input().strip().lower()
|
user_choice = input().strip().lower()
|
||||||
|
|
@ -189,9 +187,7 @@ def handle_import_openie(
|
||||||
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
if non_interactive:
|
if non_interactive:
|
||||||
logger.warning(
|
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
|
||||||
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
|
|
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
|
||||||
def main(argv: Optional[list[str]] = None) -> None:
|
def main(argv: Optional[list[str]] = None) -> None:
|
||||||
"""主函数 - 解析参数并运行异步主流程。"""
|
"""主函数 - 解析参数并运行异步主流程。"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=(
|
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
|
||||||
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
|
|
||||||
"将其导入到 LPMM 的向量库与知识图中。"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--non-interactive",
|
"--non-interactive",
|
||||||
|
|
|
||||||
|
|
@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
|
||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
if non_interactive:
|
if non_interactive:
|
||||||
logger.warning(
|
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
|
||||||
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||||
|
|
|
||||||
|
|
@ -68,4 +68,3 @@ def main() -> None:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ except ImportError as e:
|
||||||
|
|
||||||
logger = get_logger("lpmm_interactive_manager")
|
logger = get_logger("lpmm_interactive_manager")
|
||||||
|
|
||||||
|
|
||||||
async def interactive_add():
|
async def interactive_add():
|
||||||
"""交互式导入知识"""
|
"""交互式导入知识"""
|
||||||
print("\n" + "=" * 40)
|
print("\n" + "=" * 40)
|
||||||
|
|
@ -68,6 +69,7 @@ async def interactive_add():
|
||||||
print(f"\n[×] 发生异常: {e}")
|
print(f"\n[×] 发生异常: {e}")
|
||||||
logger.error(f"add_content 异常: {e}", exc_info=True)
|
logger.error(f"add_content 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def interactive_delete():
|
async def interactive_delete():
|
||||||
"""交互式删除知识"""
|
"""交互式删除知识"""
|
||||||
print("\n" + "=" * 40)
|
print("\n" + "=" * 40)
|
||||||
|
|
@ -108,8 +110,12 @@ async def interactive_delete():
|
||||||
return
|
return
|
||||||
|
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
|
confirm = (
|
||||||
if confirm != 'y':
|
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
|
||||||
|
.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
if confirm != "y":
|
||||||
print("\n[!] 已取消删除操作。")
|
print("\n[!] 已取消删除操作。")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -129,6 +135,7 @@ async def interactive_delete():
|
||||||
print(f"\n[×] 发生异常: {e}")
|
print(f"\n[×] 发生异常: {e}")
|
||||||
logger.error(f"delete 异常: {e}", exc_info=True)
|
logger.error(f"delete 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def interactive_clear():
|
async def interactive_clear():
|
||||||
"""交互式清空知识库"""
|
"""交互式清空知识库"""
|
||||||
print("\n" + "=" * 40)
|
print("\n" + "=" * 40)
|
||||||
|
|
@ -165,16 +172,21 @@ async def interactive_clear():
|
||||||
before = stats.get("before", {})
|
before = stats.get("before", {})
|
||||||
after = stats.get("after", {})
|
after = stats.get("after", {})
|
||||||
print("\n[统计信息]")
|
print("\n[统计信息]")
|
||||||
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
print(
|
||||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
|
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||||
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
|
||||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
|
)
|
||||||
|
print(
|
||||||
|
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||||
|
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"\n[×] 失败:{result['message']}")
|
print(f"\n[×] 失败:{result['message']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n[×] 发生异常: {e}")
|
print(f"\n[×] 发生异常: {e}")
|
||||||
logger.error(f"clear_all 异常: {e}", exc_info=True)
|
logger.error(f"clear_all 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def interactive_search():
|
async def interactive_search():
|
||||||
"""交互式查询知识"""
|
"""交互式查询知识"""
|
||||||
print("\n" + "=" * 40)
|
print("\n" + "=" * 40)
|
||||||
|
|
@ -224,6 +236,7 @@ async def interactive_search():
|
||||||
print(f"\n[×] 查询失败: {e}")
|
print(f"\n[×] 查询失败: {e}")
|
||||||
logger.error(f"查询异常: {e}", exc_info=True)
|
logger.error(f"查询异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""主循环"""
|
"""主循环"""
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -253,6 +266,7 @@ async def main():
|
||||||
else:
|
else:
|
||||||
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
|
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
# 运行主循环
|
# 运行主循环
|
||||||
|
|
@ -262,4 +276,3 @@ if __name__ == "__main__":
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n[!] 程序运行出错: {e}")
|
print(f"\n[!] 程序运行出错: {e}")
|
||||||
logger.error(f"Main loop 异常: {e}", exc_info=True)
|
logger.error(f"Main loop 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
|
||||||
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
||||||
txt_files = list(raw_dir.glob("*.txt"))
|
txt_files = list(raw_dir.glob("*.txt"))
|
||||||
if not txt_files:
|
if not txt_files:
|
||||||
msg = (
|
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,info_extraction 可能立即退出或无数据可处理。"
|
||||||
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
|
|
||||||
"info_extraction 可能立即退出或无数据可处理。"
|
|
||||||
)
|
|
||||||
print(msg)
|
print(msg)
|
||||||
if non_interactive:
|
if non_interactive:
|
||||||
logger.error(
|
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
|
||||||
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
||||||
return cont == "y"
|
return cont == "y"
|
||||||
|
|
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
|
||||||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||||
json_files = list(openie_dir.glob("*.json"))
|
json_files = list(openie_dir.glob("*.json"))
|
||||||
if not json_files:
|
if not json_files:
|
||||||
msg = (
|
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,import_openie 可能会因为找不到批次而失败。"
|
||||||
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
|
|
||||||
"import_openie 可能会因为找不到批次而失败。"
|
|
||||||
)
|
|
||||||
print(msg)
|
print(msg)
|
||||||
if non_interactive:
|
if non_interactive:
|
||||||
logger.error(
|
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
|
||||||
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
||||||
return cont == "y"
|
return cont == "y"
|
||||||
|
|
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
|
||||||
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
||||||
try:
|
try:
|
||||||
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
||||||
print(
|
print("[WARN] 当前配置 lpmm_knowledge.enable = false,刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
|
||||||
"[WARN] 当前配置 lpmm_knowledge.enable = false,"
|
|
||||||
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# 配置异常时不阻断主流程,仅忽略提示
|
# 配置异常时不阻断主流程,仅忽略提示
|
||||||
pass
|
pass
|
||||||
|
|
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||||
if action == "prepare_raw":
|
if action == "prepare_raw":
|
||||||
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
||||||
sha_list, raw_data = load_raw_data()
|
sha_list, raw_data = load_raw_data()
|
||||||
print(
|
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||||
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
|
|
||||||
f"去重后哈希数 {len(sha_list)}。"
|
|
||||||
)
|
|
||||||
elif action == "info_extract":
|
elif action == "info_extract":
|
||||||
if not _check_before_info_extract("--non-interactive" in extra_args):
|
if not _check_before_info_extract("--non-interactive" in extra_args):
|
||||||
print("已根据用户选择,取消执行信息提取。")
|
print("已根据用户选择,取消执行信息提取。")
|
||||||
|
|
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||||
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
||||||
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
||||||
sha_list, raw_data = load_raw_data()
|
sha_list, raw_data = load_raw_data()
|
||||||
print(
|
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||||
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
|
|
||||||
f"去重后哈希数 {len(sha_list)}。"
|
|
||||||
)
|
|
||||||
non_interactive = "--non-interactive" in extra_args
|
non_interactive = "--non-interactive" in extra_args
|
||||||
if not _check_before_info_extract(non_interactive):
|
if not _check_before_info_extract(non_interactive):
|
||||||
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
||||||
|
|
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 快速选项:按推荐方式清理所有相关实体/关系
|
# 快速选项:按推荐方式清理所有相关实体/关系
|
||||||
quick_all = input(
|
quick_all = (
|
||||||
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
|
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
|
||||||
).strip().lower()
|
)
|
||||||
if quick_all in ("", "y", "yes"):
|
if quick_all in ("", "y", "yes"):
|
||||||
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
||||||
else:
|
else:
|
||||||
|
|
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
|
||||||
|
|
||||||
def _interactive_build_batch_inspect_args() -> List[str]:
|
def _interactive_build_batch_inspect_args() -> List[str]:
|
||||||
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
||||||
path = _interactive_choose_openie_file(
|
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
|
||||||
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
|
|
||||||
)
|
|
||||||
if not path:
|
if not path:
|
||||||
return []
|
return []
|
||||||
return ["--openie-file", path]
|
return ["--openie-file", path]
|
||||||
|
|
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
|
||||||
|
|
||||||
def _interactive_build_test_args() -> List[str]:
|
def _interactive_build_test_args() -> List[str]:
|
||||||
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
||||||
print(
|
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
|
||||||
"\n[TEST] 你可以:\n"
|
|
||||||
"- 直接回车使用内置的默认测试用例;\n"
|
|
||||||
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
|
|
||||||
)
|
|
||||||
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
|
|
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
|
||||||
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
||||||
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
||||||
|
|
||||||
new_dim = input(
|
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
|
||||||
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
|
|
||||||
).strip()
|
|
||||||
if new_dim and not new_dim.isdigit():
|
if new_dim and not new_dim.isdigit():
|
||||||
print("输入的维度不是纯数字,已取消操作。")
|
print("输入的维度不是纯数字,已取消操作。")
|
||||||
return
|
return
|
||||||
|
|
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from maim_message import UserInfo, GroupInfo
|
||||||
|
|
||||||
logger = get_logger("test_memory_retrieval")
|
logger = get_logger("test_memory_retrieval")
|
||||||
|
|
||||||
|
|
||||||
# 使用 importlib 动态导入,避免循环导入问题
|
# 使用 importlib 动态导入,避免循环导入问题
|
||||||
def _import_memory_retrieval():
|
def _import_memory_retrieval():
|
||||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||||
|
|
@ -44,7 +45,7 @@ def _import_memory_retrieval():
|
||||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||||
if prompt_already_init and module_name in sys.modules:
|
if prompt_already_init and module_name in sys.modules:
|
||||||
existing_module = sys.modules[module_name]
|
existing_module = sys.modules[module_name]
|
||||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||||
return (
|
return (
|
||||||
existing_module.init_memory_retrieval_prompt,
|
existing_module.init_memory_retrieval_prompt,
|
||||||
existing_module._react_agent_solve_question,
|
existing_module._react_agent_solve_question,
|
||||||
|
|
@ -54,14 +55,14 @@ def _import_memory_retrieval():
|
||||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||||
if module_name in sys.modules:
|
if module_name in sys.modules:
|
||||||
existing_module = sys.modules[module_name]
|
existing_module = sys.modules[module_name]
|
||||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||||
# 模块部分初始化,移除它
|
# 模块部分初始化,移除它
|
||||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||||
del sys.modules[module_name]
|
del sys.modules[module_name]
|
||||||
# 清理可能相关的部分初始化模块
|
# 清理可能相关的部分初始化模块
|
||||||
keys_to_remove = []
|
keys_to_remove = []
|
||||||
for key in sys.modules.keys():
|
for key in sys.modules.keys():
|
||||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
||||||
keys_to_remove.append(key)
|
keys_to_remove.append(key)
|
||||||
for key in keys_to_remove:
|
for key in keys_to_remove:
|
||||||
try:
|
try:
|
||||||
|
|
@ -75,6 +76,7 @@ def _import_memory_retrieval():
|
||||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||||
import src.config.config
|
import src.config.config
|
||||||
import src.chat.utils.prompt_builder
|
import src.chat.utils.prompt_builder
|
||||||
|
|
||||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||||
# 如果它们已经导入,就确保它们完全初始化
|
# 如果它们已经导入,就确保它们完全初始化
|
||||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||||
|
|
@ -253,7 +255,7 @@ async def test_memory_retrieval(
|
||||||
包含测试结果的字典
|
包含测试结果的字典
|
||||||
"""
|
"""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print(f"[测试] 记忆检索测试")
|
print("[测试] 记忆检索测试")
|
||||||
print(f"[问题] {question}")
|
print(f"[问题] {question}")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
@ -267,6 +269,7 @@ async def test_memory_retrieval(
|
||||||
|
|
||||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
|
|
||||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||||
init_memory_retrieval_prompt()
|
init_memory_retrieval_prompt()
|
||||||
else:
|
else:
|
||||||
|
|
@ -284,7 +287,7 @@ async def test_memory_retrieval(
|
||||||
|
|
||||||
timeout = global_config.memory.agent_timeout_seconds
|
timeout = global_config.memory.agent_timeout_seconds
|
||||||
|
|
||||||
print(f"\n[配置]")
|
print("\n[配置]")
|
||||||
print(f" 最大迭代次数: {max_iterations}")
|
print(f" 最大迭代次数: {max_iterations}")
|
||||||
print(f" 超时时间: {timeout}秒")
|
print(f" 超时时间: {timeout}秒")
|
||||||
print(f" 聊天ID: {chat_id}")
|
print(f" 聊天ID: {chat_id}")
|
||||||
|
|
@ -321,26 +324,26 @@ async def test_memory_retrieval(
|
||||||
|
|
||||||
# 输出结果
|
# 输出结果
|
||||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||||
print(f"\n[结果]")
|
print("\n[结果]")
|
||||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||||
if found_answer and answer:
|
if found_answer and answer:
|
||||||
print(f" 答案: {answer}")
|
print(f" 答案: {answer}")
|
||||||
else:
|
else:
|
||||||
print(f" 答案: (未找到答案)")
|
print(" 答案: (未找到答案)")
|
||||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||||
print(f" 迭代次数: {len(thinking_steps)}")
|
print(f" 迭代次数: {len(thinking_steps)}")
|
||||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||||
|
|
||||||
print(f"\n[Token使用情况]")
|
print("\n[Token使用情况]")
|
||||||
print(f" 总请求数: {token_usage['request_count']}")
|
print(f" 总请求数: {token_usage['request_count']}")
|
||||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||||
|
|
||||||
if token_usage['model_usage']:
|
if token_usage["model_usage"]:
|
||||||
print(f"\n[按模型统计]")
|
print("\n[按模型统计]")
|
||||||
for model_name, usage in token_usage['model_usage'].items():
|
for model_name, usage in token_usage["model_usage"].items():
|
||||||
print(f" {model_name}:")
|
print(f" {model_name}:")
|
||||||
print(f" 请求数: {usage['request_count']}")
|
print(f" 请求数: {usage['request_count']}")
|
||||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||||
|
|
@ -348,7 +351,7 @@ async def test_memory_retrieval(
|
||||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||||
print(f" 成本: ${usage['cost']:.6f}")
|
print(f" 成本: ${usage['cost']:.6f}")
|
||||||
|
|
||||||
print(f"\n[迭代详情]")
|
print("\n[迭代详情]")
|
||||||
print(format_thinking_steps(thinking_steps))
|
print(format_thinking_steps(thinking_steps))
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
|
|
@ -444,4 +447,3 @@ def main() -> None:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -455,6 +455,7 @@ class ExpressionSelector:
|
||||||
expr_obj.save()
|
expr_obj.save()
|
||||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
expression_selector = ExpressionSelector()
|
expression_selector = ExpressionSelector()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|
||||||
class JargonExplainer:
|
class JargonExplainer:
|
||||||
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
from src.chat.utils.common_utils import TempMethodsExpression
|
from src.chat.utils.common_utils import TempMethodsExpression
|
||||||
|
|
@ -119,9 +118,7 @@ class MessageRecorder:
|
||||||
|
|
||||||
# 触发 expression_learner 和 jargon_miner 的处理
|
# 触发 expression_learner 和 jargon_miner 的处理
|
||||||
if self.enable_expression_learning:
|
if self.enable_expression_learning:
|
||||||
asyncio.create_task(
|
asyncio.create_task(self._trigger_expression_learning(messages))
|
||||||
self._trigger_expression_learning(messages)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||||
|
|
@ -130,9 +127,7 @@ class MessageRecorder:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 即使失败也保持时间戳更新,避免频繁重试
|
# 即使失败也保持时间戳更新,避免频繁重试
|
||||||
|
|
||||||
async def _trigger_expression_learning(
|
async def _trigger_expression_learning(self, messages: List[Any]) -> None:
|
||||||
self, messages: List[Any]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
触发 expression 学习,使用指定的消息列表
|
触发 expression 学习,使用指定的消息列表
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import time
|
import time
|
||||||
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
|
from typing import Tuple, Optional # 增加了 Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
|
@ -170,13 +170,10 @@ class ActionPlanner:
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
|
||||||
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
# --- 获取超时提示信息 ---
|
# --- 获取超时提示信息 ---
|
||||||
# (这部分逻辑不变)
|
# (这部分逻辑不变)
|
||||||
timeout_context = ""
|
timeout_context = ""
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class Conversation:
|
||||||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||||
"platform": msg.user_info.platform if msg.user_info else "",
|
"platform": msg.user_info.platform if msg.user_info else "",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
initial_messages_dict.append(msg_dict)
|
initial_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from src.common.logger import get_logger
|
||||||
from .chat_observer import ChatObserver
|
from .chat_observer import ChatObserver
|
||||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
import traceback # 导入 traceback 用于调试
|
import traceback # 导入 traceback 用于调试
|
||||||
|
|
||||||
logger = get_logger("observation_info")
|
logger = get_logger("observation_info")
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,7 @@ class GoalAnalyzer:
|
||||||
"""对话目标分析器"""
|
"""对话目标分析器"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str, private_name: str):
|
def __init__(self, stream_id: str, private_name: str):
|
||||||
self.llm = LLMRequest(
|
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||||
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.personality_info = self._get_personality_prompt()
|
self.personality_info = self._get_personality_prompt()
|
||||||
self.name = global_config.bot.nickname
|
self.name = global_config.bot.nickname
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from typing import List, Tuple, Dict, Any
|
from typing import List, Tuple, Dict, Any
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import model_config
|
||||||
from src.chat.message_receive.message import Message
|
|
||||||
from src.chat.knowledge import qa_manager
|
from src.chat.knowledge import qa_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
|
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
|
||||||
|
|
@ -16,9 +16,7 @@ class KnowledgeFetcher:
|
||||||
"""知识调取器"""
|
"""知识调取器"""
|
||||||
|
|
||||||
def __init__(self, private_name: str):
|
def __init__(self, private_name: str):
|
||||||
self.llm = LLMRequest(
|
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||||
model_set=model_config.model_task_config.utils
|
|
||||||
)
|
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
|
|
||||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,7 @@ class ReplyChecker:
|
||||||
"""回复检查器"""
|
"""回复检查器"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str, private_name: str):
|
def __init__(self, stream_id: str, private_name: str):
|
||||||
self.llm = LLMRequest(
|
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||||
model_set=model_config.model_task_config.utils,
|
|
||||||
request_type="reply_check"
|
|
||||||
)
|
|
||||||
self.personality_info = self._get_personality_prompt()
|
self.personality_info = self._get_personality_prompt()
|
||||||
self.name = global_config.bot.nickname
|
self.name = global_config.bot.nickname
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
|
|
|
||||||
|
|
@ -704,10 +704,7 @@ class BrainChatting:
|
||||||
|
|
||||||
# 等待指定时间,但可被新消息打断
|
# 等待指定时间,但可被新消息打断
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||||
self._new_message_event.wait(),
|
|
||||||
timeout=wait_seconds
|
|
||||||
)
|
|
||||||
# 如果事件被触发,说明有新消息到达
|
# 如果事件被触发,说明有新消息到达
|
||||||
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|
@ -731,7 +728,9 @@ class BrainChatting:
|
||||||
# 使用默认等待时间
|
# 使用默认等待时间
|
||||||
wait_seconds = 3
|
wait_seconds = 3
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)"
|
||||||
|
)
|
||||||
|
|
||||||
# 清除事件状态,准备等待新消息
|
# 清除事件状态,准备等待新消息
|
||||||
self._new_message_event.clear()
|
self._new_message_event.clear()
|
||||||
|
|
@ -749,10 +748,7 @@ class BrainChatting:
|
||||||
|
|
||||||
# 等待指定时间,但可被新消息打断
|
# 等待指定时间,但可被新消息打断
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||||
self._new_message_event.wait(),
|
|
||||||
timeout=wait_seconds
|
|
||||||
)
|
|
||||||
# 如果事件被触发,说明有新消息到达
|
# 如果事件被触发,说明有新消息到达
|
||||||
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|
|
||||||
|
|
@ -431,7 +431,9 @@ class BrainPlanner:
|
||||||
except Exception as req_e:
|
except Exception as req_e:
|
||||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||||
return extracted_reasoning, [
|
return (
|
||||||
|
extracted_reasoning,
|
||||||
|
[
|
||||||
ActionPlannerInfo(
|
ActionPlannerInfo(
|
||||||
action_type="complete_talk",
|
action_type="complete_talk",
|
||||||
reasoning=extracted_reasoning,
|
reasoning=extracted_reasoning,
|
||||||
|
|
@ -439,7 +441,11 @@ class BrainPlanner:
|
||||||
action_message=None,
|
action_message=None,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
)
|
)
|
||||||
], llm_content, llm_reasoning, llm_duration_ms
|
],
|
||||||
|
llm_content,
|
||||||
|
llm_reasoning,
|
||||||
|
llm_duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
# 解析LLM响应
|
# 解析LLM响应
|
||||||
if llm_content:
|
if llm_content:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Union, Dict, Any
|
from typing import List, Union
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from . import prompt_template
|
from . import prompt_template
|
||||||
|
|
@ -200,9 +200,7 @@ class IEProcess:
|
||||||
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
|
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
|
||||||
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
|
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
|
||||||
try:
|
try:
|
||||||
entities, triples = await asyncio.to_thread(
|
entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
|
||||||
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
|
|
||||||
)
|
|
||||||
|
|
||||||
if entities is not None:
|
if entities is not None:
|
||||||
results.append(
|
results.append(
|
||||||
|
|
|
||||||
|
|
@ -395,8 +395,7 @@ class KGManager:
|
||||||
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
||||||
if not appear_cnt or appear_cnt <= 0:
|
if not appear_cnt or appear_cnt <= 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,"
|
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,将使用 1.0 作为默认出现次数参与权重计算"
|
||||||
f"将使用 1.0 作为默认出现次数参与权重计算"
|
|
||||||
)
|
)
|
||||||
appear_cnt = 1.0
|
appear_cnt = 1.0
|
||||||
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
||||||
|
|
||||||
logger = get_logger("LPMM-Plugin-API")
|
logger = get_logger("LPMM-Plugin-API")
|
||||||
|
|
||||||
|
|
||||||
class LPMMOperations:
|
class LPMMOperations:
|
||||||
"""
|
"""
|
||||||
LPMM 内部操作接口。
|
LPMM 内部操作接口。
|
||||||
|
|
@ -20,9 +21,7 @@ class LPMMOperations:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
async def _run_cancellable_executor(
|
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||||
self, func: Callable, *args, **kwargs
|
|
||||||
) -> Any:
|
|
||||||
"""
|
"""
|
||||||
在线程池中执行可取消的同步操作。
|
在线程池中执行可取消的同步操作。
|
||||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
||||||
|
|
@ -79,7 +78,7 @@ class LPMMOperations:
|
||||||
# 1. 分段处理
|
# 1. 分段处理
|
||||||
if auto_split:
|
if auto_split:
|
||||||
# 自动按双换行符分割
|
# 自动按双换行符分割
|
||||||
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||||
else:
|
else:
|
||||||
# 不分割,作为完整一段
|
# 不分割,作为完整一段
|
||||||
text_stripped = text.strip()
|
text_stripped = text.strip()
|
||||||
|
|
@ -95,7 +94,9 @@ class LPMMOperations:
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
|
||||||
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
|
llm_ner = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||||
|
)
|
||||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||||
|
|
||||||
|
|
@ -128,25 +129,21 @@ class LPMMOperations:
|
||||||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
||||||
# store_new_data_set 会自动处理嵌入生成和存储
|
# store_new_data_set 会自动处理嵌入生成和存储
|
||||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||||
await self._run_cancellable_executor(
|
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
||||||
embed_mgr.store_new_data_set,
|
|
||||||
new_raw_paragraphs,
|
|
||||||
new_triple_list_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
||||||
await self._run_cancellable_executor(
|
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
||||||
kg_mgr.build_kg,
|
|
||||||
new_triple_list_data,
|
|
||||||
embed_mgr
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 持久化
|
# 4. 持久化
|
||||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||||
|
|
||||||
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"count": len(new_raw_paragraphs),
|
||||||
|
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
||||||
|
}
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
logger.warning("[Plugin API] 导入操作被用户中断")
|
||||||
|
|
@ -215,8 +212,7 @@ class LPMMOperations:
|
||||||
|
|
||||||
# a. 从向量库删除
|
# a. 从向量库删除
|
||||||
deleted_count, _ = await self._run_cancellable_executor(
|
deleted_count, _ = await self._run_cancellable_executor(
|
||||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
||||||
to_delete_keys
|
|
||||||
)
|
)
|
||||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||||
|
|
||||||
|
|
@ -224,10 +220,7 @@ class LPMMOperations:
|
||||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||||
delete_func = partial(
|
delete_func = partial(
|
||||||
kg_mgr.delete_paragraphs,
|
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||||
to_delete_hashes,
|
|
||||||
ent_hashes=None,
|
|
||||||
remove_orphan_entities=True
|
|
||||||
)
|
)
|
||||||
await self._run_cancellable_executor(delete_func)
|
await self._run_cancellable_executor(delete_func)
|
||||||
|
|
||||||
|
|
@ -237,7 +230,11 @@ class LPMMOperations:
|
||||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||||
|
|
||||||
match_type = "完整文段" if exact_match else "关键词"
|
match_type = "完整文段" if exact_match else "关键词"
|
||||||
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"deleted_count": deleted_count,
|
||||||
|
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
||||||
|
}
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
logger.warning("[Plugin API] 删除操作被用户中断")
|
||||||
|
|
@ -275,16 +272,14 @@ class LPMMOperations:
|
||||||
|
|
||||||
# 删除所有段落向量
|
# 删除所有段落向量
|
||||||
para_deleted, _ = await self._run_cancellable_executor(
|
para_deleted, _ = await self._run_cancellable_executor(
|
||||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
||||||
para_keys
|
|
||||||
)
|
)
|
||||||
embed_mgr.stored_pg_hashes.clear()
|
embed_mgr.stored_pg_hashes.clear()
|
||||||
|
|
||||||
# 删除所有实体向量
|
# 删除所有实体向量
|
||||||
if ent_keys:
|
if ent_keys:
|
||||||
ent_deleted, _ = await self._run_cancellable_executor(
|
ent_deleted, _ = await self._run_cancellable_executor(
|
||||||
embed_mgr.entities_embedding_store.delete_items,
|
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
||||||
ent_keys
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ent_deleted = 0
|
ent_deleted = 0
|
||||||
|
|
@ -292,8 +287,7 @@ class LPMMOperations:
|
||||||
# 删除所有关系向量
|
# 删除所有关系向量
|
||||||
if rel_keys:
|
if rel_keys:
|
||||||
rel_deleted, _ = await self._run_cancellable_executor(
|
rel_deleted, _ = await self._run_cancellable_executor(
|
||||||
embed_mgr.relation_embedding_store.delete_items,
|
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
||||||
rel_keys
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rel_deleted = 0
|
rel_deleted = 0
|
||||||
|
|
@ -341,15 +335,13 @@ class LPMMOperations:
|
||||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||||
delete_func = partial(
|
delete_func = partial(
|
||||||
kg_mgr.delete_paragraphs,
|
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||||
all_pg_hashes,
|
|
||||||
ent_hashes=None,
|
|
||||||
remove_orphan_entities=True
|
|
||||||
)
|
)
|
||||||
await self._run_cancellable_executor(delete_func)
|
await self._run_cancellable_executor(delete_func)
|
||||||
|
|
||||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
||||||
from quick_algo import di_graph
|
from quick_algo import di_graph
|
||||||
|
|
||||||
kg_mgr.graph = di_graph.DiGraph()
|
kg_mgr.graph = di_graph.DiGraph()
|
||||||
kg_mgr.stored_paragraph_hashes.clear()
|
kg_mgr.stored_paragraph_hashes.clear()
|
||||||
kg_mgr.ent_appear_cnt.clear()
|
kg_mgr.ent_appear_cnt.clear()
|
||||||
|
|
@ -373,7 +365,7 @@ class LPMMOperations:
|
||||||
"stats": {
|
"stats": {
|
||||||
"before": before_stats,
|
"before": before_stats,
|
||||||
"after": after_stats,
|
"after": after_stats,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|
@ -383,6 +375,6 @@ class LPMMOperations:
|
||||||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
# 内部使用的单例
|
# 内部使用的单例
|
||||||
lpmm_ops = LPMMOperations()
|
lpmm_ops = LPMMOperations()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,4 +136,3 @@ class PlanReplyLogger:
|
||||||
return str(value)
|
return str(value)
|
||||||
# Fallback to string for other complex types
|
# Fallback to string for other complex types
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
|
|
||||||
# 如果未开启 API Server,直接跳过 Fallback
|
# 如果未开启 API Server,直接跳过 Fallback
|
||||||
if not global_config.maim_message.enable_api_server:
|
if not global_config.maim_message.enable_api_server:
|
||||||
logger.debug(f"[API Server Fallback] API Server未开启,跳过fallback")
|
logger.debug("[API Server Fallback] API Server未开启,跳过fallback")
|
||||||
if legacy_exception:
|
if legacy_exception:
|
||||||
raise legacy_exception
|
raise legacy_exception
|
||||||
return False
|
return False
|
||||||
|
|
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
extra_server = getattr(global_api, "extra_server", None)
|
extra_server = getattr(global_api, "extra_server", None)
|
||||||
|
|
||||||
if not extra_server:
|
if not extra_server:
|
||||||
logger.warning(f"[API Server Fallback] extra_server不存在")
|
logger.warning("[API Server Fallback] extra_server不存在")
|
||||||
if legacy_exception:
|
if legacy_exception:
|
||||||
raise legacy_exception
|
raise legacy_exception
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not extra_server.is_running():
|
if not extra_server.is_running():
|
||||||
logger.warning(f"[API Server Fallback] extra_server未运行")
|
logger.warning("[API Server Fallback] extra_server未运行")
|
||||||
if legacy_exception:
|
if legacy_exception:
|
||||||
raise legacy_exception
|
raise legacy_exception
|
||||||
return False
|
return False
|
||||||
|
|
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||||
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...")
|
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
|
||||||
results = await extra_server.send_message(api_message)
|
results = await extra_server.send_message(api_message)
|
||||||
logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ logger = get_logger("planner")
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
class ActionPlanner:
|
class ActionPlanner:
|
||||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
|
|
@ -111,20 +112,29 @@ class ActionPlanner:
|
||||||
|
|
||||||
# 替换 [picid:xxx] 为 [图片:描述]
|
# 替换 [picid:xxx] 为 [图片:描述]
|
||||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||||
|
|
||||||
def replace_pic_id(pic_match: re.Match) -> str:
|
def replace_pic_id(pic_match: re.Match) -> str:
|
||||||
pic_id = pic_match.group(1)
|
pic_id = pic_match.group(1)
|
||||||
description = translate_pid_to_description(pic_id)
|
description = translate_pid_to_description(pic_id)
|
||||||
return f"[图片:{description}]"
|
return f"[图片:{description}]"
|
||||||
|
|
||||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||||
|
|
||||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||||
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
|
platform = (
|
||||||
|
getattr(message, "user_info", None)
|
||||||
|
and message.user_info.platform
|
||||||
|
or getattr(message, "chat_info", None)
|
||||||
|
and message.chat_info.platform
|
||||||
|
or "qq"
|
||||||
|
)
|
||||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||||
|
|
||||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||||
# 这里匹配到的应该都是单独的格式
|
# 这里匹配到的应该都是单独的格式
|
||||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||||
|
|
||||||
def replace_user_ref(user_match: re.Match) -> str:
|
def replace_user_ref(user_match: re.Match) -> str:
|
||||||
user_name = user_match.group(1)
|
user_name = user_match.group(1)
|
||||||
user_id = user_match.group(2)
|
user_id = user_match.group(2)
|
||||||
|
|
@ -137,6 +147,7 @@ class ActionPlanner:
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果解析失败,使用原始昵称
|
# 如果解析失败,使用原始昵称
|
||||||
return user_name
|
return user_name
|
||||||
|
|
||||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||||
|
|
||||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||||
|
|
@ -306,9 +317,7 @@ class ActionPlanner:
|
||||||
|
|
||||||
return merged_words
|
return merged_words
|
||||||
|
|
||||||
def _process_unknown_words_cache(
|
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
|
||||||
self, actions: List[ActionPlannerInfo]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
处理黑话缓存逻辑:
|
处理黑话缓存逻辑:
|
||||||
1. 检查是否有 reply action 提取了 unknown_words
|
1. 检查是否有 reply action 提取了 unknown_words
|
||||||
|
|
@ -442,7 +451,11 @@ class ActionPlanner:
|
||||||
# 检查是否已经有回复该消息的 action
|
# 检查是否已经有回复该消息的 action
|
||||||
has_reply_to_force_message = False
|
has_reply_to_force_message = False
|
||||||
for action in actions:
|
for action in actions:
|
||||||
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
|
if (
|
||||||
|
action.action_type == "reply"
|
||||||
|
and action.action_message
|
||||||
|
and action.action_message.message_id == force_reply_message.message_id
|
||||||
|
):
|
||||||
has_reply_to_force_message = True
|
has_reply_to_force_message = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -577,10 +590,11 @@ class ActionPlanner:
|
||||||
if global_config.chat.think_mode == "classic":
|
if global_config.chat.think_mode == "classic":
|
||||||
reply_action_example = ""
|
reply_action_example = ""
|
||||||
if global_config.chat.llm_quote:
|
if global_config.chat.llm_quote:
|
||||||
reply_action_example += "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
|
||||||
reply_action_example += (
|
reply_action_example += (
|
||||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
"5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||||
'"unknown_words":["词语1","词语2"]'
|
)
|
||||||
|
reply_action_example += (
|
||||||
|
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
|
||||||
)
|
)
|
||||||
if global_config.chat.llm_quote:
|
if global_config.chat.llm_quote:
|
||||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||||
|
|
@ -590,7 +604,9 @@ class ActionPlanner:
|
||||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||||
)
|
)
|
||||||
if global_config.chat.llm_quote:
|
if global_config.chat.llm_quote:
|
||||||
reply_action_example += "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
reply_action_example += (
|
||||||
|
"6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||||
|
)
|
||||||
reply_action_example += (
|
reply_action_example += (
|
||||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||||
'"target_message_id":"消息id(m+数字)", '
|
'"target_message_id":"消息id(m+数字)", '
|
||||||
|
|
@ -741,7 +757,9 @@ class ActionPlanner:
|
||||||
|
|
||||||
except Exception as req_e:
|
except Exception as req_e:
|
||||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||||
return f"LLM 请求失败,模型出现问题: {req_e}", [
|
return (
|
||||||
|
f"LLM 请求失败,模型出现问题: {req_e}",
|
||||||
|
[
|
||||||
ActionPlannerInfo(
|
ActionPlannerInfo(
|
||||||
action_type="no_reply",
|
action_type="no_reply",
|
||||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||||
|
|
@ -749,7 +767,11 @@ class ActionPlanner:
|
||||||
action_message=None,
|
action_message=None,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
)
|
)
|
||||||
], llm_content, llm_reasoning, llm_duration_ms
|
],
|
||||||
|
llm_content,
|
||||||
|
llm_reasoning,
|
||||||
|
llm_duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
# 解析LLM响应
|
# 解析LLM响应
|
||||||
extracted_reasoning = ""
|
extracted_reasoning = ""
|
||||||
|
|
|
||||||
|
|
@ -1071,7 +1071,6 @@ class DefaultReplyer:
|
||||||
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
|
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
|
||||||
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
|
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
|
||||||
|
|
||||||
|
|
||||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||||
reply_style = global_config.personality.reply_style
|
reply_style = global_config.personality.reply_style
|
||||||
multi_styles = global_config.personality.multiple_reply_style
|
multi_styles = global_config.personality.multiple_reply_style
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
|
||||||
)
|
)
|
||||||
from src.bw_learner.expression_selector import expression_selector
|
from src.bw_learner.expression_selector import expression_selector
|
||||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||||
|
|
||||||
# from src.memory_system.memory_activator import MemoryActivator
|
# from src.memory_system.memory_activator import MemoryActivator
|
||||||
from src.person_info.person_info import Person, is_person_known
|
from src.person_info.person_info import Person, is_person_known
|
||||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("common_utils")
|
logger = get_logger("common_utils")
|
||||||
|
|
||||||
|
|
||||||
class TempMethodsExpression:
|
class TempMethodsExpression:
|
||||||
"""用于临时存放一些方法的类"""
|
"""用于临时存放一些方法的类"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
|
||||||
|
|
||||||
from . import BaseDatabaseDataModel
|
from . import BaseDatabaseDataModel
|
||||||
|
|
||||||
|
|
||||||
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||||
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,8 +96,12 @@ class Images(SQLModel, table=True):
|
||||||
|
|
||||||
no_file_flag: bool = Field(default=False) # 文件不存在标记,如果为True表示文件已经不存在,仅保留描述字段
|
no_file_flag: bool = Field(default=False) # 文件不存在标记,如果为True表示文件已经不存在,仅保留描述字段
|
||||||
|
|
||||||
record_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间(数据库记录被创建的时间)
|
record_time: datetime = Field(
|
||||||
register_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 注册时间(被注册为可用表情包的时间)
|
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||||
|
) # 记录时间(数据库记录被创建的时间)
|
||||||
|
register_time: Optional[datetime] = Field(
|
||||||
|
default=None, sa_column=Column(DateTime, nullable=True)
|
||||||
|
) # 注册时间(被注册为可用表情包的时间)
|
||||||
last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间
|
last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间
|
||||||
|
|
||||||
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
||||||
|
|
@ -171,7 +175,9 @@ class Expression(SQLModel, table=True):
|
||||||
|
|
||||||
content_list: str # 内容列表,JSON格式存储
|
content_list: str # 内容列表,JSON格式存储
|
||||||
count: int = Field(default=0) # 使用次数
|
count: int = Field(default=0) # 使用次数
|
||||||
last_active_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 上次使用时间
|
last_active_time: datetime = Field(
|
||||||
|
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||||
|
) # 上次使用时间
|
||||||
create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间
|
create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间
|
||||||
session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID,区分是否为全局表达方式
|
session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID,区分是否为全局表达方式
|
||||||
|
|
||||||
|
|
@ -232,8 +238,12 @@ class ThinkingQuestion(SQLModel, table=True):
|
||||||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
||||||
|
|
||||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
||||||
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
|
created_timestamp: datetime = Field(
|
||||||
updated_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后更新时间
|
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||||
|
) # 创建时间
|
||||||
|
updated_timestamp: datetime = Field(
|
||||||
|
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||||
|
) # 最后更新时间
|
||||||
|
|
||||||
|
|
||||||
class BinaryData(SQLModel, table=True):
|
class BinaryData(SQLModel, table=True):
|
||||||
|
|
@ -272,7 +282,9 @@ class PersonInfo(SQLModel, table=True):
|
||||||
|
|
||||||
# 认识次数和时间
|
# 认识次数和时间
|
||||||
know_counts: int = Field(default=0) # 认识次数
|
know_counts: int = Field(default=0) # 认识次数
|
||||||
first_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 首次认识时间
|
first_known_time: Optional[datetime] = Field(
|
||||||
|
default=None, sa_column=Column(DateTime, nullable=True)
|
||||||
|
) # 首次认识时间
|
||||||
last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间
|
last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -285,8 +297,12 @@ class ChatSession(SQLModel, table=True):
|
||||||
|
|
||||||
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
|
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
|
||||||
|
|
||||||
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
|
created_timestamp: datetime = Field(
|
||||||
last_active_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后活跃时间
|
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||||
|
) # 创建时间
|
||||||
|
last_active_timestamp: datetime = Field(
|
||||||
|
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||||
|
) # 最后活跃时间
|
||||||
|
|
||||||
# 身份元数据
|
# 身份元数据
|
||||||
user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID
|
user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID
|
||||||
|
|
|
||||||
|
|
@ -221,5 +221,7 @@ if not supports_truecolor():
|
||||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||||
else:
|
else:
|
||||||
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
|
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
|
||||||
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold)
|
escape_str = rgb_pair_to_ansi_truecolor(
|
||||||
|
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
|
||||||
|
)
|
||||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||||
|
|
@ -9,6 +9,7 @@ from .server import get_global_server
|
||||||
|
|
||||||
global_api = None
|
global_api = None
|
||||||
|
|
||||||
|
|
||||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||||
"""获取全局MessageServer实例"""
|
"""获取全局MessageServer实例"""
|
||||||
global global_api
|
global global_api
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
|
||||||
|
|
||||||
logger = get_logger("file_utils")
|
logger = get_logger("file_utils")
|
||||||
|
|
||||||
|
|
||||||
class FileUtils:
|
class FileUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_binary_to_file(file_path: Path, data: bytes):
|
def save_binary_to_file(file_path: Path, data: bytes):
|
||||||
|
|
|
||||||
|
|
@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||||
|
|
||||||
reason = ",".join(reasons)
|
reason = ",".join(reasons)
|
||||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,6 @@ async def generate_dream_summary(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||||
tool_results_map: dict[str, str] = {}
|
tool_results_map: dict[str, str] = {}
|
||||||
for msg in conversation_messages:
|
for msg in conversation_messages:
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,3 @@ dream agent 工具实现模块。
|
||||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
|
||||||
return f"create_chat_history 执行失败: {e}"
|
return f"create_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return create_chat_history
|
return create_chat_history
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
||||||
return f"delete_chat_history 执行失败: {e}"
|
return f"delete_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return delete_chat_history
|
return delete_chat_history
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
||||||
return f"delete_jargon 执行失败: {e}"
|
return f"delete_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return delete_jargon
|
return delete_jargon
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
return finish_maintenance
|
return finish_maintenance
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
||||||
return f"get_chat_history_detail 执行失败: {e}"
|
return f"get_chat_history_detail 执行失败: {e}"
|
||||||
|
|
||||||
return get_chat_history_detail
|
return get_chat_history_detail
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
|
||||||
return f"search_chat_history 执行失败: {e}"
|
return f"search_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return search_chat_history
|
return search_chat_history
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
||||||
return f"update_chat_history 执行失败: {e}"
|
return f"update_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return update_chat_history
|
return update_chat_history
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
||||||
return f"update_jargon 执行失败: {e}"
|
return f"update_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return update_jargon
|
return update_jargon
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -459,7 +459,7 @@ def _default_normal_response_parser(
|
||||||
# 此时为了调试方便,建议打印出 arguments 的类型
|
# 此时为了调试方便,建议打印出 arguments 的类型
|
||||||
raise RespParseException(
|
raise RespParseException(
|
||||||
resp,
|
resp,
|
||||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
|
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
|
||||||
)
|
)
|
||||||
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import time
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,7 @@ async def search_chat_history(
|
||||||
if start_time:
|
if start_time:
|
||||||
try:
|
try:
|
||||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||||
|
|
||||||
start_timestamp = parse_datetime_to_timestamp(start_time)
|
start_timestamp = parse_datetime_to_timestamp(start_time)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||||
|
|
@ -120,6 +121,7 @@ async def search_chat_history(
|
||||||
if end_time:
|
if end_time:
|
||||||
try:
|
try:
|
||||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||||
|
|
||||||
end_timestamp = parse_datetime_to_timestamp(end_time)
|
end_timestamp = parse_datetime_to_timestamp(end_time)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||||
|
|
@ -165,16 +167,13 @@ async def search_chat_history(
|
||||||
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
||||||
query = query.where(
|
query = query.where(
|
||||||
(
|
(
|
||||||
(ChatHistory.start_time >= start_timestamp)
|
(ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
|
||||||
& (ChatHistory.start_time <= end_timestamp)
|
|
||||||
) # 记录开始时间在查询时间段内
|
) # 记录开始时间在查询时间段内
|
||||||
| (
|
| (
|
||||||
(ChatHistory.end_time >= start_timestamp)
|
(ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
|
||||||
& (ChatHistory.end_time <= end_timestamp)
|
|
||||||
) # 记录结束时间在查询时间段内
|
) # 记录结束时间在查询时间段内
|
||||||
| (
|
| (
|
||||||
(ChatHistory.start_time <= start_timestamp)
|
(ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
|
||||||
& (ChatHistory.end_time >= end_timestamp)
|
|
||||||
) # 记录完全包含查询时间段
|
) # 记录完全包含查询时间段
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -76,4 +76,3 @@ def register_tool():
|
||||||
],
|
],
|
||||||
execute_func=query_words,
|
execute_func=query_words,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -558,7 +558,9 @@ class PluginBase(ABC):
|
||||||
if version_spec:
|
if version_spec:
|
||||||
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
||||||
if not is_ok:
|
if not is_ok:
|
||||||
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
|
logger.error(
|
||||||
|
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if min_version or max_version:
|
if min_version or max_version:
|
||||||
|
|
|
||||||
|
|
@ -751,9 +751,7 @@ class ComponentRegistry:
|
||||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||||
"workflow_steps": workflow_step_count,
|
"workflow_steps": workflow_step_count,
|
||||||
"enabled_workflow_steps": enabled_workflow_step_count,
|
"enabled_workflow_steps": enabled_workflow_step_count,
|
||||||
"workflow_steps_by_stage": {
|
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
|
||||||
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -429,7 +429,9 @@ class PluginManager:
|
||||||
|
|
||||||
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
||||||
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
||||||
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
|
indegree: Dict[str, int] = {
|
||||||
|
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
|
||||||
|
}
|
||||||
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
||||||
|
|
||||||
for plugin_name, dependencies in dependency_graph.items():
|
for plugin_name, dependencies in dependency_graph.items():
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,9 @@ class PluginServiceRegistry:
|
||||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||||
return self._service_handlers.get(full_name) if full_name else None
|
return self._service_handlers.get(full_name) if full_name else None
|
||||||
|
|
||||||
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
def list_services(
|
||||||
|
self, plugin_name: Optional[str] = None, enabled_only: bool = False
|
||||||
|
) -> Dict[str, PluginServiceInfo]:
|
||||||
"""列出插件服务。"""
|
"""列出插件服务。"""
|
||||||
services = self._services.copy()
|
services = self._services.copy()
|
||||||
if plugin_name:
|
if plugin_name:
|
||||||
|
|
@ -120,7 +122,12 @@ class PluginServiceRegistry:
|
||||||
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
||||||
raise ValueError(f"插件服务未注册: {target_name}")
|
raise ValueError(f"插件服务未注册: {target_name}")
|
||||||
|
|
||||||
if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin:
|
if (
|
||||||
|
"." not in service_name
|
||||||
|
and plugin_name is None
|
||||||
|
and caller_plugin
|
||||||
|
and service_info.plugin_name != caller_plugin
|
||||||
|
):
|
||||||
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
|
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
|
||||||
|
|
||||||
if not self._is_call_authorized(service_info, caller_plugin):
|
if not self._is_call_authorized(service_info, caller_plugin):
|
||||||
|
|
@ -153,7 +160,9 @@ class PluginServiceRegistry:
|
||||||
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
|
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
|
||||||
return "*" in allowed_callers or caller_plugin in allowed_callers
|
return "*" in allowed_callers or caller_plugin in allowed_callers
|
||||||
|
|
||||||
def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
|
def _validate_input_contract(
|
||||||
|
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""校验服务入参契约。"""
|
"""校验服务入参契约。"""
|
||||||
schema = service_info.params_schema
|
schema = service_info.params_schema
|
||||||
if not schema:
|
if not schema:
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,9 @@ class WorkflowEngine:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
||||||
workflow_context.errors.append(f"{stage_key}: {e}")
|
workflow_context.errors.append(f"{stage_key}: {e}")
|
||||||
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
|
||||||
|
)
|
||||||
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
||||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||||
return (
|
return (
|
||||||
|
|
@ -195,7 +197,9 @@ class WorkflowEngine:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||||
context.errors.append(f"{step_info.full_name}: {e}")
|
context.errors.append(f"{step_info.full_name}: {e}")
|
||||||
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
|
||||||
|
)
|
||||||
return WorkflowStepResult(
|
return WorkflowStepResult(
|
||||||
status="failed",
|
status="failed",
|
||||||
return_message=str(e),
|
return_message=str(e),
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||||
3. 详情按需加载
|
3. 详情按需加载
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
|
||||||
|
|
||||||
class ChatSummary(BaseModel):
|
class ChatSummary(BaseModel):
|
||||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||||
|
|
||||||
chat_id: str
|
chat_id: str
|
||||||
plan_count: int
|
plan_count: int
|
||||||
latest_timestamp: float
|
latest_timestamp: float
|
||||||
|
|
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
|
||||||
|
|
||||||
class PlanLogSummary(BaseModel):
|
class PlanLogSummary(BaseModel):
|
||||||
"""规划日志摘要"""
|
"""规划日志摘要"""
|
||||||
|
|
||||||
chat_id: str
|
chat_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
filename: str
|
filename: str
|
||||||
|
|
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
|
||||||
|
|
||||||
class PlanLogDetail(BaseModel):
|
class PlanLogDetail(BaseModel):
|
||||||
"""规划日志详情"""
|
"""规划日志详情"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
chat_id: str
|
chat_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
|
|
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
|
||||||
|
|
||||||
class PlannerOverview(BaseModel):
|
class PlannerOverview(BaseModel):
|
||||||
"""规划器总览 - 轻量级统计"""
|
"""规划器总览 - 轻量级统计"""
|
||||||
|
|
||||||
total_chats: int
|
total_chats: int
|
||||||
total_plans: int
|
total_plans: int
|
||||||
chats: List[ChatSummary]
|
chats: List[ChatSummary]
|
||||||
|
|
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
|
||||||
|
|
||||||
class PaginatedChatLogs(BaseModel):
|
class PaginatedChatLogs(BaseModel):
|
||||||
"""分页的聊天日志列表"""
|
"""分页的聊天日志列表"""
|
||||||
|
|
||||||
data: List[PlanLogSummary]
|
data: List[PlanLogSummary]
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
|
|
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
|
||||||
def parse_timestamp_from_filename(filename: str) -> float:
|
def parse_timestamp_from_filename(filename: str) -> float:
|
||||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||||
try:
|
try:
|
||||||
timestamp_str = filename.split('_')[0]
|
timestamp_str = filename.split("_")[0]
|
||||||
# 时间戳是毫秒级,需要转换为秒
|
# 时间戳是毫秒级,需要转换为秒
|
||||||
return float(timestamp_str) / 1000
|
return float(timestamp_str) / 1000
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
|
|
@ -106,21 +112,19 @@ async def get_planner_overview():
|
||||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||||
|
|
||||||
chats.append(ChatSummary(
|
chats.append(
|
||||||
|
ChatSummary(
|
||||||
chat_id=chat_dir.name,
|
chat_id=chat_dir.name,
|
||||||
plan_count=plan_count,
|
plan_count=plan_count,
|
||||||
latest_timestamp=latest_timestamp,
|
latest_timestamp=latest_timestamp,
|
||||||
latest_filename=latest_file.name
|
latest_filename=latest_file.name,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 按最新时间戳排序
|
# 按最新时间戳排序
|
||||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||||
|
|
||||||
return PlannerOverview(
|
return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
|
||||||
total_chats=len(chats),
|
|
||||||
total_plans=total_plans,
|
|
||||||
chats=chats
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||||
|
|
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(20, ge=1, le=100),
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定聊天的规划日志列表(分页)
|
获取指定聊天的规划日志列表(分页)
|
||||||
|
|
@ -137,9 +141,7 @@ async def get_chat_plan_logs(
|
||||||
"""
|
"""
|
||||||
chat_dir = PLAN_LOG_DIR / chat_id
|
chat_dir = PLAN_LOG_DIR / chat_id
|
||||||
if not chat_dir.exists():
|
if not chat_dir.exists():
|
||||||
return PaginatedChatLogs(
|
return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 先获取所有文件并按时间戳排序
|
# 先获取所有文件并按时间戳排序
|
||||||
json_files = list(chat_dir.glob("*.json"))
|
json_files = list(chat_dir.glob("*.json"))
|
||||||
|
|
@ -151,9 +153,9 @@ async def get_chat_plan_logs(
|
||||||
filtered_files = []
|
filtered_files = []
|
||||||
for log_file in json_files:
|
for log_file in json_files:
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
prompt = data.get('prompt', '')
|
prompt = data.get("prompt", "")
|
||||||
if search_lower in prompt.lower():
|
if search_lower in prompt.lower():
|
||||||
filtered_files.append(log_file)
|
filtered_files.append(log_file)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -169,24 +171,27 @@ async def get_chat_plan_logs(
|
||||||
logs = []
|
logs = []
|
||||||
for log_file in page_files:
|
for log_file in page_files:
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
reasoning = data.get('reasoning', '')
|
reasoning = data.get("reasoning", "")
|
||||||
actions = data.get('actions', [])
|
actions = data.get("actions", [])
|
||||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
|
||||||
logs.append(PlanLogSummary(
|
logs.append(
|
||||||
chat_id=data.get('chat_id', chat_id),
|
PlanLogSummary(
|
||||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
chat_id=data.get("chat_id", chat_id),
|
||||||
|
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||||
filename=log_file.name,
|
filename=log_file.name,
|
||||||
action_count=len(actions),
|
action_count=len(actions),
|
||||||
action_types=action_types,
|
action_types=action_types,
|
||||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
|
||||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
|
||||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
reasoning_preview=reasoning[:100] if reasoning else "",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 文件读取失败时使用文件名信息
|
# 文件读取失败时使用文件名信息
|
||||||
logs.append(PlanLogSummary(
|
logs.append(
|
||||||
|
PlanLogSummary(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||||
filename=log_file.name,
|
filename=log_file.name,
|
||||||
|
|
@ -194,16 +199,11 @@ async def get_chat_plan_logs(
|
||||||
action_types=[],
|
action_types=[],
|
||||||
total_plan_ms=0,
|
total_plan_ms=0,
|
||||||
llm_duration_ms=0,
|
llm_duration_ms=0,
|
||||||
reasoning_preview='[读取失败]'
|
reasoning_preview="[读取失败]",
|
||||||
))
|
|
||||||
|
|
||||||
return PaginatedChatLogs(
|
|
||||||
data=logs,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
chat_id=chat_id
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||||
|
|
@ -214,7 +214,7 @@ async def get_log_detail(chat_id: str, filename: str):
|
||||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return PlanLogDetail(**data)
|
return PlanLogDetail(**data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -223,6 +223,7 @@ async def get_log_detail(chat_id: str, filename: str):
|
||||||
|
|
||||||
# ========== 兼容旧接口 ==========
|
# ========== 兼容旧接口 ==========
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats")
|
@router.get("/stats")
|
||||||
async def get_planner_stats():
|
async def get_planner_stats():
|
||||||
"""获取规划器统计信息 - 兼容旧接口"""
|
"""获取规划器统计信息 - 兼容旧接口"""
|
||||||
|
|
@ -246,7 +247,7 @@ async def get_planner_stats():
|
||||||
"total_plans": overview.total_plans,
|
"total_plans": overview.total_plans,
|
||||||
"avg_plan_time_ms": 0,
|
"avg_plan_time_ms": 0,
|
||||||
"avg_llm_time_ms": 0,
|
"avg_llm_time_ms": 0,
|
||||||
"recent_plans": recent_plans
|
"recent_plans": recent_plans,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -258,10 +259,7 @@ async def get_chat_list():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all-logs")
|
@router.get("/all-logs")
|
||||||
async def get_all_logs(
|
async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
page_size: int = Query(20, ge=1, le=100)
|
|
||||||
):
|
|
||||||
"""获取所有规划日志 - 兼容旧接口"""
|
"""获取所有规划日志 - 兼容旧接口"""
|
||||||
if not PLAN_LOG_DIR.exists():
|
if not PLAN_LOG_DIR.exists():
|
||||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||||
|
|
@ -283,18 +281,20 @@ async def get_all_logs(
|
||||||
logs = []
|
logs = []
|
||||||
for chat_id, log_file in page_files:
|
for chat_id, log_file in page_files:
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
reasoning = data.get('reasoning', '')
|
reasoning = data.get("reasoning", "")
|
||||||
logs.append({
|
logs.append(
|
||||||
"chat_id": data.get('chat_id', chat_id),
|
{
|
||||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
"chat_id": data.get("chat_id", chat_id),
|
||||||
|
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||||
"filename": log_file.name,
|
"filename": log_file.name,
|
||||||
"action_count": len(data.get('actions', [])),
|
"action_count": len(data.get("actions", [])),
|
||||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
"total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
|
||||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
"llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
|
||||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
"reasoning_preview": reasoning[:100] if reasoning else "",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||||
3. 详情按需加载
|
3. 详情按需加载
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
|
||||||
|
|
||||||
class ReplierChatSummary(BaseModel):
|
class ReplierChatSummary(BaseModel):
|
||||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||||
|
|
||||||
chat_id: str
|
chat_id: str
|
||||||
reply_count: int
|
reply_count: int
|
||||||
latest_timestamp: float
|
latest_timestamp: float
|
||||||
|
|
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
|
||||||
|
|
||||||
class ReplyLogSummary(BaseModel):
|
class ReplyLogSummary(BaseModel):
|
||||||
"""回复日志摘要"""
|
"""回复日志摘要"""
|
||||||
|
|
||||||
chat_id: str
|
chat_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
filename: str
|
filename: str
|
||||||
|
|
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
|
||||||
|
|
||||||
class ReplyLogDetail(BaseModel):
|
class ReplyLogDetail(BaseModel):
|
||||||
"""回复日志详情"""
|
"""回复日志详情"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
chat_id: str
|
chat_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
|
|
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
|
||||||
|
|
||||||
class ReplierOverview(BaseModel):
|
class ReplierOverview(BaseModel):
|
||||||
"""回复器总览 - 轻量级统计"""
|
"""回复器总览 - 轻量级统计"""
|
||||||
|
|
||||||
total_chats: int
|
total_chats: int
|
||||||
total_replies: int
|
total_replies: int
|
||||||
chats: List[ReplierChatSummary]
|
chats: List[ReplierChatSummary]
|
||||||
|
|
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
|
||||||
|
|
||||||
class PaginatedReplyLogs(BaseModel):
|
class PaginatedReplyLogs(BaseModel):
|
||||||
"""分页的回复日志列表"""
|
"""分页的回复日志列表"""
|
||||||
|
|
||||||
data: List[ReplyLogSummary]
|
data: List[ReplyLogSummary]
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
|
|
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
|
||||||
def parse_timestamp_from_filename(filename: str) -> float:
|
def parse_timestamp_from_filename(filename: str) -> float:
|
||||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||||
try:
|
try:
|
||||||
timestamp_str = filename.split('_')[0]
|
timestamp_str = filename.split("_")[0]
|
||||||
# 时间戳是毫秒级,需要转换为秒
|
# 时间戳是毫秒级,需要转换为秒
|
||||||
return float(timestamp_str) / 1000
|
return float(timestamp_str) / 1000
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
|
|
@ -109,21 +115,19 @@ async def get_replier_overview():
|
||||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||||
|
|
||||||
chats.append(ReplierChatSummary(
|
chats.append(
|
||||||
|
ReplierChatSummary(
|
||||||
chat_id=chat_dir.name,
|
chat_id=chat_dir.name,
|
||||||
reply_count=reply_count,
|
reply_count=reply_count,
|
||||||
latest_timestamp=latest_timestamp,
|
latest_timestamp=latest_timestamp,
|
||||||
latest_filename=latest_file.name
|
latest_filename=latest_file.name,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 按最新时间戳排序
|
# 按最新时间戳排序
|
||||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||||
|
|
||||||
return ReplierOverview(
|
return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
|
||||||
total_chats=len(chats),
|
|
||||||
total_replies=total_replies,
|
|
||||||
chats=chats
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||||
|
|
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(20, ge=1, le=100),
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定聊天的回复日志列表(分页)
|
获取指定聊天的回复日志列表(分页)
|
||||||
|
|
@ -140,9 +144,7 @@ async def get_chat_reply_logs(
|
||||||
"""
|
"""
|
||||||
chat_dir = REPLY_LOG_DIR / chat_id
|
chat_dir = REPLY_LOG_DIR / chat_id
|
||||||
if not chat_dir.exists():
|
if not chat_dir.exists():
|
||||||
return PaginatedReplyLogs(
|
return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 先获取所有文件并按时间戳排序
|
# 先获取所有文件并按时间戳排序
|
||||||
json_files = list(chat_dir.glob("*.json"))
|
json_files = list(chat_dir.glob("*.json"))
|
||||||
|
|
@ -154,9 +156,9 @@ async def get_chat_reply_logs(
|
||||||
filtered_files = []
|
filtered_files = []
|
||||||
for log_file in json_files:
|
for log_file in json_files:
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
prompt = data.get('prompt', '')
|
prompt = data.get("prompt", "")
|
||||||
if search_lower in prompt.lower():
|
if search_lower in prompt.lower():
|
||||||
filtered_files.append(log_file)
|
filtered_files.append(log_file)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -172,39 +174,37 @@ async def get_chat_reply_logs(
|
||||||
logs = []
|
logs = []
|
||||||
for log_file in page_files:
|
for log_file in page_files:
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
output = data.get('output', '')
|
output = data.get("output", "")
|
||||||
logs.append(ReplyLogSummary(
|
logs.append(
|
||||||
chat_id=data.get('chat_id', chat_id),
|
ReplyLogSummary(
|
||||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
chat_id=data.get("chat_id", chat_id),
|
||||||
|
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||||
filename=log_file.name,
|
filename=log_file.name,
|
||||||
model=data.get('model', ''),
|
model=data.get("model", ""),
|
||||||
success=data.get('success', True),
|
success=data.get("success", True),
|
||||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
llm_ms=data.get("timing", {}).get("llm_ms", 0),
|
||||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
overall_ms=data.get("timing", {}).get("overall_ms", 0),
|
||||||
output_preview=output[:100] if output else ''
|
output_preview=output[:100] if output else "",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 文件读取失败时使用文件名信息
|
# 文件读取失败时使用文件名信息
|
||||||
logs.append(ReplyLogSummary(
|
logs.append(
|
||||||
|
ReplyLogSummary(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||||
filename=log_file.name,
|
filename=log_file.name,
|
||||||
model='',
|
model="",
|
||||||
success=False,
|
success=False,
|
||||||
llm_ms=0,
|
llm_ms=0,
|
||||||
overall_ms=0,
|
overall_ms=0,
|
||||||
output_preview='[读取失败]'
|
output_preview="[读取失败]",
|
||||||
))
|
|
||||||
|
|
||||||
return PaginatedReplyLogs(
|
|
||||||
data=logs,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
chat_id=chat_id
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||||
|
|
@ -215,21 +215,21 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
||||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(log_file, 'r', encoding='utf-8') as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return ReplyLogDetail(
|
return ReplyLogDetail(
|
||||||
type=data.get('type', 'reply'),
|
type=data.get("type", "reply"),
|
||||||
chat_id=data.get('chat_id', chat_id),
|
chat_id=data.get("chat_id", chat_id),
|
||||||
timestamp=data.get('timestamp', 0),
|
timestamp=data.get("timestamp", 0),
|
||||||
prompt=data.get('prompt', ''),
|
prompt=data.get("prompt", ""),
|
||||||
output=data.get('output', ''),
|
output=data.get("output", ""),
|
||||||
processed_output=data.get('processed_output', []),
|
processed_output=data.get("processed_output", []),
|
||||||
model=data.get('model', ''),
|
model=data.get("model", ""),
|
||||||
reasoning=data.get('reasoning', ''),
|
reasoning=data.get("reasoning", ""),
|
||||||
think_level=data.get('think_level', 0),
|
think_level=data.get("think_level", 0),
|
||||||
timing=data.get('timing', {}),
|
timing=data.get("timing", {}),
|
||||||
error=data.get('error'),
|
error=data.get("error"),
|
||||||
success=data.get('success', True)
|
success=data.get("success", True),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||||
|
|
@ -237,6 +237,7 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
||||||
|
|
||||||
# ========== 兼容接口 ==========
|
# ========== 兼容接口 ==========
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats")
|
@router.get("/stats")
|
||||||
async def get_replier_stats():
|
async def get_replier_stats():
|
||||||
"""获取回复器统计信息"""
|
"""获取回复器统计信息"""
|
||||||
|
|
@ -258,7 +259,7 @@ async def get_replier_stats():
|
||||||
return {
|
return {
|
||||||
"total_chats": overview.total_chats,
|
"total_chats": overview.total_chats,
|
||||||
"total_replies": overview.total_replies,
|
"total_replies": overview.total_replies,
|
||||||
"recent_replies": recent_replies
|
"recent_replies": recent_replies,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import Depends, Cookie, Header, Request, HTTPException
|
from fastapi import Depends, Cookie, Header, Request
|
||||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
|
from .core import get_current_token, get_token_manager, check_auth_rate_limit
|
||||||
|
|
||||||
|
|
||||||
async def require_auth(
|
async def require_auth(
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
||||||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||||
|
|
||||||
|
|
||||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||||
# 支持格式:
|
# 支持格式:
|
||||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||||
|
|
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||||
def _get_anti_crawler_config():
|
def _get_anti_crawler_config():
|
||||||
"""获取防爬虫配置"""
|
"""获取防爬虫配置"""
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'mode': global_config.webui.anti_crawler_mode,
|
"mode": global_config.webui.anti_crawler_mode,
|
||||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
"allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
"trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||||
'trust_xff': global_config.webui.trust_xff
|
"trust_xff": global_config.webui.trust_xff,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 初始化配置(将在模块加载时执行)
|
# 初始化配置(将在模块加载时执行)
|
||||||
_config = _get_anti_crawler_config()
|
_config = _get_anti_crawler_config()
|
||||||
ANTI_CRAWLER_MODE = _config['mode']
|
ANTI_CRAWLER_MODE = _config["mode"]
|
||||||
ALLOWED_IPS = _config['allowed_ips']
|
ALLOWED_IPS = _config["allowed_ips"]
|
||||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
TRUSTED_PROXIES = _config["trusted_proxies"]
|
||||||
TRUST_XFF = _config['trust_xff']
|
TRUST_XFF = _config["trust_xff"]
|
||||||
|
|
||||||
|
|
||||||
def _get_mode_config(mode: str) -> dict:
|
def _get_mode_config(mode: str) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def _get_paragraph_store():
|
||||||
namespace="paragraph",
|
namespace="paragraph",
|
||||||
dir_path=embedding_dir,
|
dir_path=embedding_dir,
|
||||||
max_workers=1, # 只读不需要多线程
|
max_workers=1, # 只读不需要多线程
|
||||||
chunk_size=100
|
chunk_size=100,
|
||||||
)
|
)
|
||||||
paragraph_store.load_from_file()
|
paragraph_store.load_from_file()
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||||
paragraph_item = paragraph_store.store.get(node_id)
|
paragraph_item = paragraph_store.store.get(node_id)
|
||||||
if paragraph_item is not None:
|
if paragraph_item is not None:
|
||||||
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
||||||
content: str = getattr(paragraph_item, 'str', '')
|
content: str = getattr(paragraph_item, "str", "")
|
||||||
if content:
|
if content:
|
||||||
return content, True
|
return content, True
|
||||||
return None, True
|
return None, True
|
||||||
|
|
@ -160,7 +160,11 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||||
if node_type == "paragraph":
|
if node_type == "paragraph":
|
||||||
full_content, _ = _get_paragraph_content(node_id)
|
full_content, _ = _get_paragraph_content(node_id)
|
||||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
content = (
|
||||||
|
full_content
|
||||||
|
if full_content is not None
|
||||||
|
else (node_data["content"] if "content" in node_data else node_id)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = node_data["content"] if "content" in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
|
|
||||||
|
|
@ -249,7 +253,11 @@ async def get_knowledge_graph(
|
||||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||||
if node_type_val == "paragraph":
|
if node_type_val == "paragraph":
|
||||||
full_content, _ = _get_paragraph_content(node_id)
|
full_content, _ = _get_paragraph_content(node_id)
|
||||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
content = (
|
||||||
|
full_content
|
||||||
|
if full_content is not None
|
||||||
|
else (node_data["content"] if "content" in node_data else node_id)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = node_data["content"] if "content" in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
|
|
||||||
|
|
@ -372,7 +380,11 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
|
||||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||||
if node_type == "paragraph":
|
if node_type == "paragraph":
|
||||||
full_content, _ = _get_paragraph_content(node_id)
|
full_content, _ = _get_paragraph_content(node_id)
|
||||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
content = (
|
||||||
|
full_content
|
||||||
|
if full_content is not None
|
||||||
|
else (node_data["content"] if "content" in node_data else node_id)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = node_data["content"] if "content" in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue