mirror of https://github.com/Mai-with-u/MaiBot.git
Ruff Format
parent
2cb512120b
commit
eaef7f0e98
1
bot.py
1
bot.py
|
|
@ -50,6 +50,7 @@ print("警告:Dev进入不稳定开发状态,任何插件与WebUI均可能
|
|||
print("\n\n\n\n\n")
|
||||
print("-----------------------------------------")
|
||||
|
||||
|
||||
def run_runner_process():
|
||||
"""
|
||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Core helpers for MCP Bridge Plugin."""
|
||||
|
||||
|
|
|
|||
|
|
@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
|
|||
if not mcp_servers:
|
||||
return ""
|
||||
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,28 +34,31 @@ from enum import Enum
|
|||
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("mcp_client")
|
||||
except ImportError:
|
||||
# Fallback: 使用标准 logging
|
||||
logger = logging.getLogger("mcp_client")
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s'))
|
||||
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class TransportType(Enum):
|
||||
"""MCP 传输类型"""
|
||||
STDIO = "stdio" # 本地进程通信
|
||||
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
||||
HTTP = "http" # HTTP Streamable (新版,推荐)
|
||||
|
||||
STDIO = "stdio" # 本地进程通信
|
||||
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
||||
HTTP = "http" # HTTP Streamable (新版,推荐)
|
||||
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolInfo:
|
||||
"""MCP 工具信息"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
|
|
@ -65,6 +68,7 @@ class MCPToolInfo:
|
|||
@dataclass
|
||||
class MCPResourceInfo:
|
||||
"""MCP 资源信息"""
|
||||
|
||||
uri: str
|
||||
name: str
|
||||
description: str
|
||||
|
|
@ -75,6 +79,7 @@ class MCPResourceInfo:
|
|||
@dataclass
|
||||
class MCPPromptInfo:
|
||||
"""MCP 提示模板信息"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
||||
|
|
@ -84,6 +89,7 @@ class MCPPromptInfo:
|
|||
@dataclass
|
||||
class MCPServerConfig:
|
||||
"""MCP 服务器配置"""
|
||||
|
||||
name: str
|
||||
enabled: bool = True
|
||||
transport: TransportType = TransportType.STDIO
|
||||
|
|
@ -99,6 +105,7 @@ class MCPServerConfig:
|
|||
@dataclass
|
||||
class MCPCallResult:
|
||||
"""MCP 工具调用结果"""
|
||||
|
||||
success: bool
|
||||
content: Any
|
||||
error: Optional[str] = None
|
||||
|
|
@ -108,8 +115,9 @@ class MCPCallResult:
|
|||
|
||||
class CircuitState(Enum):
|
||||
"""断路器状态"""
|
||||
CLOSED = "closed" # 正常状态,允许请求
|
||||
OPEN = "open" # 熔断状态,拒绝请求
|
||||
|
||||
CLOSED = "closed" # 正常状态,允许请求
|
||||
OPEN = "open" # 熔断状态,拒绝请求
|
||||
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
||||
|
||||
|
||||
|
|
@ -125,9 +133,9 @@ class CircuitBreaker:
|
|||
"""
|
||||
|
||||
# 配置
|
||||
failure_threshold: int = 5 # 连续失败多少次后熔断
|
||||
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
||||
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
||||
failure_threshold: int = 5 # 连续失败多少次后熔断
|
||||
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
||||
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
||||
|
||||
# 状态
|
||||
state: CircuitState = field(default=CircuitState.CLOSED)
|
||||
|
|
@ -232,6 +240,7 @@ class CircuitBreaker:
|
|||
@dataclass
|
||||
class ToolCallStats:
|
||||
"""工具调用统计"""
|
||||
|
||||
tool_key: str
|
||||
total_calls: int = 0
|
||||
success_calls: int = 0
|
||||
|
|
@ -282,6 +291,7 @@ class ToolCallStats:
|
|||
@dataclass
|
||||
class ServerStats:
|
||||
"""服务器统计"""
|
||||
|
||||
server_name: str
|
||||
connect_count: int = 0 # 连接次数
|
||||
disconnect_count: int = 0 # 断开次数
|
||||
|
|
@ -442,9 +452,7 @@ class MCPClientSession:
|
|||
return False
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.config.command,
|
||||
args=self.config.args,
|
||||
env=self.config.env if self.config.env else None
|
||||
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
|
||||
)
|
||||
|
||||
self._stdio_context = stdio_client(server_params)
|
||||
|
|
@ -506,6 +514,7 @@ class MCPClientSession:
|
|||
except Exception as e:
|
||||
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||
await self._cleanup()
|
||||
return False
|
||||
|
|
@ -551,6 +560,7 @@ class MCPClientSession:
|
|||
except Exception as e:
|
||||
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||
await self._cleanup()
|
||||
return False
|
||||
|
|
@ -568,8 +578,8 @@ class MCPClientSession:
|
|||
tool_info = MCPToolInfo(
|
||||
name=tool.name,
|
||||
description=tool.description or f"MCP tool: {tool.name}",
|
||||
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {},
|
||||
server_name=self.server_name
|
||||
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
server_name=self.server_name,
|
||||
)
|
||||
self._tools.append(tool_info)
|
||||
# 初始化工具统计
|
||||
|
|
@ -591,10 +601,7 @@ class MCPClientSession:
|
|||
return False
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.list_resources(),
|
||||
timeout=self.call_timeout
|
||||
)
|
||||
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
|
||||
self._resources = []
|
||||
|
||||
for resource in result.resources:
|
||||
|
|
@ -602,8 +609,8 @@ class MCPClientSession:
|
|||
uri=str(resource.uri),
|
||||
name=resource.name or str(resource.uri),
|
||||
description=resource.description or "",
|
||||
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None,
|
||||
server_name=self.server_name
|
||||
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
|
||||
server_name=self.server_name,
|
||||
)
|
||||
self._resources.append(resource_info)
|
||||
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
||||
|
|
@ -633,28 +640,27 @@ class MCPClientSession:
|
|||
return False
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.list_prompts(),
|
||||
timeout=self.call_timeout
|
||||
)
|
||||
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
|
||||
self._prompts = []
|
||||
|
||||
for prompt in result.prompts:
|
||||
# 解析参数
|
||||
arguments = []
|
||||
if hasattr(prompt, 'arguments') and prompt.arguments:
|
||||
if hasattr(prompt, "arguments") and prompt.arguments:
|
||||
for arg in prompt.arguments:
|
||||
arguments.append({
|
||||
"name": arg.name,
|
||||
"description": arg.description or "",
|
||||
"required": arg.required if hasattr(arg, 'required') else False,
|
||||
})
|
||||
arguments.append(
|
||||
{
|
||||
"name": arg.name,
|
||||
"description": arg.description or "",
|
||||
"required": arg.required if hasattr(arg, "required") else False,
|
||||
}
|
||||
)
|
||||
|
||||
prompt_info = MCPPromptInfo(
|
||||
name=prompt.name,
|
||||
description=prompt.description or f"MCP prompt: {prompt.name}",
|
||||
arguments=arguments,
|
||||
server_name=self.server_name
|
||||
server_name=self.server_name,
|
||||
)
|
||||
self._prompts.append(prompt_info)
|
||||
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
||||
|
|
@ -686,35 +692,25 @@ class MCPClientSession:
|
|||
start_time = time.time()
|
||||
|
||||
if not self._connected or not self._session:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"服务器 {self.server_name} 未连接"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||
|
||||
if not self._supports_resources:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"服务器 {self.server_name} 不支持 Resources 功能"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.read_resource(uri),
|
||||
timeout=self.call_timeout
|
||||
)
|
||||
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# 处理返回内容
|
||||
content_parts = []
|
||||
for content in result.contents:
|
||||
if hasattr(content, 'text'):
|
||||
if hasattr(content, "text"):
|
||||
content_parts.append(content.text)
|
||||
elif hasattr(content, 'blob'):
|
||||
elif hasattr(content, "blob"):
|
||||
# 二进制数据,返回 base64 或提示
|
||||
import base64
|
||||
|
||||
blob_data = content.blob
|
||||
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
||||
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
||||
|
|
@ -724,28 +720,18 @@ class MCPClientSession:
|
|||
content_parts.append(str(content))
|
||||
|
||||
return MCPCallResult(
|
||||
success=True,
|
||||
content="\n".join(content_parts) if content_parts else "",
|
||||
duration_ms=duration_ms
|
||||
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"读取资源超时({self.call_timeout}秒)",
|
||||
duration_ms=duration_ms
|
||||
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
||||
"""v1.2.0: 获取提示模板的内容
|
||||
|
|
@ -760,23 +746,14 @@ class MCPClientSession:
|
|||
start_time = time.time()
|
||||
|
||||
if not self._connected or not self._session:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"服务器 {self.server_name} 未连接"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||
|
||||
if not self._supports_prompts:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.get_prompt(name, arguments=arguments or {}),
|
||||
timeout=self.call_timeout
|
||||
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
|
||||
)
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
|
@ -784,10 +761,10 @@ class MCPClientSession:
|
|||
# 处理返回的消息
|
||||
messages = []
|
||||
for msg in result.messages:
|
||||
role = msg.role if hasattr(msg, 'role') else "unknown"
|
||||
role = msg.role if hasattr(msg, "role") else "unknown"
|
||||
content_text = ""
|
||||
if hasattr(msg, 'content'):
|
||||
if hasattr(msg.content, 'text'):
|
||||
if hasattr(msg, "content"):
|
||||
if hasattr(msg.content, "text"):
|
||||
content_text = msg.content.text
|
||||
elif isinstance(msg.content, str):
|
||||
content_text = msg.content
|
||||
|
|
@ -796,28 +773,18 @@ class MCPClientSession:
|
|||
messages.append(f"[{role}]: {content_text}")
|
||||
|
||||
return MCPCallResult(
|
||||
success=True,
|
||||
content="\n\n".join(messages) if messages else "",
|
||||
duration_ms=duration_ms
|
||||
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"获取提示模板超时({self.call_timeout}秒)",
|
||||
duration_ms=duration_ms
|
||||
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
"""检查连接健康状态(心跳检测)
|
||||
|
|
@ -829,10 +796,7 @@ class MCPClientSession:
|
|||
|
||||
try:
|
||||
# 使用 list_tools 作为心跳检测
|
||||
await asyncio.wait_for(
|
||||
self._session.list_tools(),
|
||||
timeout=10.0
|
||||
)
|
||||
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
|
||||
self.stats.record_heartbeat()
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -849,12 +813,7 @@ class MCPClientSession:
|
|||
# v1.7.0: 断路器检查
|
||||
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
||||
if not can_execute:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"⚡ {reject_reason}",
|
||||
circuit_broken=True
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True)
|
||||
|
||||
# 半开状态下增加试探计数
|
||||
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
||||
|
|
@ -870,8 +829,7 @@ class MCPClientSession:
|
|||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.call_tool(tool_name, arguments=arguments),
|
||||
timeout=self.call_timeout
|
||||
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
|
||||
)
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
|
@ -879,9 +837,9 @@ class MCPClientSession:
|
|||
# 处理返回内容
|
||||
content_parts = []
|
||||
for content in result.content:
|
||||
if hasattr(content, 'text'):
|
||||
if hasattr(content, "text"):
|
||||
content_parts.append(content.text)
|
||||
elif hasattr(content, 'data'):
|
||||
elif hasattr(content, "data"):
|
||||
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
||||
else:
|
||||
content_parts.append(str(content))
|
||||
|
|
@ -896,7 +854,7 @@ class MCPClientSession:
|
|||
return MCPCallResult(
|
||||
success=True,
|
||||
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
||||
duration_ms=duration_ms
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -939,25 +897,25 @@ class MCPClientSession:
|
|||
self._supports_prompts = False # v1.2.0
|
||||
|
||||
try:
|
||||
if hasattr(self, '_session_context') and self._session_context:
|
||||
if hasattr(self, "_session_context") and self._session_context:
|
||||
await self._session_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(self, '_stdio_context') and self._stdio_context:
|
||||
if hasattr(self, "_stdio_context") and self._stdio_context:
|
||||
await self._stdio_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(self, '_http_context') and self._http_context:
|
||||
if hasattr(self, "_http_context") and self._http_context:
|
||||
await self._http_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(self, '_sse_context') and self._sse_context:
|
||||
if hasattr(self, "_sse_context") and self._sse_context:
|
||||
await self._sse_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
||||
|
|
@ -1082,7 +1040,9 @@ class MCPClientManager:
|
|||
return True
|
||||
|
||||
if attempt < retry_attempts:
|
||||
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
|
||||
logger.warning(
|
||||
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
|
||||
)
|
||||
await asyncio.sleep(retry_interval)
|
||||
|
||||
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
||||
|
|
@ -1213,11 +1173,7 @@ class MCPClientManager:
|
|||
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
||||
"""调用 MCP 工具"""
|
||||
if tool_key not in self._all_tools:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"工具 {tool_key} 不存在"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
|
||||
|
||||
tool_info, client = self._all_tools[tool_key]
|
||||
|
||||
|
|
@ -1273,11 +1229,7 @@ class MCPClientManager:
|
|||
# 如果指定了服务器
|
||||
if server_name:
|
||||
if server_name not in self._clients:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"服务器 {server_name} 不存在"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||
client = self._clients[server_name]
|
||||
return await client.read_resource(uri)
|
||||
|
||||
|
|
@ -1293,14 +1245,11 @@ class MCPClientManager:
|
|||
if result.success:
|
||||
return result
|
||||
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"未找到资源: {uri}"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
|
||||
|
||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None,
|
||||
server_name: Optional[str] = None) -> MCPCallResult:
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
|
||||
) -> MCPCallResult:
|
||||
"""v1.2.0: 获取提示模板内容
|
||||
|
||||
Args:
|
||||
|
|
@ -1311,11 +1260,7 @@ class MCPClientManager:
|
|||
# 如果指定了服务器
|
||||
if server_name:
|
||||
if server_name not in self._clients:
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"服务器 {server_name} 不存在"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||
client = self._clients[server_name]
|
||||
return await client.get_prompt(name, arguments)
|
||||
|
||||
|
|
@ -1324,11 +1269,7 @@ class MCPClientManager:
|
|||
if prompt_info.name == name:
|
||||
return await client.get_prompt(name, arguments)
|
||||
|
||||
return MCPCallResult(
|
||||
success=False,
|
||||
content=None,
|
||||
error=f"未找到提示模板: {name}"
|
||||
)
|
||||
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
|
||||
|
||||
# ==================== 心跳检测 ====================
|
||||
|
||||
|
|
@ -1489,7 +1430,9 @@ class MCPClientManager:
|
|||
"global": {
|
||||
**self._global_stats,
|
||||
"uptime_seconds": round(uptime, 2),
|
||||
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0,
|
||||
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
|
||||
if uptime > 0
|
||||
else 0,
|
||||
},
|
||||
"servers": server_stats,
|
||||
"tools": tool_stats,
|
||||
|
|
|
|||
|
|
@ -123,9 +123,11 @@ logger = get_logger("mcp_bridge_plugin")
|
|||
# v1.4.0: 调用链路追踪
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRecord:
|
||||
"""工具调用记录"""
|
||||
|
||||
call_id: str
|
||||
timestamp: float
|
||||
tool_name: str
|
||||
|
|
@ -208,9 +210,11 @@ tool_call_tracer = ToolCallTracer()
|
|||
# v1.4.0: 工具调用缓存
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""缓存条目"""
|
||||
|
||||
tool_name: str
|
||||
args_hash: str
|
||||
result: str
|
||||
|
|
@ -347,6 +351,7 @@ tool_call_cache = ToolCallCache()
|
|||
# v1.4.0: 工具权限控制
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PermissionChecker:
|
||||
"""工具权限检查器"""
|
||||
|
||||
|
|
@ -479,6 +484,7 @@ permission_checker = PermissionChecker()
|
|||
# 工具类型转换
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
||||
type_mapping = {
|
||||
|
|
@ -492,7 +498,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
|||
return type_mapping.get(json_type, ToolParamType.STRING)
|
||||
|
||||
|
||||
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
||||
def parse_mcp_parameters(
|
||||
input_schema: Dict[str, Any],
|
||||
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
||||
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
||||
parameters = []
|
||||
|
||||
|
|
@ -534,6 +542,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
|
|||
# MCP 工具代理
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MCPToolProxy(BaseTool):
|
||||
"""MCP 工具代理基类"""
|
||||
|
||||
|
|
@ -576,10 +585,7 @@ class MCPToolProxy(BaseTool):
|
|||
# v1.4.0: 权限检查
|
||||
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
||||
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
|
||||
}
|
||||
return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
|
||||
|
||||
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
||||
|
||||
|
|
@ -749,11 +755,7 @@ class MCPToolProxy(BaseTool):
|
|||
return None
|
||||
|
||||
async def _call_post_process_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
settings: Dict[str, Any],
|
||||
server_config: Optional[Dict[str, Any]]
|
||||
self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
|
||||
) -> Optional[str]:
|
||||
"""调用 LLM 进行后处理"""
|
||||
from src.config.config import model_config
|
||||
|
|
@ -811,10 +813,7 @@ class MCPToolProxy(BaseTool):
|
|||
|
||||
|
||||
def create_mcp_tool_class(
|
||||
tool_key: str,
|
||||
tool_info: MCPToolInfo,
|
||||
tool_prefix: str,
|
||||
disabled: bool = False
|
||||
tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||
) -> Type[MCPToolProxy]:
|
||||
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
||||
parameters = parse_mcp_parameters(tool_info.input_schema)
|
||||
|
|
@ -837,7 +836,7 @@ def create_mcp_tool_class(
|
|||
"_mcp_tool_key": tool_key,
|
||||
"_mcp_original_name": tool_info.name,
|
||||
"_mcp_server_name": tool_info.server_name,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return tool_class
|
||||
|
|
@ -851,11 +850,7 @@ class MCPToolRegistry:
|
|||
self._tool_infos: Dict[str, ToolInfo] = {}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
tool_key: str,
|
||||
tool_info: MCPToolInfo,
|
||||
tool_prefix: str,
|
||||
disabled: bool = False
|
||||
self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
||||
"""注册 MCP 工具"""
|
||||
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
||||
|
|
@ -902,6 +897,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
|
|||
# 内置工具
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MCPReadResourceTool(BaseTool):
|
||||
"""v1.2.0: MCP 资源读取工具"""
|
||||
|
||||
|
|
@ -973,6 +969,7 @@ class MCPGetPromptTool(BaseTool):
|
|||
# v1.8.0: 工具链代理工具
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ToolChainProxyBase(BaseTool):
|
||||
"""工具链代理基类"""
|
||||
|
||||
|
|
@ -1037,7 +1034,7 @@ def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBa
|
|||
"parameters": parameters,
|
||||
"available_for_llm": True,
|
||||
"_chain_name": chain.name,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return tool_class
|
||||
|
|
@ -1095,7 +1092,13 @@ class MCPStatusTool(BaseTool):
|
|||
name = "mcp_status"
|
||||
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
||||
parameters = [
|
||||
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"]),
|
||||
(
|
||||
"query_type",
|
||||
ToolParamType.STRING,
|
||||
"查询类型",
|
||||
False,
|
||||
["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"],
|
||||
),
|
||||
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
||||
]
|
||||
available_for_llm = True
|
||||
|
|
@ -1132,10 +1135,7 @@ class MCPStatusTool(BaseTool):
|
|||
if query_type in ("cache",):
|
||||
result_parts.append(self._format_cache())
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
|
||||
}
|
||||
return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
|
||||
|
||||
def _format_status(self, server_name: Optional[str] = None) -> str:
|
||||
status = mcp_manager.get_status()
|
||||
|
|
@ -1147,14 +1147,14 @@ class MCPStatusTool(BaseTool):
|
|||
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
||||
|
||||
lines.append("\n🔌 服务器详情:")
|
||||
for name, info in status['servers'].items():
|
||||
for name, info in status["servers"].items():
|
||||
if server_name and name != server_name:
|
||||
continue
|
||||
status_icon = "✅" if info['connected'] else "❌"
|
||||
enabled_text = "" if info['enabled'] else " (已禁用)"
|
||||
status_icon = "✅" if info["connected"] else "❌"
|
||||
enabled_text = "" if info["enabled"] else " (已禁用)"
|
||||
lines.append(f" {status_icon} {name}{enabled_text}")
|
||||
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
||||
if info['consecutive_failures'] > 0:
|
||||
if info["consecutive_failures"] > 0:
|
||||
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
|
@ -1184,11 +1184,11 @@ class MCPStatusTool(BaseTool):
|
|||
stats = mcp_manager.get_all_stats()
|
||||
lines = ["📈 调用统计"]
|
||||
|
||||
g = stats['global']
|
||||
g = stats["global"]
|
||||
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
||||
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
||||
if g['total_tool_calls'] > 0:
|
||||
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100
|
||||
if g["total_tool_calls"] > 0:
|
||||
success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
|
||||
lines.append(f" 成功率: {success_rate:.1f}%")
|
||||
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
||||
|
||||
|
|
@ -1277,7 +1277,7 @@ class MCPStatusTool(BaseTool):
|
|||
lines.append(f" 描述: {chain.description[:50]}...")
|
||||
lines.append(f" 步骤: {len(chain.steps)} 个")
|
||||
for i, step in enumerate(chain.steps[:3]):
|
||||
lines.append(f" {i+1}. {step.tool_name}")
|
||||
lines.append(f" {i + 1}. {step.tool_name}")
|
||||
if len(chain.steps) > 3:
|
||||
lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤")
|
||||
if chain.input_params:
|
||||
|
|
@ -1294,6 +1294,7 @@ class MCPStatusTool(BaseTool):
|
|||
# 命令处理
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MCPStatusCommand(BaseCommand):
|
||||
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
||||
|
||||
|
|
@ -1644,6 +1645,7 @@ class MCPStatusCommand(BaseCommand):
|
|||
_plugin_instance._load_tool_chains()
|
||||
chains = tool_chain_manager.get_all_chains()
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
registered = 0
|
||||
for name, chain in tool_chain_manager.get_enabled_chains().items():
|
||||
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
|
||||
|
|
@ -1735,7 +1737,7 @@ class MCPStatusCommand(BaseCommand):
|
|||
lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):")
|
||||
for i, step in enumerate(chain.steps):
|
||||
optional_tag = " (可选)" if step.optional else ""
|
||||
lines.append(f" {i+1}. {step.tool_name}{optional_tag}")
|
||||
lines.append(f" {i + 1}. {step.tool_name}{optional_tag}")
|
||||
if step.description:
|
||||
lines.append(f" {step.description}")
|
||||
if step.output_key:
|
||||
|
|
@ -1983,6 +1985,7 @@ class MCPImportCommand(BaseCommand):
|
|||
# 事件处理器
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MCPStartupHandler(BaseEventHandler):
|
||||
"""MCP 启动事件处理器"""
|
||||
|
||||
|
|
@ -2037,6 +2040,7 @@ class MCPStopHandler(BaseEventHandler):
|
|||
# 主插件类
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@register_plugin
|
||||
class MCPBridgePlugin(BasePlugin):
|
||||
"""MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
||||
|
|
@ -2505,7 +2509,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||
label="📋 工具链列表",
|
||||
input_type="textarea",
|
||||
rows=20,
|
||||
placeholder='''[
|
||||
placeholder="""[
|
||||
{
|
||||
"name": "search_and_detail",
|
||||
"description": "先搜索再获取详情",
|
||||
|
|
@ -2515,7 +2519,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||
{"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}}
|
||||
]
|
||||
}
|
||||
]''',
|
||||
]""",
|
||||
hint="每个工具链包含 name、description、input_params、steps",
|
||||
order=30,
|
||||
),
|
||||
|
|
@ -2653,9 +2657,9 @@ mcp_bing_*""",
|
|||
label="📜 高级权限规则(可选)",
|
||||
input_type="textarea",
|
||||
rows=10,
|
||||
placeholder='''[
|
||||
placeholder="""[
|
||||
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
||||
]''',
|
||||
]""",
|
||||
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
||||
order=10,
|
||||
),
|
||||
|
|
@ -2754,7 +2758,9 @@ mcp_bing_*""",
|
|||
value = match1.group(2)
|
||||
suffix = match1.group(3)
|
||||
# 将转义的换行符还原为实际换行符
|
||||
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
||||
unescaped = (
|
||||
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
||||
)
|
||||
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
||||
fixed_lines.append(fixed_line)
|
||||
modified = True
|
||||
|
|
@ -2948,11 +2954,13 @@ mcp_bing_*""",
|
|||
logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}")
|
||||
args_template = {}
|
||||
|
||||
steps.append({
|
||||
"tool_name": tool_name,
|
||||
"args_template": args_template,
|
||||
"output_key": output_key,
|
||||
})
|
||||
steps.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"args_template": args_template,
|
||||
"output_key": output_key,
|
||||
}
|
||||
)
|
||||
|
||||
if not steps:
|
||||
logger.warning("快速添加工具链: 没有有效的步骤")
|
||||
|
|
@ -3056,7 +3064,9 @@ mcp_bing_*""",
|
|||
if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict):
|
||||
self.config["tool_chains"] = {}
|
||||
self.config["tool_chains"]["chains_list"] = chains_json
|
||||
logger.info("检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)")
|
||||
logger.info(
|
||||
"检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)"
|
||||
)
|
||||
|
||||
chains_config = self.config.get("tool_chains", {})
|
||||
if not isinstance(chains_config, dict):
|
||||
|
|
@ -3153,10 +3163,7 @@ mcp_bing_*""",
|
|||
|
||||
# 应用过滤器
|
||||
if filter_patterns:
|
||||
matched = any(
|
||||
fnmatch.fnmatch(tool_name, p) or tool_name == p
|
||||
for p in filter_patterns
|
||||
)
|
||||
matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns)
|
||||
|
||||
if filter_mode == "whitelist":
|
||||
# 白名单模式:只注册匹配的
|
||||
|
|
@ -3179,6 +3186,7 @@ mcp_bing_*""",
|
|||
return result.content or "(无返回内容)"
|
||||
else:
|
||||
return f"工具调用失败: {result.error}"
|
||||
|
||||
return execute_func
|
||||
|
||||
execute_func = make_execute_func(tool_key)
|
||||
|
|
@ -3207,7 +3215,9 @@ mcp_bing_*""",
|
|||
|
||||
return registered_count
|
||||
|
||||
def _update_react_status_display(self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]) -> None:
|
||||
def _update_react_status_display(
|
||||
self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]
|
||||
) -> None:
|
||||
"""更新 ReAct 工具状态显示"""
|
||||
if not registered_tools:
|
||||
status_text = "(未注册任何工具)"
|
||||
|
|
@ -3243,12 +3253,14 @@ mcp_bing_*""",
|
|||
description = param_info.get("description", f"参数 {param_name}")
|
||||
is_required = param_name in required
|
||||
|
||||
parameters.append({
|
||||
"name": param_name,
|
||||
"type": param_type,
|
||||
"description": description,
|
||||
"required": is_required,
|
||||
})
|
||||
parameters.append(
|
||||
{
|
||||
"name": param_name,
|
||||
"type": param_type,
|
||||
"description": description,
|
||||
"required": is_required,
|
||||
}
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
|
|
@ -3276,7 +3288,7 @@ mcp_bing_*""",
|
|||
for i, step in enumerate(chain.steps):
|
||||
opt = " (可选)" if step.optional else ""
|
||||
out = f" → {step.output_key}" if step.output_key else ""
|
||||
lines.append(f" {i+1}. {step.tool_name}{out}{opt}")
|
||||
lines.append(f" {i + 1}. {step.tool_name}{out}{opt}")
|
||||
lines.append("")
|
||||
|
||||
status_text = "\n".join(lines)
|
||||
|
|
@ -3295,6 +3307,7 @@ mcp_bing_*""",
|
|||
async def _async_connect_servers(self) -> None:
|
||||
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
||||
import asyncio
|
||||
|
||||
settings = self.config.get("settings", {})
|
||||
|
||||
servers_config = self._load_mcp_servers_config()
|
||||
|
|
@ -3380,10 +3393,7 @@ mcp_bing_*""",
|
|||
|
||||
# 并行执行所有连接
|
||||
start_time = time.time()
|
||||
results = await asyncio.gather(
|
||||
*[connect_single_server(cfg) for cfg in enabled_configs],
|
||||
return_exceptions=True
|
||||
)
|
||||
results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
|
||||
connect_duration = time.time() - start_time
|
||||
|
||||
# 统计连接结果
|
||||
|
|
@ -3404,15 +3414,14 @@ mcp_bing_*""",
|
|||
|
||||
# 注册所有工具
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
registered_count = 0
|
||||
|
||||
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
||||
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
||||
is_disabled = tool_name in disabled_tools
|
||||
|
||||
info, tool_class = mcp_tool_registry.register_tool(
|
||||
tool_key, tool_info, tool_prefix, disabled=is_disabled
|
||||
)
|
||||
info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
|
||||
info.plugin_name = self.plugin_name
|
||||
|
||||
if component_registry.register_component(info, tool_class):
|
||||
|
|
@ -3433,7 +3442,9 @@ mcp_bing_*""",
|
|||
react_count = self._register_tools_to_react()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具")
|
||||
logger.info(
|
||||
f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具"
|
||||
)
|
||||
|
||||
# 更新状态显示
|
||||
self._update_status_display()
|
||||
|
|
@ -3508,7 +3519,9 @@ mcp_bing_*""",
|
|||
logger.info("检测到旧版 servers.list,已自动迁移为 Claude mcpServers(请在 WebUI 保存一次以固化)")
|
||||
|
||||
if not claude_json.strip():
|
||||
self._last_servers_config_error = "未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)"
|
||||
self._last_servers_config_error = (
|
||||
"未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)"
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -22,15 +22,18 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("mcp_tool_chain")
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("mcp_tool_chain")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolChainStep:
|
||||
"""工具链步骤"""
|
||||
|
||||
tool_name: str # 要调用的工具名(如 mcp_server_tool)
|
||||
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
|
||||
output_key: str = "" # 输出存储的键名,供后续步骤引用
|
||||
|
|
@ -60,6 +63,7 @@ class ToolChainStep:
|
|||
@dataclass
|
||||
class ToolChainDefinition:
|
||||
"""工具链定义"""
|
||||
|
||||
name: str # 工具链名称(将作为组合工具的名称)
|
||||
description: str # 工具链描述(供 LLM 理解)
|
||||
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
|
||||
|
|
@ -90,6 +94,7 @@ class ToolChainDefinition:
|
|||
@dataclass
|
||||
class ChainExecutionResult:
|
||||
"""工具链执行结果"""
|
||||
|
||||
success: bool
|
||||
final_output: str # 最终输出(最后一个步骤的结果)
|
||||
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
|
||||
|
|
@ -103,7 +108,7 @@ class ChainExecutionResult:
|
|||
status = "✅" if step.get("success") else "❌"
|
||||
tool = step.get("tool_name", "unknown")
|
||||
duration = step.get("duration_ms", 0)
|
||||
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)")
|
||||
lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
|
||||
if not step.get("success") and step.get("error"):
|
||||
lines.append(f" 错误: {step['error'][:50]}")
|
||||
return "\n".join(lines)
|
||||
|
|
@ -113,7 +118,7 @@ class ToolChainExecutor:
|
|||
"""工具链执行器"""
|
||||
|
||||
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
|
||||
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
|
||||
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
|
||||
|
||||
def __init__(self, mcp_manager):
|
||||
self._mcp_manager = mcp_manager
|
||||
|
|
@ -201,7 +206,7 @@ class ToolChainExecutor:
|
|||
tool_key = self._resolve_tool_key(step.tool_name)
|
||||
if not tool_key:
|
||||
step_result["error"] = f"工具 {step.tool_name} 不存在"
|
||||
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在")
|
||||
logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
|
|
@ -209,13 +214,13 @@ class ToolChainExecutor:
|
|||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在",
|
||||
error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
step_results.append(step_result)
|
||||
continue
|
||||
|
||||
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}")
|
||||
logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
|
||||
|
||||
# 调用工具
|
||||
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
|
||||
|
|
@ -236,10 +241,10 @@ class ToolChainExecutor:
|
|||
|
||||
final_output = content
|
||||
content_preview = content[:100] if content else "(空)"
|
||||
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...")
|
||||
logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
|
||||
else:
|
||||
step_result["error"] = result.error or "未知错误"
|
||||
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}")
|
||||
logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
|
|
@ -247,7 +252,7 @@ class ToolChainExecutor:
|
|||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}",
|
||||
error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
|
|
@ -255,7 +260,7 @@ class ToolChainExecutor:
|
|||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
step_result["error"] = str(e)
|
||||
logger.error(f"工具链步骤 {i+1} 异常: {e}")
|
||||
logger.error(f"工具链步骤 {i + 1} 异常: {e}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
|
|
@ -263,7 +268,7 @@ class ToolChainExecutor:
|
|||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}",
|
||||
error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
|
|
@ -295,10 +300,7 @@ class ToolChainExecutor:
|
|||
elif isinstance(value, dict):
|
||||
resolved[key] = self._resolve_args(value, context)
|
||||
elif isinstance(value, list):
|
||||
resolved[key] = [
|
||||
self._substitute_vars(v, context) if isinstance(v, str) else v
|
||||
for v in value
|
||||
]
|
||||
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
|
||||
else:
|
||||
resolved[key] = value
|
||||
|
||||
|
|
@ -306,6 +308,7 @@ class ToolChainExecutor:
|
|||
|
||||
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""替换字符串中的变量"""
|
||||
|
||||
def replacer(match):
|
||||
var_path = match.group(1)
|
||||
return self._get_var_value(var_path, context)
|
||||
|
|
@ -552,7 +555,7 @@ class ToolChainManager:
|
|||
try:
|
||||
chain = ToolChainDefinition.from_dict(item)
|
||||
if not chain.name:
|
||||
errors.append(f"第 {i+1} 个工具链缺少名称")
|
||||
errors.append(f"第 {i + 1} 个工具链缺少名称")
|
||||
continue
|
||||
if not chain.steps:
|
||||
errors.append(f"工具链 {chain.name} 没有步骤")
|
||||
|
|
@ -561,7 +564,7 @@ class ToolChainManager:
|
|||
self.add_chain(chain)
|
||||
loaded += 1
|
||||
except Exception as e:
|
||||
errors.append(f"第 {i+1} 个工具链解析失败: {e}")
|
||||
errors.append(f"第 {i + 1} 个工具链解析失败: {e}")
|
||||
|
||||
return loaded, errors
|
||||
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
|
|||
chat_stream=self.message.chat_stream,
|
||||
reply_reason=reply_reason,
|
||||
enable_chinese_typo=False,
|
||||
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
|
||||
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
|
||||
)
|
||||
if result_status:
|
||||
# 发送生成的回复
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
|
|||
|
||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
||||
|
||||
|
||||
class SimpleClass(ConfigBase):
|
||||
a: int = 1
|
||||
b: str = "test"
|
||||
|
|
@ -282,7 +283,7 @@ class TestConfigBase:
|
|||
True,
|
||||
"ConfigBase is not Hashable",
|
||||
id="listset-validation-set-configbase-element_reject",
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
||||
|
|
@ -340,7 +341,7 @@ class TestConfigBase:
|
|||
False,
|
||||
None,
|
||||
id="dict-validation-happy-configbase-value",
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
||||
|
|
@ -353,13 +354,11 @@ class TestConfigBase:
|
|||
field_name = "mapping"
|
||||
|
||||
if expect_error:
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
|
||||
# Act
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import importlib
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
import asyncio
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
|
|
@ -71,6 +70,7 @@ class DummyLLMRequest:
|
|||
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
|
||||
return ("dummy description", {})
|
||||
|
||||
|
||||
class DummySelect:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
|
@ -81,6 +81,7 @@ class DummySelect:
|
|||
def limit(self, n):
|
||||
return self
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_external_dependencies(monkeypatch):
|
||||
# Provide dummy implementations as modules so that importing image_manager is safe
|
||||
|
|
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
|
|||
if tmp_path is not None:
|
||||
tmpdir = Path(tmp_path)
|
||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||
setattr(mod, "IMAGE_DIR", tmpdir)
|
||||
mod.IMAGE_DIR = tmpdir
|
||||
except Exception:
|
||||
pass
|
||||
return mod
|
||||
|
|
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
|
|||
|
||||
# cleanup should run without error
|
||||
mgr.cleanup_invalid_descriptions_in_db()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import pytest
|
||||
|
||||
from src.config.official_configs import ChatConfig
|
||||
from src.config.config import Config
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Expression routes pytest tests"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
|
@ -12,7 +11,6 @@ from sqlalchemy import text
|
|||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
|
||||
|
||||
def create_test_app() -> FastAPI:
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def analyze_single_file(file_path: str) -> Dict:
|
|||
stats["date_range"] = {
|
||||
"start": min_date.isoformat(),
|
||||
"end": max_date.isoformat(),
|
||||
"duration_days": (max_date - min_date).days + 1
|
||||
"duration_days": (max_date - min_date).days + 1,
|
||||
}
|
||||
|
||||
# 检查字段存在性
|
||||
|
|
@ -151,8 +151,8 @@ def print_file_stats(stats: Dict, index: int = None):
|
|||
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||
print(f" 实际记录数: {stats['actual_count']}")
|
||||
|
||||
if stats['total_count'] != stats['actual_count']:
|
||||
diff = stats['total_count'] - stats['actual_count']
|
||||
if stats["total_count"] != stats["actual_count"]:
|
||||
diff = stats["total_count"] - stats["actual_count"]
|
||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||
|
||||
print("\n【评估结果统计】")
|
||||
|
|
@ -161,21 +161,21 @@ def print_file_stats(stats: Dict, index: int = None):
|
|||
|
||||
print("\n【唯一性统计】")
|
||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||
if stats['actual_count'] > 0:
|
||||
duplicate_count = stats['actual_count'] - stats['unique_pairs']
|
||||
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["actual_count"] > 0:
|
||||
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
|
||||
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
print("\n【评估者统计】")
|
||||
if stats['evaluators']:
|
||||
for evaluator, count in stats['evaluators'].most_common():
|
||||
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["evaluators"]:
|
||||
for evaluator, count in stats["evaluators"].most_common():
|
||||
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
print("\n【时间统计】")
|
||||
if stats['date_range']:
|
||||
if stats["date_range"]:
|
||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||
|
|
@ -185,8 +185,8 @@ def print_file_stats(stats: Dict, index: int = None):
|
|||
print("\n【字段统计】")
|
||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||
if stats['has_reason']:
|
||||
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["has_reason"]:
|
||||
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||
|
||||
|
||||
|
|
@ -215,15 +215,15 @@ def print_summary(all_stats: List[Dict]):
|
|||
return
|
||||
|
||||
# 汇总记录统计
|
||||
total_records = sum(s['actual_count'] for s in valid_files)
|
||||
total_suitable = sum(s['suitable_count'] for s in valid_files)
|
||||
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
|
||||
total_records = sum(s["actual_count"] for s in valid_files)
|
||||
total_suitable = sum(s["suitable_count"] for s in valid_files)
|
||||
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
|
||||
total_unique_pairs = set()
|
||||
|
||||
# 收集所有唯一的(situation, style)对
|
||||
for stats in valid_files:
|
||||
try:
|
||||
with open(stats['file_path'], "r", encoding="utf-8") as f:
|
||||
with open(stats["file_path"], "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
for r in results:
|
||||
|
|
@ -234,8 +234,16 @@ def print_summary(all_stats: List[Dict]):
|
|||
|
||||
print("\n【记录汇总】")
|
||||
print(f" 总记录数: {total_records:,} 条")
|
||||
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
|
||||
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
|
||||
print(
|
||||
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 通过: 0 条"
|
||||
)
|
||||
print(
|
||||
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 不通过: 0 条"
|
||||
)
|
||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||
|
||||
if total_records > 0:
|
||||
|
|
@ -246,7 +254,7 @@ def print_summary(all_stats: List[Dict]):
|
|||
# 汇总评估者统计
|
||||
all_evaluators = Counter()
|
||||
for stats in valid_files:
|
||||
all_evaluators.update(stats['evaluators'])
|
||||
all_evaluators.update(stats["evaluators"])
|
||||
|
||||
print("\n【评估者汇总】")
|
||||
if all_evaluators:
|
||||
|
|
@ -259,7 +267,7 @@ def print_summary(all_stats: List[Dict]):
|
|||
# 汇总时间范围
|
||||
all_dates = []
|
||||
for stats in valid_files:
|
||||
all_dates.extend(stats['evaluation_dates'])
|
||||
all_dates.extend(stats["evaluation_dates"])
|
||||
|
||||
if all_dates:
|
||||
min_date = min(all_dates)
|
||||
|
|
@ -270,7 +278,7 @@ def print_summary(all_stats: List[Dict]):
|
|||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||
|
||||
# 文件大小汇总
|
||||
total_size = sum(s['file_size'] for s in valid_files)
|
||||
total_size = sum(s["file_size"] for s in valid_files)
|
||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||
print("\n【文件大小汇总】")
|
||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||
|
|
@ -318,5 +326,3 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -171,7 +171,9 @@ def main():
|
|||
sys.exit(1)
|
||||
|
||||
if not args.raw_index:
|
||||
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
|
||||
logger.info(
|
||||
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# 解析索引列表(1-based)
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def save_results(evaluation_results: List[Dict]):
|
|||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(evaluation_results),
|
||||
"evaluation_results": evaluation_results
|
||||
"evaluation_results": evaluation_results,
|
||||
}
|
||||
|
||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||
|
|
@ -84,9 +84,7 @@ def save_results(evaluation_results: List[Dict]):
|
|||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def select_expressions_for_evaluation(
|
||||
evaluated_pairs: Set[Tuple[str, str]] = None
|
||||
) -> List[Expression]:
|
||||
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
|
||||
"""
|
||||
选择用于评估的表达方式
|
||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||
|
|
@ -109,10 +107,7 @@ def select_expressions_for_evaluation(
|
|||
return []
|
||||
|
||||
# 过滤出未评估的项目
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.warning("所有项目都已评估完成")
|
||||
|
|
@ -132,7 +127,9 @@ def select_expressions_for_evaluation(
|
|||
count_eq1_needed = count_gt1_count * 2
|
||||
|
||||
if len(count_eq1) < count_eq1_needed:
|
||||
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条")
|
||||
logger.warning(
|
||||
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条"
|
||||
)
|
||||
count_eq1_needed = len(count_eq1)
|
||||
|
||||
# 随机选择count=1的项目
|
||||
|
|
@ -141,13 +138,16 @@ def select_expressions_for_evaluation(
|
|||
selected = selected_count_gt1 + selected_count_eq1
|
||||
random.shuffle(selected) # 打乱顺序
|
||||
|
||||
logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)")
|
||||
logger.info(
|
||||
f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
|
@ -202,9 +202,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
|||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
|
@ -241,7 +239,9 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
|||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
|
||||
logger.info(
|
||||
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
|
||||
)
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||
|
||||
|
|
@ -258,7 +258,7 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
|||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -302,7 +302,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
print(f"Count = {count}:")
|
||||
print(f" 总数: {total}")
|
||||
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
# 比较count=1和count>1
|
||||
|
|
@ -340,12 +340,12 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
print("Count = 1:")
|
||||
print(f" 总数: {eq1_total}")
|
||||
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
|
||||
print()
|
||||
print("Count > 1:")
|
||||
print(f" 总数: {gt1_total}")
|
||||
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
# 进行卡方检验(简化版,使用2x2列联表)
|
||||
|
|
@ -427,7 +427,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||
"count_eq1": count_eq1_group,
|
||||
"count_gt1": count_gt1_group,
|
||||
"total_evaluated": len(evaluation_results)
|
||||
"total_evaluated": len(evaluation_results),
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -466,8 +466,7 @@ async def main():
|
|||
try:
|
||||
all_expressions = list(Expression.select())
|
||||
unevaluated_count_gt1 = [
|
||||
expr for expr in all_expressions
|
||||
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||
except Exception as e:
|
||||
|
|
@ -485,21 +484,20 @@ async def main():
|
|||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_count_analysis_llm"
|
||||
request_type="expression_evaluator_count_analysis_llm",
|
||||
)
|
||||
print("✓ LLM实例创建成功\n")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||
expressions = select_expressions_for_evaluation(
|
||||
evaluated_pairs=evaluated_pairs
|
||||
)
|
||||
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
|
||||
|
||||
if not expressions:
|
||||
print("\n没有可评估的项目")
|
||||
|
|
@ -518,7 +516,7 @@ async def main():
|
|||
llm_result = await llm_evaluate_expression(expression, llm)
|
||||
|
||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||
if llm_result.get('error'):
|
||||
if llm_result.get("error"):
|
||||
print(f" 错误: {llm_result['error']}")
|
||||
print()
|
||||
|
||||
|
|
@ -553,4 +551,3 @@ async def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
|
|||
|
|
@ -140,9 +140,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
|||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
|
@ -152,6 +150,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
|||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
|
|
@ -196,7 +195,7 @@ async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -
|
|||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm"
|
||||
"evaluator": "llm",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -244,10 +243,16 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
|||
false_negatives += 1
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
precision = (
|
||||
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
)
|
||||
recall = (
|
||||
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
)
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
specificity = (
|
||||
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
)
|
||||
|
||||
# 计算人工效标的不合适率
|
||||
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
||||
|
|
@ -283,7 +288,7 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
|||
"specificity": specificity,
|
||||
"manual_unsuitable_rate": manual_unsuitable_rate,
|
||||
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
||||
"rate_difference": rate_difference
|
||||
"rate_difference": rate_difference,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -334,13 +339,11 @@ async def main(count: int | None = None):
|
|||
# 2. 创建LLM实例并评估
|
||||
print("\n步骤2: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_llm"
|
||||
)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
|
|
@ -348,11 +351,7 @@ async def main(count: int | None = None):
|
|||
llm_results = []
|
||||
for i, manual_result in enumerate(valid_manual_results, 1):
|
||||
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
||||
llm_results.append(await evaluate_expression_llm(
|
||||
manual_result["situation"],
|
||||
manual_result["style"],
|
||||
llm
|
||||
))
|
||||
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 5. 输出FP和FN项目(在评估结果之前)
|
||||
|
|
@ -372,14 +371,16 @@ async def main(count: int | None = None):
|
|||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
fp_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
|
|
@ -389,7 +390,7 @@ async def main(count: int | None = None):
|
|||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get('llm_error'):
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
|
|
@ -410,14 +411,16 @@ async def main(count: int | None = None):
|
|||
|
||||
# 人工评估通过,但LLM评估不通过(FN情况)
|
||||
if manual_result["suitable"] and not llm_result["suitable"]:
|
||||
fn_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
fn_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fn_items:
|
||||
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
||||
|
|
@ -427,7 +430,7 @@ async def main(count: int | None = None):
|
|||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 通过 ✅")
|
||||
print("LLM评估: 不通过 ❌ (误删)")
|
||||
if item.get('llm_error'):
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
|
|
@ -447,13 +450,21 @@ async def main(count: int | None = None):
|
|||
print()
|
||||
# print(" 【核心能力指标】")
|
||||
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
||||
print(
|
||||
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
||||
print(
|
||||
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
|
|
@ -464,12 +475,18 @@ async def main(count: int | None = None):
|
|||
print()
|
||||
print(" 【不合适率分析】")
|
||||
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
|
||||
print(
|
||||
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
|
||||
)
|
||||
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
||||
print()
|
||||
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(
|
||||
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
|
||||
)
|
||||
print()
|
||||
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
||||
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
|
|
@ -486,11 +503,12 @@ async def main(count: int | None = None):
|
|||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"manual_results": valid_manual_results,
|
||||
"llm_results": llm_results,
|
||||
"comparison": comparison
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
json.dump(
|
||||
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
|
@ -509,15 +527,9 @@ if __name__ == "__main__":
|
|||
python evaluate_expressions_llm_v6.py # 使用全部数据
|
||||
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
||||
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--count",
|
||||
type=int,
|
||||
default=None,
|
||||
help="随机选取的数据条数(默认:使用全部数据)"
|
||||
""",
|
||||
)
|
||||
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(count=args.count))
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def save_results(manual_results: List[Dict]):
|
|||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(manual_results),
|
||||
"manual_results": manual_results
|
||||
"manual_results": manual_results,
|
||||
}
|
||||
|
||||
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
||||
|
|
@ -98,10 +98,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
|
|||
return []
|
||||
|
||||
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.info("所有项目都已评估完成")
|
||||
|
|
@ -120,6 +117,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
|
|||
except Exception as e:
|
||||
logger.error(f"获取未评估表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
|
@ -150,18 +148,18 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
|
|||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
||||
|
||||
if user_input in ['q', 'quit']:
|
||||
if user_input in ["q", "quit"]:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ['s', 'skip']:
|
||||
if user_input in ["s", "skip"]:
|
||||
print("跳过当前项目")
|
||||
return "skip"
|
||||
|
||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
||||
if user_input in ["y", "yes", "1", "是", "通过"]:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
||||
elif user_input in ["n", "no", "0", "否", "不通过"]:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
|
|
@ -173,7 +171,7 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
|
|||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
|
@ -257,9 +255,9 @@ def main():
|
|||
# 询问是否继续
|
||||
while True:
|
||||
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
||||
if continue_input in ['y', 'yes', '1', '是', '继续']:
|
||||
if continue_input in ["y", "yes", "1", "是", "继续"]:
|
||||
break
|
||||
elif continue_input in ['n', 'no', '0', '否', '退出']:
|
||||
elif continue_input in ["n", "no", "0", "否", "退出"]:
|
||||
print("\n评估结束")
|
||||
return
|
||||
else:
|
||||
|
|
@ -275,4 +273,3 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -134,9 +134,7 @@ def handle_import_openie(
|
|||
# 在非交互模式下,不再询问用户,而是直接报错终止
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
|
||||
)
|
||||
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
|
||||
sys.exit(1)
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
|
|
@ -189,9 +187,7 @@ def handle_import_openie(
|
|||
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
|
||||
)
|
||||
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
|
||||
else:
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
|
|
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
|
|||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
"""主函数 - 解析参数并运行异步主流程。"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
|
||||
"将其导入到 LPMM 的向量库与知识图中。"
|
||||
)
|
||||
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
|
|
|
|||
|
|
@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
|
|||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
|
||||
)
|
||||
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
|
||||
else:
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
|
|
|
|||
|
|
@ -68,4 +68,3 @@ def main() -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ except ImportError as e:
|
|||
|
||||
logger = get_logger("lpmm_interactive_manager")
|
||||
|
||||
|
||||
async def interactive_add():
|
||||
"""交互式导入知识"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -68,6 +69,7 @@ async def interactive_add():
|
|||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"add_content 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_delete():
|
||||
"""交互式删除知识"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -108,8 +110,12 @@ async def interactive_delete():
|
|||
return
|
||||
|
||||
print("-" * 40)
|
||||
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
|
||||
if confirm != 'y':
|
||||
confirm = (
|
||||
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if confirm != "y":
|
||||
print("\n[!] 已取消删除操作。")
|
||||
return
|
||||
|
||||
|
|
@ -129,6 +135,7 @@ async def interactive_delete():
|
|||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"delete 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_clear():
|
||||
"""交互式清空知识库"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -165,16 +172,21 @@ async def interactive_clear():
|
|||
before = stats.get("before", {})
|
||||
after = stats.get("after", {})
|
||||
print("\n[统计信息]")
|
||||
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
|
||||
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
|
||||
print(
|
||||
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
|
||||
)
|
||||
print(
|
||||
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
|
||||
)
|
||||
else:
|
||||
print(f"\n[×] 失败:{result['message']}")
|
||||
except Exception as e:
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"clear_all 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_search():
|
||||
"""交互式查询知识"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -224,6 +236,7 @@ async def interactive_search():
|
|||
print(f"\n[×] 查询失败: {e}")
|
||||
logger.error(f"查询异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主循环"""
|
||||
while True:
|
||||
|
|
@ -253,6 +266,7 @@ async def main():
|
|||
else:
|
||||
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 运行主循环
|
||||
|
|
@ -262,4 +276,3 @@ if __name__ == "__main__":
|
|||
except Exception as e:
|
||||
print(f"\n[!] 程序运行出错: {e}")
|
||||
logger.error(f"Main loop 异常: {e}", exc_info=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
|
|||
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
||||
txt_files = list(raw_dir.glob("*.txt"))
|
||||
if not txt_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
|
||||
"info_extraction 可能立即退出或无数据可处理。"
|
||||
)
|
||||
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,info_extraction 可能立即退出或无数据可处理。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
|
||||
)
|
||||
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
|
||||
return False
|
||||
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
|
|
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
|
|||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
json_files = list(openie_dir.glob("*.json"))
|
||||
if not json_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
|
||||
"import_openie 可能会因为找不到批次而失败。"
|
||||
)
|
||||
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,import_openie 可能会因为找不到批次而失败。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
|
||||
)
|
||||
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
|
||||
return False
|
||||
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
|
|
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
|
|||
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
||||
try:
|
||||
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
||||
print(
|
||||
"[WARN] 当前配置 lpmm_knowledge.enable = false,"
|
||||
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
|
||||
)
|
||||
print("[WARN] 当前配置 lpmm_knowledge.enable = false,刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
|
||||
except Exception:
|
||||
# 配置异常时不阻断主流程,仅忽略提示
|
||||
pass
|
||||
|
|
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
|||
if action == "prepare_raw":
|
||||
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
elif action == "info_extract":
|
||||
if not _check_before_info_extract("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行信息提取。")
|
||||
|
|
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
|||
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
||||
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
non_interactive = "--non-interactive" in extra_args
|
||||
if not _check_before_info_extract(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
||||
|
|
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
|
|||
)
|
||||
|
||||
# 快速选项:按推荐方式清理所有相关实体/关系
|
||||
quick_all = input(
|
||||
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
|
||||
).strip().lower()
|
||||
quick_all = (
|
||||
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
|
||||
)
|
||||
if quick_all in ("", "y", "yes"):
|
||||
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
||||
else:
|
||||
|
|
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
|
|||
|
||||
def _interactive_build_batch_inspect_args() -> List[str]:
|
||||
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
||||
path = _interactive_choose_openie_file(
|
||||
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
|
||||
)
|
||||
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
|
||||
if not path:
|
||||
return []
|
||||
return ["--openie-file", path]
|
||||
|
|
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
|
|||
|
||||
def _interactive_build_test_args() -> List[str]:
|
||||
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
||||
print(
|
||||
"\n[TEST] 你可以:\n"
|
||||
"- 直接回车使用内置的默认测试用例;\n"
|
||||
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
|
||||
)
|
||||
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
|
||||
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
||||
if not query:
|
||||
return []
|
||||
|
|
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
|
|||
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
||||
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
||||
|
||||
new_dim = input(
|
||||
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
|
||||
).strip()
|
||||
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
|
||||
if new_dim and not new_dim.isdigit():
|
||||
print("输入的维度不是纯数字,已取消操作。")
|
||||
return
|
||||
|
|
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from maim_message import UserInfo, GroupInfo
|
|||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
|
|
@ -44,7 +45,7 @@ def _import_memory_retrieval():
|
|||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
|
|
@ -54,14 +55,14 @@ def _import_memory_retrieval():
|
|||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
|
|
@ -75,6 +76,7 @@ def _import_memory_retrieval():
|
|||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
|
|
@ -253,7 +255,7 @@ async def test_memory_retrieval(
|
|||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[测试] 记忆检索测试")
|
||||
print("[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
|
|
@ -267,6 +269,7 @@ async def test_memory_retrieval(
|
|||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
|
|
@ -284,7 +287,7 @@ async def test_memory_retrieval(
|
|||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print(f"\n[配置]")
|
||||
print("\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
|
@ -321,26 +324,26 @@ async def test_memory_retrieval(
|
|||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print(f"\n[结果]")
|
||||
print("\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(f" 答案: (未找到答案)")
|
||||
print(" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print(f"\n[Token使用情况]")
|
||||
print("\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage['model_usage']:
|
||||
print(f"\n[按模型统计]")
|
||||
for model_name, usage in token_usage['model_usage'].items():
|
||||
if token_usage["model_usage"]:
|
||||
print("\n[按模型统计]")
|
||||
for model_name, usage in token_usage["model_usage"].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
|
|
@ -348,7 +351,7 @@ async def test_memory_retrieval(
|
|||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print(f"\n[迭代详情]")
|
||||
print("\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
|
@ -444,4 +447,3 @@ def main() -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -455,6 +455,7 @@ class ExpressionSelector:
|
|||
expr_obj.save()
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
|
|||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
class JargonExplainer:
|
||||
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import time
|
|||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
|
@ -119,9 +118,7 @@ class MessageRecorder:
|
|||
|
||||
# 触发 expression_learner 和 jargon_miner 的处理
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(messages)
|
||||
)
|
||||
asyncio.create_task(self._trigger_expression_learning(messages))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
|
|
@ -130,9 +127,7 @@ class MessageRecorder:
|
|||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self, messages: List[Any]
|
||||
) -> None:
|
||||
async def _trigger_expression_learning(self, messages: List[Any]) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import time
|
||||
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
|
|
@ -170,13 +170,10 @@ class ActionPlanner:
|
|||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
|
||||
)
|
||||
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class Conversation:
|
|||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||
"platform": msg.user_info.platform if msg.user_info else "",
|
||||
}
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from src.common.logger import get_logger
|
|||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
|
|
|
|||
|
|
@ -42,9 +42,7 @@ class GoalAnalyzer:
|
|||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import List, Tuple, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import Message
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge import qa_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
|
||||
|
|
@ -16,9 +16,7 @@ class KnowledgeFetcher:
|
|||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
|
|
|
|||
|
|
@ -14,10 +14,7 @@ class ReplyChecker:
|
|||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="reply_check"
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
|
|
|
|||
|
|
@ -704,10 +704,7 @@ class BrainChatting:
|
|||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._new_message_event.wait(),
|
||||
timeout=wait_seconds
|
||||
)
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -731,7 +728,9 @@ class BrainChatting:
|
|||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)"
|
||||
)
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
|
@ -749,10 +748,7 @@ class BrainChatting:
|
|||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._new_message_event.wait(),
|
||||
timeout=wait_seconds
|
||||
)
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
|
|
|
|||
|
|
@ -431,15 +431,21 @@ class BrainPlanner:
|
|||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return extracted_reasoning, [
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
return (
|
||||
extracted_reasoning,
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union, Dict, Any
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
|
|
@ -200,9 +200,7 @@ class IEProcess:
|
|||
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
|
||||
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
|
||||
try:
|
||||
entities, triples = await asyncio.to_thread(
|
||||
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
|
||||
)
|
||||
entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
|
||||
|
||||
if entities is not None:
|
||||
results.append(
|
||||
|
|
|
|||
|
|
@ -395,8 +395,7 @@ class KGManager:
|
|||
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
||||
if not appear_cnt or appear_cnt <= 0:
|
||||
logger.debug(
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,"
|
||||
f"将使用 1.0 作为默认出现次数参与权重计算"
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,将使用 1.0 作为默认出现次数参与权重计算"
|
||||
)
|
||||
appear_cnt = 1.0
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
|||
|
||||
logger = get_logger("LPMM-Plugin-API")
|
||||
|
||||
|
||||
class LPMMOperations:
|
||||
"""
|
||||
LPMM 内部操作接口。
|
||||
|
|
@ -20,9 +21,7 @@ class LPMMOperations:
|
|||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
async def _run_cancellable_executor(
|
||||
self, func: Callable, *args, **kwargs
|
||||
) -> Any:
|
||||
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
在线程池中执行可取消的同步操作。
|
||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
||||
|
|
@ -79,7 +78,7 @@ class LPMMOperations:
|
|||
# 1. 分段处理
|
||||
if auto_split:
|
||||
# 自动按双换行符分割
|
||||
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
else:
|
||||
# 不分割,作为完整一段
|
||||
text_stripped = text.strip()
|
||||
|
|
@ -95,7 +94,9 @@ class LPMMOperations:
|
|||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
|
||||
llm_ner = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
|
|
@ -128,25 +129,21 @@ class LPMMOperations:
|
|||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
||||
# store_new_data_set 会自动处理嵌入生成和存储
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
await self._run_cancellable_executor(
|
||||
embed_mgr.store_new_data_set,
|
||||
new_raw_paragraphs,
|
||||
new_triple_list_data
|
||||
)
|
||||
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
||||
|
||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
||||
await self._run_cancellable_executor(
|
||||
kg_mgr.build_kg,
|
||||
new_triple_list_data,
|
||||
embed_mgr
|
||||
)
|
||||
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
||||
|
||||
# 4. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(new_raw_paragraphs),
|
||||
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
||||
|
|
@ -215,8 +212,7 @@ class LPMMOperations:
|
|||
|
||||
# a. 从向量库删除
|
||||
deleted_count, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
||||
to_delete_keys
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
|
||||
|
|
@ -224,10 +220,7 @@ class LPMMOperations:
|
|||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs,
|
||||
to_delete_hashes,
|
||||
ent_hashes=None,
|
||||
remove_orphan_entities=True
|
||||
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
|
|
@ -237,7 +230,11 @@ class LPMMOperations:
|
|||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
||||
|
|
@ -275,16 +272,14 @@ class LPMMOperations:
|
|||
|
||||
# 删除所有段落向量
|
||||
para_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
||||
para_keys
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes.clear()
|
||||
|
||||
# 删除所有实体向量
|
||||
if ent_keys:
|
||||
ent_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.entities_embedding_store.delete_items,
|
||||
ent_keys
|
||||
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
||||
)
|
||||
else:
|
||||
ent_deleted = 0
|
||||
|
|
@ -292,8 +287,7 @@ class LPMMOperations:
|
|||
# 删除所有关系向量
|
||||
if rel_keys:
|
||||
rel_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.relation_embedding_store.delete_items,
|
||||
rel_keys
|
||||
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
||||
)
|
||||
else:
|
||||
rel_deleted = 0
|
||||
|
|
@ -341,15 +335,13 @@ class LPMMOperations:
|
|||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs,
|
||||
all_pg_hashes,
|
||||
ent_hashes=None,
|
||||
remove_orphan_entities=True
|
||||
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
||||
from quick_algo import di_graph
|
||||
|
||||
kg_mgr.graph = di_graph.DiGraph()
|
||||
kg_mgr.stored_paragraph_hashes.clear()
|
||||
kg_mgr.ent_appear_cnt.clear()
|
||||
|
|
@ -373,7 +365,7 @@ class LPMMOperations:
|
|||
"stats": {
|
||||
"before": before_stats,
|
||||
"after": after_stats,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
|
|
@ -383,6 +375,6 @@ class LPMMOperations:
|
|||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# 内部使用的单例
|
||||
lpmm_ops = LPMMOperations()
|
||||
|
||||
|
|
|
|||
|
|
@ -136,4 +136,3 @@ class PlanReplyLogger:
|
|||
return str(value)
|
||||
# Fallback to string for other complex types
|
||||
return str(value)
|
||||
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
|||
|
||||
# 如果未开启 API Server,直接跳过 Fallback
|
||||
if not global_config.maim_message.enable_api_server:
|
||||
logger.debug(f"[API Server Fallback] API Server未开启,跳过fallback")
|
||||
logger.debug("[API Server Fallback] API Server未开启,跳过fallback")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
|
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
|||
extra_server = getattr(global_api, "extra_server", None)
|
||||
|
||||
if not extra_server:
|
||||
logger.warning(f"[API Server Fallback] extra_server不存在")
|
||||
logger.warning("[API Server Fallback] extra_server不存在")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
if not extra_server.is_running():
|
||||
logger.warning(f"[API Server Fallback] extra_server未运行")
|
||||
logger.warning("[API Server Fallback] extra_server未运行")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
|
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
|||
)
|
||||
|
||||
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...")
|
||||
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
|
||||
results = await extra_server.send_message(api_message)
|
||||
logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ logger = get_logger("planner")
|
|||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
|
|
@ -111,20 +112,29 @@ class ActionPlanner:
|
|||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
|
||||
platform = (
|
||||
getattr(message, "user_info", None)
|
||||
and message.user_info.platform
|
||||
or getattr(message, "chat_info", None)
|
||||
and message.chat_info.platform
|
||||
or "qq"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
|
|
@ -137,6 +147,7 @@ class ActionPlanner:
|
|||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
|
|
@ -306,9 +317,7 @@ class ActionPlanner:
|
|||
|
||||
return merged_words
|
||||
|
||||
def _process_unknown_words_cache(
|
||||
self, actions: List[ActionPlannerInfo]
|
||||
) -> None:
|
||||
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
|
||||
"""
|
||||
处理黑话缓存逻辑:
|
||||
1. 检查是否有 reply action 提取了 unknown_words
|
||||
|
|
@ -442,7 +451,11 @@ class ActionPlanner:
|
|||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
|
||||
if (
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
):
|
||||
has_reply_to_force_message = True
|
||||
break
|
||||
|
||||
|
|
@ -577,10 +590,11 @@ class ActionPlanner:
|
|||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = ""
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
"5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]'
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
|
|
@ -590,7 +604,9 @@ class ActionPlanner:
|
|||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
"6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
|
|
@ -741,15 +757,21 @@ class ActionPlanner:
|
|||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return f"LLM 请求失败,模型出现问题: {req_e}", [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
return (
|
||||
f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
|
|
|
|||
|
|
@ -1071,7 +1071,6 @@ class DefaultReplyer:
|
|||
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
|
||||
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
|
||||
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = global_config.personality.multiple_reply_style
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
|
|||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from src.common.logger import get_logger
|
|||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
|
||||
class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
|
|||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
self.session_id = session_id
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
|
|
|||
|
|
@ -96,8 +96,12 @@ class Images(SQLModel, table=True):
|
|||
|
||||
no_file_flag: bool = Field(default=False) # 文件不存在标记,如果为True表示文件已经不存在,仅保留描述字段
|
||||
|
||||
record_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间(数据库记录被创建的时间)
|
||||
register_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 注册时间(被注册为可用表情包的时间)
|
||||
record_time: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 记录时间(数据库记录被创建的时间)
|
||||
register_time: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(DateTime, nullable=True)
|
||||
) # 注册时间(被注册为可用表情包的时间)
|
||||
last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间
|
||||
|
||||
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
||||
|
|
@ -171,7 +175,9 @@ class Expression(SQLModel, table=True):
|
|||
|
||||
content_list: str # 内容列表,JSON格式存储
|
||||
count: int = Field(default=0) # 使用次数
|
||||
last_active_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 上次使用时间
|
||||
last_active_time: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 上次使用时间
|
||||
create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间
|
||||
session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID,区分是否为全局表达方式
|
||||
|
||||
|
|
@ -232,8 +238,12 @@ class ThinkingQuestion(SQLModel, table=True):
|
|||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
||||
|
||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
||||
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
|
||||
updated_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后更新时间
|
||||
created_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 创建时间
|
||||
updated_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 最后更新时间
|
||||
|
||||
|
||||
class BinaryData(SQLModel, table=True):
|
||||
|
|
@ -272,7 +282,9 @@ class PersonInfo(SQLModel, table=True):
|
|||
|
||||
# 认识次数和时间
|
||||
know_counts: int = Field(default=0) # 认识次数
|
||||
first_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 首次认识时间
|
||||
first_known_time: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(DateTime, nullable=True)
|
||||
) # 首次认识时间
|
||||
last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间
|
||||
|
||||
|
||||
|
|
@ -285,8 +297,12 @@ class ChatSession(SQLModel, table=True):
|
|||
|
||||
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
|
||||
|
||||
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
|
||||
last_active_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后活跃时间
|
||||
created_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 创建时间
|
||||
last_active_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 最后活跃时间
|
||||
|
||||
# 身份元数据
|
||||
user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID
|
||||
|
|
|
|||
|
|
@ -221,5 +221,7 @@ if not supports_truecolor():
|
|||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
else:
|
||||
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
|
||||
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold)
|
||||
escape_str = rgb_pair_to_ansi_truecolor(
|
||||
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
|
||||
)
|
||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
|
|
@ -9,6 +9,7 @@ from .server import get_global_server
|
|||
|
||||
global_api = None
|
||||
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
global global_api
|
||||
|
|
@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
|
||||
return False
|
||||
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3. Setup Message Bridge
|
||||
# Initialize refined route map if not exists
|
||||
if not hasattr(global_api, "platform_map"):
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
|
||||
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
|
||||
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
|
||||
|
|
@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
|
||||
|
||||
if platform:
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to update platform map: {e}")
|
||||
|
|
@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
if "raw_message" not in msg_dict:
|
||||
msg_dict["raw_message"] = None
|
||||
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3.5. Register custom message handlers (bridge to Legacy handlers)
|
||||
# message_id_echo: handles message ID echo from adapters
|
||||
# 兼容新旧两个版本的 maim_message:
|
||||
# - 旧版: handler(payload)
|
||||
# - 新版: handler(payload, metadata)
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
# Bridge to the Legacy custom handler registered in main.py
|
||||
try:
|
||||
# The Legacy handler expects the payload format directly
|
||||
if hasattr(global_api, "_custom_message_handlers"):
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
if handler:
|
||||
await handler(payload)
|
||||
api_logger.debug(f"Processed message_id_echo: {payload}")
|
||||
|
|
@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
except Exception as e:
|
||||
api_logger.warning(f"Failed to process message_id_echo: {e}")
|
||||
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 4. Initialize Server
|
||||
extra_server = WebSocketServer(config=server_config)
|
||||
|
|
@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
global_api.stop = patched_stop
|
||||
|
||||
# Attach for reference
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
|
||||
except ImportError:
|
||||
get_logger("maim_message").error(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
|
|||
|
||||
logger = get_logger("file_utils")
|
||||
|
||||
|
||||
class FileUtils:
|
||||
@staticmethod
|
||||
def save_binary_to_file(file_path: Path, data: bytes):
|
||||
|
|
|
|||
|
|
@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
|||
|
||||
reason = ",".join(reasons)
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,8 +54,6 @@ async def generate_dream_summary(
|
|||
) -> None:
|
||||
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
||||
try:
|
||||
|
||||
|
||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||
tool_results_map: dict[str, str] = {}
|
||||
for msg in conversation_messages:
|
||||
|
|
|
|||
|
|
@ -4,4 +4,3 @@ dream agent 工具实现模块。
|
|||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
|
|||
return f"create_chat_history 执行失败: {e}"
|
||||
|
||||
return create_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||
return f"delete_chat_history 执行失败: {e}"
|
||||
|
||||
return delete_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||
return f"delete_jargon 执行失败: {e}"
|
||||
|
||||
return delete_jargon
|
||||
|
||||
|
|
|
|||
|
|
@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
|||
return msg
|
||||
|
||||
return finish_maintenance
|
||||
|
||||
|
|
|
|||
|
|
@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
|||
return f"get_chat_history_detail 执行失败: {e}"
|
||||
|
||||
return get_chat_history_detail
|
||||
|
||||
|
|
|
|||
|
|
@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
|
|||
return f"search_chat_history 执行失败: {e}"
|
||||
|
||||
return search_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||
return f"update_chat_history 执行失败: {e}"
|
||||
|
||||
return update_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||
return f"update_jargon 执行失败: {e}"
|
||||
|
||||
return update_jargon
|
||||
|
||||
|
|
|
|||
|
|
@ -459,7 +459,7 @@ def _default_normal_response_parser(
|
|||
# 此时为了调试方便,建议打印出 arguments 的类型
|
||||
raise RespParseException(
|
||||
resp,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
|
||||
)
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||
except json.JSONDecodeError as e:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import time
|
|||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ async def search_chat_history(
|
|||
if start_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_time)
|
||||
except ValueError as e:
|
||||
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
|
@ -120,6 +121,7 @@ async def search_chat_history(
|
|||
if end_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
|
||||
end_timestamp = parse_datetime_to_timestamp(end_time)
|
||||
except ValueError as e:
|
||||
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
|
@ -165,16 +167,13 @@ async def search_chat_history(
|
|||
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
||||
query = query.where(
|
||||
(
|
||||
(ChatHistory.start_time >= start_timestamp)
|
||||
& (ChatHistory.start_time <= end_timestamp)
|
||||
(ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
|
||||
) # 记录开始时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.end_time >= start_timestamp)
|
||||
& (ChatHistory.end_time <= end_timestamp)
|
||||
(ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
|
||||
) # 记录结束时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.start_time <= start_timestamp)
|
||||
& (ChatHistory.end_time >= end_timestamp)
|
||||
(ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
|
||||
) # 记录完全包含查询时间段
|
||||
)
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -76,4 +76,3 @@ def register_tool():
|
|||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -558,7 +558,9 @@ class PluginBase(ABC):
|
|||
if version_spec:
|
||||
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
||||
if not is_ok:
|
||||
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
|
||||
logger.error(
|
||||
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
|
||||
)
|
||||
return False
|
||||
|
||||
if min_version or max_version:
|
||||
|
|
|
|||
|
|
@ -751,9 +751,7 @@ class ComponentRegistry:
|
|||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
"workflow_steps": workflow_step_count,
|
||||
"enabled_workflow_steps": enabled_workflow_step_count,
|
||||
"workflow_steps_by_stage": {
|
||||
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
|
||||
},
|
||||
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -429,7 +429,9 @@ class PluginManager:
|
|||
|
||||
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
||||
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
||||
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
|
||||
indegree: Dict[str, int] = {
|
||||
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
|
||||
}
|
||||
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
||||
|
||||
for plugin_name, dependencies in dependency_graph.items():
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ class PluginServiceRegistry:
|
|||
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||
return self._service_handlers.get(full_name) if full_name else None
|
||||
|
||||
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
||||
def list_services(
|
||||
self, plugin_name: Optional[str] = None, enabled_only: bool = False
|
||||
) -> Dict[str, PluginServiceInfo]:
|
||||
"""列出插件服务。"""
|
||||
services = self._services.copy()
|
||||
if plugin_name:
|
||||
|
|
@ -120,7 +122,12 @@ class PluginServiceRegistry:
|
|||
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
||||
raise ValueError(f"插件服务未注册: {target_name}")
|
||||
|
||||
if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin:
|
||||
if (
|
||||
"." not in service_name
|
||||
and plugin_name is None
|
||||
and caller_plugin
|
||||
and service_info.plugin_name != caller_plugin
|
||||
):
|
||||
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
|
||||
|
||||
if not self._is_call_authorized(service_info, caller_plugin):
|
||||
|
|
@ -153,7 +160,9 @@ class PluginServiceRegistry:
|
|||
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
|
||||
return "*" in allowed_callers or caller_plugin in allowed_callers
|
||||
|
||||
def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
|
||||
def _validate_input_contract(
|
||||
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""校验服务入参契约。"""
|
||||
schema = service_info.params_schema
|
||||
if not schema:
|
||||
|
|
|
|||
|
|
@ -96,7 +96,9 @@ class WorkflowEngine:
|
|||
except Exception as e:
|
||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
||||
workflow_context.errors.append(f"{stage_key}: {e}")
|
||||
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
|
||||
)
|
||||
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||
return (
|
||||
|
|
@ -195,7 +197,9 @@ class WorkflowEngine:
|
|||
except Exception as e:
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
context.errors.append(f"{step_info.full_name}: {e}")
|
||||
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
|
||||
)
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
|
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
|
|||
|
||||
class ChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
|
||||
chat_id: str
|
||||
plan_count: int
|
||||
latest_timestamp: float
|
||||
|
|
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
|
|||
|
||||
class PlanLogSummary(BaseModel):
|
||||
"""规划日志摘要"""
|
||||
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
|
|
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
|
|||
|
||||
class PlanLogDetail(BaseModel):
|
||||
"""规划日志详情"""
|
||||
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
|
|
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
|
|||
|
||||
class PlannerOverview(BaseModel):
|
||||
"""规划器总览 - 轻量级统计"""
|
||||
|
||||
total_chats: int
|
||||
total_plans: int
|
||||
chats: List[ChatSummary]
|
||||
|
|
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
|
|||
|
||||
class PaginatedChatLogs(BaseModel):
|
||||
"""分页的聊天日志列表"""
|
||||
|
||||
data: List[PlanLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
|
|||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
timestamp_str = filename.split("_")[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
|
|
@ -106,21 +112,19 @@ async def get_planner_overview():
|
|||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
chats.append(
|
||||
ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name,
|
||||
)
|
||||
)
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return PlannerOverview(
|
||||
total_chats=len(chats),
|
||||
total_plans=total_plans,
|
||||
chats=chats
|
||||
)
|
||||
return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||
|
|
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
|
|||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||
):
|
||||
"""
|
||||
获取指定聊天的规划日志列表(分页)
|
||||
|
|
@ -137,9 +141,7 @@ async def get_chat_plan_logs(
|
|||
"""
|
||||
chat_dir = PLAN_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedChatLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
|
|
@ -151,9 +153,9 @@ async def get_chat_plan_logs(
|
|||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
prompt = data.get("prompt", "")
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
|
|
@ -164,46 +166,44 @@ async def get_chat_plan_logs(
|
|||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
page_files = json_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
actions = data.get('actions', [])
|
||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
||||
))
|
||||
reasoning = data.get("reasoning", "")
|
||||
actions = data.get("actions", [])
|
||||
action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
|
||||
logs.append(
|
||||
PlanLogSummary(
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
|
||||
llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else "",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview='[读取失败]'
|
||||
))
|
||||
logs.append(
|
||||
PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview="[读取失败]",
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedChatLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||
|
|
@ -214,7 +214,7 @@ async def get_log_detail(chat_id: str, filename: str):
|
|||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return PlanLogDetail(**data)
|
||||
except Exception as e:
|
||||
|
|
@ -223,6 +223,7 @@ async def get_log_detail(chat_id: str, filename: str):
|
|||
|
||||
# ========== 兼容旧接口 ==========
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_planner_stats():
|
||||
"""获取规划器统计信息 - 兼容旧接口"""
|
||||
|
|
@ -246,7 +247,7 @@ async def get_planner_stats():
|
|||
"total_plans": overview.total_plans,
|
||||
"avg_plan_time_ms": 0,
|
||||
"avg_llm_time_ms": 0,
|
||||
"recent_plans": recent_plans
|
||||
"recent_plans": recent_plans,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -258,10 +259,7 @@ async def get_chat_list():
|
|||
|
||||
|
||||
@router.get("/all-logs")
|
||||
async def get_all_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100)
|
||||
):
|
||||
async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
|
||||
"""获取所有规划日志 - 兼容旧接口"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
|
@ -278,23 +276,25 @@ async def get_all_logs(
|
|||
|
||||
total = len(all_files)
|
||||
offset = (page - 1) * page_size
|
||||
page_files = all_files[offset:offset + page_size]
|
||||
page_files = all_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for chat_id, log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
logs.append({
|
||||
"chat_id": data.get('chat_id', chat_id),
|
||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get('actions', [])),
|
||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
||||
})
|
||||
reasoning = data.get("reasoning", "")
|
||||
logs.append(
|
||||
{
|
||||
"chat_id": data.get("chat_id", chat_id),
|
||||
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get("actions", [])),
|
||||
"total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
|
||||
"llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else "",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
|
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
|
|||
|
||||
class ReplierChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
|
||||
chat_id: str
|
||||
reply_count: int
|
||||
latest_timestamp: float
|
||||
|
|
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
|
|||
|
||||
class ReplyLogSummary(BaseModel):
|
||||
"""回复日志摘要"""
|
||||
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
|
|
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
|
|||
|
||||
class ReplyLogDetail(BaseModel):
|
||||
"""回复日志详情"""
|
||||
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
|
|
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
|
|||
|
||||
class ReplierOverview(BaseModel):
|
||||
"""回复器总览 - 轻量级统计"""
|
||||
|
||||
total_chats: int
|
||||
total_replies: int
|
||||
chats: List[ReplierChatSummary]
|
||||
|
|
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
|
|||
|
||||
class PaginatedReplyLogs(BaseModel):
|
||||
"""分页的回复日志列表"""
|
||||
|
||||
data: List[ReplyLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
|
|||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
timestamp_str = filename.split("_")[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
|
|
@ -109,21 +115,19 @@ async def get_replier_overview():
|
|||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
chats.append(
|
||||
ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name,
|
||||
)
|
||||
)
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return ReplierOverview(
|
||||
total_chats=len(chats),
|
||||
total_replies=total_replies,
|
||||
chats=chats
|
||||
)
|
||||
return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||
|
|
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
|
|||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||
):
|
||||
"""
|
||||
获取指定聊天的回复日志列表(分页)
|
||||
|
|
@ -140,9 +144,7 @@ async def get_chat_reply_logs(
|
|||
"""
|
||||
chat_dir = REPLY_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedReplyLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
|
|
@ -154,9 +156,9 @@ async def get_chat_reply_logs(
|
|||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
prompt = data.get("prompt", "")
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
|
|
@ -167,44 +169,42 @@ async def get_chat_reply_logs(
|
|||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
page_files = json_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
output = data.get('output', '')
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get('model', ''),
|
||||
success=data.get('success', True),
|
||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
||||
output_preview=output[:100] if output else ''
|
||||
))
|
||||
output = data.get("output", "")
|
||||
logs.append(
|
||||
ReplyLogSummary(
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get("model", ""),
|
||||
success=data.get("success", True),
|
||||
llm_ms=data.get("timing", {}).get("llm_ms", 0),
|
||||
overall_ms=data.get("timing", {}).get("overall_ms", 0),
|
||||
output_preview=output[:100] if output else "",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model='',
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview='[读取失败]'
|
||||
))
|
||||
logs.append(
|
||||
ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model="",
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview="[读取失败]",
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedReplyLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||
|
|
@ -215,21 +215,21 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
|||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return ReplyLogDetail(
|
||||
type=data.get('type', 'reply'),
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', 0),
|
||||
prompt=data.get('prompt', ''),
|
||||
output=data.get('output', ''),
|
||||
processed_output=data.get('processed_output', []),
|
||||
model=data.get('model', ''),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
think_level=data.get('think_level', 0),
|
||||
timing=data.get('timing', {}),
|
||||
error=data.get('error'),
|
||||
success=data.get('success', True)
|
||||
type=data.get("type", "reply"),
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", 0),
|
||||
prompt=data.get("prompt", ""),
|
||||
output=data.get("output", ""),
|
||||
processed_output=data.get("processed_output", []),
|
||||
model=data.get("model", ""),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
think_level=data.get("think_level", 0),
|
||||
timing=data.get("timing", {}),
|
||||
error=data.get("error"),
|
||||
success=data.get("success", True),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
|
@ -237,6 +237,7 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
|||
|
||||
# ========== 兼容接口 ==========
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_replier_stats():
|
||||
"""获取回复器统计信息"""
|
||||
|
|
@ -258,7 +259,7 @@ async def get_replier_stats():
|
|||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_replies": overview.total_replies,
|
||||
"recent_replies": recent_replies
|
||||
"recent_replies": recent_replies,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
from fastapi import Depends, Cookie, Header, Request, HTTPException
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
|
||||
from fastapi import Depends, Cookie, Header, Request
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit
|
||||
|
||||
|
||||
async def require_auth(
|
||||
|
|
|
|||
|
|
@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
|||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
|
|
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
|||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
"mode": global_config.webui.anti_crawler_mode,
|
||||
"allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
"trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
"trust_xff": global_config.webui.trust_xff,
|
||||
}
|
||||
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
ANTI_CRAWLER_MODE = _config["mode"]
|
||||
ALLOWED_IPS = _config["allowed_ips"]
|
||||
TRUSTED_PROXIES = _config["trusted_proxies"]
|
||||
TRUST_XFF = _config["trust_xff"]
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def _get_paragraph_store():
|
|||
namespace="paragraph",
|
||||
dir_path=embedding_dir,
|
||||
max_workers=1, # 只读不需要多线程
|
||||
chunk_size=100
|
||||
chunk_size=100,
|
||||
)
|
||||
paragraph_store.load_from_file()
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
|||
paragraph_item = paragraph_store.store.get(node_id)
|
||||
if paragraph_item is not None:
|
||||
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
||||
content: str = getattr(paragraph_item, 'str', '')
|
||||
content: str = getattr(paragraph_item, "str", "")
|
||||
if content:
|
||||
return content, True
|
||||
return None, True
|
||||
|
|
@ -160,7 +160,11 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
|||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
|
@ -249,7 +253,11 @@ async def get_knowledge_graph(
|
|||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type_val == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
|
@ -372,7 +380,11 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
|
|||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue