Ruff Format

pull/1496/head
DrSmoothl 2026-02-21 16:24:24 +08:00
parent 2cb512120b
commit eaef7f0e98
82 changed files with 1881 additions and 1900 deletions

1
bot.py
View File

@ -50,6 +50,7 @@ print("警告Dev进入不稳定开发状态任何插件与WebUI均可能
print("\n\n\n\n\n")
print("-----------------------------------------")
def run_runner_process():
"""
Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程

View File

@ -1,2 +1 @@
"""Core helpers for MCP Bridge Plugin."""

View File

@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
if not mcp_servers:
return ""
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)

View File

@ -34,28 +34,31 @@ from enum import Enum
# 尝试导入 MaiBot 的 logger如果失败则使用标准 logging
try:
from src.common.logger import get_logger
logger = get_logger("mcp_client")
except ImportError:
# Fallback: 使用标准 logging
logger = logging.getLogger("mcp_client")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s'))
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class TransportType(Enum):
"""MCP 传输类型"""
STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐)
STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐)
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
@dataclass
class MCPToolInfo:
"""MCP 工具信息"""
name: str
description: str
input_schema: Dict[str, Any]
@ -65,6 +68,7 @@ class MCPToolInfo:
@dataclass
class MCPResourceInfo:
"""MCP 资源信息"""
uri: str
name: str
description: str
@ -75,6 +79,7 @@ class MCPResourceInfo:
@dataclass
class MCPPromptInfo:
"""MCP 提示模板信息"""
name: str
description: str
arguments: List[Dict[str, Any]] # [{name, description, required}]
@ -84,6 +89,7 @@ class MCPPromptInfo:
@dataclass
class MCPServerConfig:
"""MCP 服务器配置"""
name: str
enabled: bool = True
transport: TransportType = TransportType.STDIO
@ -99,6 +105,7 @@ class MCPServerConfig:
@dataclass
class MCPCallResult:
"""MCP 工具调用结果"""
success: bool
content: Any
error: Optional[str] = None
@ -108,8 +115,9 @@ class MCPCallResult:
class CircuitState(Enum):
"""断路器状态"""
CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求
CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
@ -125,9 +133,9 @@ class CircuitBreaker:
"""
# 配置
failure_threshold: int = 5 # 连续失败多少次后熔断
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
failure_threshold: int = 5 # 连续失败多少次后熔断
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
# 状态
state: CircuitState = field(default=CircuitState.CLOSED)
@ -232,6 +240,7 @@ class CircuitBreaker:
@dataclass
class ToolCallStats:
"""工具调用统计"""
tool_key: str
total_calls: int = 0
success_calls: int = 0
@ -282,6 +291,7 @@ class ToolCallStats:
@dataclass
class ServerStats:
"""服务器统计"""
server_name: str
connect_count: int = 0 # 连接次数
disconnect_count: int = 0 # 断开次数
@ -442,9 +452,7 @@ class MCPClientSession:
return False
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args,
env=self.config.env if self.config.env else None
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
)
self._stdio_context = stdio_client(server_params)
@ -506,6 +514,7 @@ class MCPClientSession:
except Exception as e:
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup()
return False
@ -551,6 +560,7 @@ class MCPClientSession:
except Exception as e:
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup()
return False
@ -568,8 +578,8 @@ class MCPClientSession:
tool_info = MCPToolInfo(
name=tool.name,
description=tool.description or f"MCP tool: {tool.name}",
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {},
server_name=self.server_name
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
server_name=self.server_name,
)
self._tools.append(tool_info)
# 初始化工具统计
@ -591,10 +601,7 @@ class MCPClientSession:
return False
try:
result = await asyncio.wait_for(
self._session.list_resources(),
timeout=self.call_timeout
)
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
self._resources = []
for resource in result.resources:
@ -602,8 +609,8 @@ class MCPClientSession:
uri=str(resource.uri),
name=resource.name or str(resource.uri),
description=resource.description or "",
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None,
server_name=self.server_name
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
server_name=self.server_name,
)
self._resources.append(resource_info)
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
@ -633,28 +640,27 @@ class MCPClientSession:
return False
try:
result = await asyncio.wait_for(
self._session.list_prompts(),
timeout=self.call_timeout
)
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
self._prompts = []
for prompt in result.prompts:
# 解析参数
arguments = []
if hasattr(prompt, 'arguments') and prompt.arguments:
if hasattr(prompt, "arguments") and prompt.arguments:
for arg in prompt.arguments:
arguments.append({
"name": arg.name,
"description": arg.description or "",
"required": arg.required if hasattr(arg, 'required') else False,
})
arguments.append(
{
"name": arg.name,
"description": arg.description or "",
"required": arg.required if hasattr(arg, "required") else False,
}
)
prompt_info = MCPPromptInfo(
name=prompt.name,
description=prompt.description or f"MCP prompt: {prompt.name}",
arguments=arguments,
server_name=self.server_name
server_name=self.server_name,
)
self._prompts.append(prompt_info)
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
@ -686,35 +692,25 @@ class MCPClientSession:
start_time = time.time()
if not self._connected or not self._session:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
if not self._supports_resources:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Resources 功能"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
try:
result = await asyncio.wait_for(
self._session.read_resource(uri),
timeout=self.call_timeout
)
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
duration_ms = (time.time() - start_time) * 1000
# 处理返回内容
content_parts = []
for content in result.contents:
if hasattr(content, 'text'):
if hasattr(content, "text"):
content_parts.append(content.text)
elif hasattr(content, 'blob'):
elif hasattr(content, "blob"):
# 二进制数据,返回 base64 或提示
import base64
blob_data = content.blob
if len(blob_data) < 10000: # 小于 10KB 返回 base64
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
@ -724,28 +720,18 @@ class MCPClientSession:
content_parts.append(str(content))
return MCPCallResult(
success=True,
content="\n".join(content_parts) if content_parts else "",
duration_ms=duration_ms
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
return MCPCallResult(
success=False,
content=None,
error=f"读取资源超时({self.call_timeout}秒)",
duration_ms=duration_ms
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
return MCPCallResult(
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
"""v1.2.0: 获取提示模板的内容
@ -760,23 +746,14 @@ class MCPClientSession:
start_time = time.time()
if not self._connected or not self._session:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
if not self._supports_prompts:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
try:
result = await asyncio.wait_for(
self._session.get_prompt(name, arguments=arguments or {}),
timeout=self.call_timeout
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000
@ -784,10 +761,10 @@ class MCPClientSession:
# 处理返回的消息
messages = []
for msg in result.messages:
role = msg.role if hasattr(msg, 'role') else "unknown"
role = msg.role if hasattr(msg, "role") else "unknown"
content_text = ""
if hasattr(msg, 'content'):
if hasattr(msg.content, 'text'):
if hasattr(msg, "content"):
if hasattr(msg.content, "text"):
content_text = msg.content.text
elif isinstance(msg.content, str):
content_text = msg.content
@ -796,28 +773,18 @@ class MCPClientSession:
messages.append(f"[{role}]: {content_text}")
return MCPCallResult(
success=True,
content="\n\n".join(messages) if messages else "",
duration_ms=duration_ms
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
return MCPCallResult(
success=False,
content=None,
error=f"获取提示模板超时({self.call_timeout}秒)",
duration_ms=duration_ms
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
return MCPCallResult(
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
async def check_health(self) -> bool:
"""检查连接健康状态(心跳检测)
@ -829,10 +796,7 @@ class MCPClientSession:
try:
# 使用 list_tools 作为心跳检测
await asyncio.wait_for(
self._session.list_tools(),
timeout=10.0
)
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
self.stats.record_heartbeat()
return True
except Exception as e:
@ -849,12 +813,7 @@ class MCPClientSession:
# v1.7.0: 断路器检查
can_execute, reject_reason = self._circuit_breaker.can_execute()
if not can_execute:
return MCPCallResult(
success=False,
content=None,
error=f"{reject_reason}",
circuit_broken=True
)
return MCPCallResult(success=False, content=None, error=f"{reject_reason}", circuit_broken=True)
# 半开状态下增加试探计数
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
@ -870,8 +829,7 @@ class MCPClientSession:
try:
result = await asyncio.wait_for(
self._session.call_tool(tool_name, arguments=arguments),
timeout=self.call_timeout
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000
@ -879,9 +837,9 @@ class MCPClientSession:
# 处理返回内容
content_parts = []
for content in result.content:
if hasattr(content, 'text'):
if hasattr(content, "text"):
content_parts.append(content.text)
elif hasattr(content, 'data'):
elif hasattr(content, "data"):
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
else:
content_parts.append(str(content))
@ -896,7 +854,7 @@ class MCPClientSession:
return MCPCallResult(
success=True,
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
duration_ms=duration_ms
duration_ms=duration_ms,
)
except asyncio.TimeoutError:
@ -939,25 +897,25 @@ class MCPClientSession:
self._supports_prompts = False # v1.2.0
try:
if hasattr(self, '_session_context') and self._session_context:
if hasattr(self, "_session_context") and self._session_context:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
try:
if hasattr(self, '_stdio_context') and self._stdio_context:
if hasattr(self, "_stdio_context") and self._stdio_context:
await self._stdio_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
try:
if hasattr(self, '_http_context') and self._http_context:
if hasattr(self, "_http_context") and self._http_context:
await self._http_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
try:
if hasattr(self, '_sse_context') and self._sse_context:
if hasattr(self, "_sse_context") and self._sse_context:
await self._sse_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
@ -1082,7 +1040,9 @@ class MCPClientManager:
return True
if attempt < retry_attempts:
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
logger.warning(
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
)
await asyncio.sleep(retry_interval)
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
@ -1213,11 +1173,7 @@ class MCPClientManager:
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
"""调用 MCP 工具"""
if tool_key not in self._all_tools:
return MCPCallResult(
success=False,
content=None,
error=f"工具 {tool_key} 不存在"
)
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
tool_info, client = self._all_tools[tool_key]
@ -1273,11 +1229,7 @@ class MCPClientManager:
# 如果指定了服务器
if server_name:
if server_name not in self._clients:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
client = self._clients[server_name]
return await client.read_resource(uri)
@ -1293,14 +1245,11 @@ class MCPClientManager:
if result.success:
return result
return MCPCallResult(
success=False,
content=None,
error=f"未找到资源: {uri}"
)
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None,
server_name: Optional[str] = None) -> MCPCallResult:
async def get_prompt(
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
) -> MCPCallResult:
"""v1.2.0: 获取提示模板内容
Args:
@ -1311,11 +1260,7 @@ class MCPClientManager:
# 如果指定了服务器
if server_name:
if server_name not in self._clients:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
client = self._clients[server_name]
return await client.get_prompt(name, arguments)
@ -1324,11 +1269,7 @@ class MCPClientManager:
if prompt_info.name == name:
return await client.get_prompt(name, arguments)
return MCPCallResult(
success=False,
content=None,
error=f"未找到提示模板: {name}"
)
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
# ==================== 心跳检测 ====================
@ -1489,7 +1430,9 @@ class MCPClientManager:
"global": {
**self._global_stats,
"uptime_seconds": round(uptime, 2),
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0,
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
if uptime > 0
else 0,
},
"servers": server_stats,
"tools": tool_stats,

View File

@ -123,9 +123,11 @@ logger = get_logger("mcp_bridge_plugin")
# v1.4.0: 调用链路追踪
# ============================================================================
@dataclass
class ToolCallRecord:
"""工具调用记录"""
call_id: str
timestamp: float
tool_name: str
@ -208,9 +210,11 @@ tool_call_tracer = ToolCallTracer()
# v1.4.0: 工具调用缓存
# ============================================================================
@dataclass
class CacheEntry:
"""缓存条目"""
tool_name: str
args_hash: str
result: str
@ -347,6 +351,7 @@ tool_call_cache = ToolCallCache()
# v1.4.0: 工具权限控制
# ============================================================================
class PermissionChecker:
"""工具权限检查器"""
@ -479,6 +484,7 @@ permission_checker = PermissionChecker()
# 工具类型转换
# ============================================================================
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
type_mapping = {
@ -492,7 +498,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
return type_mapping.get(json_type, ToolParamType.STRING)
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
def parse_mcp_parameters(
input_schema: Dict[str, Any],
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
"""解析 MCP 工具的参数 schema转换为 MaiBot 的参数格式"""
parameters = []
@ -534,6 +542,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
# MCP 工具代理
# ============================================================================
class MCPToolProxy(BaseTool):
"""MCP 工具代理基类"""
@ -576,10 +585,7 @@ class MCPToolProxy(BaseTool):
# v1.4.0: 权限检查
if not permission_checker.check(self.name, chat_id, user_id, is_group):
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
return {
"name": self.name,
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
}
return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
@ -749,11 +755,7 @@ class MCPToolProxy(BaseTool):
return None
async def _call_post_process_llm(
self,
prompt: str,
max_tokens: int,
settings: Dict[str, Any],
server_config: Optional[Dict[str, Any]]
self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
) -> Optional[str]:
"""调用 LLM 进行后处理"""
from src.config.config import model_config
@ -811,10 +813,7 @@ class MCPToolProxy(BaseTool):
def create_mcp_tool_class(
tool_key: str,
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
) -> Type[MCPToolProxy]:
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
parameters = parse_mcp_parameters(tool_info.input_schema)
@ -837,7 +836,7 @@ def create_mcp_tool_class(
"_mcp_tool_key": tool_key,
"_mcp_original_name": tool_info.name,
"_mcp_server_name": tool_info.server_name,
}
},
)
return tool_class
@ -851,11 +850,7 @@ class MCPToolRegistry:
self._tool_infos: Dict[str, ToolInfo] = {}
def register_tool(
self,
tool_key: str,
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
"""注册 MCP 工具"""
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
@ -902,6 +897,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
# 内置工具
# ============================================================================
class MCPReadResourceTool(BaseTool):
"""v1.2.0: MCP 资源读取工具"""
@ -973,6 +969,7 @@ class MCPGetPromptTool(BaseTool):
# v1.8.0: 工具链代理工具
# ============================================================================
class ToolChainProxyBase(BaseTool):
"""工具链代理基类"""
@ -1037,7 +1034,7 @@ def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBa
"parameters": parameters,
"available_for_llm": True,
"_chain_name": chain.name,
}
},
)
return tool_class
@ -1095,7 +1092,13 @@ class MCPStatusTool(BaseTool):
name = "mcp_status"
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
parameters = [
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"]),
(
"query_type",
ToolParamType.STRING,
"查询类型",
False,
["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"],
),
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
]
available_for_llm = True
@ -1132,10 +1135,7 @@ class MCPStatusTool(BaseTool):
if query_type in ("cache",):
result_parts.append(self._format_cache())
return {
"name": self.name,
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
}
return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
def _format_status(self, server_name: Optional[str] = None) -> str:
status = mcp_manager.get_status()
@ -1147,14 +1147,14 @@ class MCPStatusTool(BaseTool):
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
lines.append("\n🔌 服务器详情:")
for name, info in status['servers'].items():
for name, info in status["servers"].items():
if server_name and name != server_name:
continue
status_icon = "" if info['connected'] else ""
enabled_text = "" if info['enabled'] else " (已禁用)"
status_icon = "" if info["connected"] else ""
enabled_text = "" if info["enabled"] else " (已禁用)"
lines.append(f" {status_icon} {name}{enabled_text}")
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
if info['consecutive_failures'] > 0:
if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']}")
return "\n".join(lines)
@ -1184,11 +1184,11 @@ class MCPStatusTool(BaseTool):
stats = mcp_manager.get_all_stats()
lines = ["📈 调用统计"]
g = stats['global']
g = stats["global"]
lines.append(f" 总调用次数: {g['total_tool_calls']}")
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
if g['total_tool_calls'] > 0:
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100
if g["total_tool_calls"] > 0:
success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
lines.append(f" 成功率: {success_rate:.1f}%")
lines.append(f" 运行时间: {g['uptime_seconds']:.0f}")
@ -1277,7 +1277,7 @@ class MCPStatusTool(BaseTool):
lines.append(f" 描述: {chain.description[:50]}...")
lines.append(f" 步骤: {len(chain.steps)}")
for i, step in enumerate(chain.steps[:3]):
lines.append(f" {i+1}. {step.tool_name}")
lines.append(f" {i + 1}. {step.tool_name}")
if len(chain.steps) > 3:
lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤")
if chain.input_params:
@ -1294,6 +1294,7 @@ class MCPStatusTool(BaseTool):
# 命令处理
# ============================================================================
class MCPStatusCommand(BaseCommand):
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
@ -1644,6 +1645,7 @@ class MCPStatusCommand(BaseCommand):
_plugin_instance._load_tool_chains()
chains = tool_chain_manager.get_all_chains()
from src.plugin_system.core.component_registry import component_registry
registered = 0
for name, chain in tool_chain_manager.get_enabled_chains().items():
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
@ -1735,7 +1737,7 @@ class MCPStatusCommand(BaseCommand):
lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):")
for i, step in enumerate(chain.steps):
optional_tag = " (可选)" if step.optional else ""
lines.append(f" {i+1}. {step.tool_name}{optional_tag}")
lines.append(f" {i + 1}. {step.tool_name}{optional_tag}")
if step.description:
lines.append(f" {step.description}")
if step.output_key:
@ -1983,6 +1985,7 @@ class MCPImportCommand(BaseCommand):
# 事件处理器
# ============================================================================
class MCPStartupHandler(BaseEventHandler):
"""MCP 启动事件处理器"""
@ -2037,6 +2040,7 @@ class MCPStopHandler(BaseEventHandler):
# 主插件类
# ============================================================================
@register_plugin
class MCPBridgePlugin(BasePlugin):
"""MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
@ -2505,7 +2509,7 @@ class MCPBridgePlugin(BasePlugin):
label="📋 工具链列表",
input_type="textarea",
rows=20,
placeholder='''[
placeholder="""[
{
"name": "search_and_detail",
"description": "先搜索再获取详情",
@ -2515,7 +2519,7 @@ class MCPBridgePlugin(BasePlugin):
{"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}}
]
}
]''',
]""",
hint="每个工具链包含 name、description、input_params、steps",
order=30,
),
@ -2653,9 +2657,9 @@ mcp_bing_*""",
label="📜 高级权限规则(可选)",
input_type="textarea",
rows=10,
placeholder='''[
placeholder="""[
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
]''',
]""",
hint="格式: qq:ID:group/private/user工具名支持通配符 *",
order=10,
),
@ -2754,7 +2758,9 @@ mcp_bing_*""",
value = match1.group(2)
suffix = match1.group(3)
# 将转义的换行符还原为实际换行符
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
unescaped = (
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
)
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
fixed_lines.append(fixed_line)
modified = True
@ -2948,11 +2954,13 @@ mcp_bing_*""",
logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}")
args_template = {}
steps.append({
"tool_name": tool_name,
"args_template": args_template,
"output_key": output_key,
})
steps.append(
{
"tool_name": tool_name,
"args_template": args_template,
"output_key": output_key,
}
)
if not steps:
logger.warning("快速添加工具链: 没有有效的步骤")
@ -3056,7 +3064,9 @@ mcp_bing_*""",
if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict):
self.config["tool_chains"] = {}
self.config["tool_chains"]["chains_list"] = chains_json
logger.info("检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list请在 WebUI 保存一次以固化)")
logger.info(
"检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list请在 WebUI 保存一次以固化)"
)
chains_config = self.config.get("tool_chains", {})
if not isinstance(chains_config, dict):
@ -3153,10 +3163,7 @@ mcp_bing_*""",
# 应用过滤器
if filter_patterns:
matched = any(
fnmatch.fnmatch(tool_name, p) or tool_name == p
for p in filter_patterns
)
matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns)
if filter_mode == "whitelist":
# 白名单模式:只注册匹配的
@ -3179,6 +3186,7 @@ mcp_bing_*""",
return result.content or "(无返回内容)"
else:
return f"工具调用失败: {result.error}"
return execute_func
execute_func = make_execute_func(tool_key)
@ -3207,7 +3215,9 @@ mcp_bing_*""",
return registered_count
def _update_react_status_display(self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]) -> None:
def _update_react_status_display(
self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str]
) -> None:
"""更新 ReAct 工具状态显示"""
if not registered_tools:
status_text = "(未注册任何工具)"
@ -3243,12 +3253,14 @@ mcp_bing_*""",
description = param_info.get("description", f"参数 {param_name}")
is_required = param_name in required
parameters.append({
"name": param_name,
"type": param_type,
"description": description,
"required": is_required,
})
parameters.append(
{
"name": param_name,
"type": param_type,
"description": description,
"required": is_required,
}
)
return parameters
@ -3276,7 +3288,7 @@ mcp_bing_*""",
for i, step in enumerate(chain.steps):
opt = " (可选)" if step.optional else ""
out = f"{step.output_key}" if step.output_key else ""
lines.append(f" {i+1}. {step.tool_name}{out}{opt}")
lines.append(f" {i + 1}. {step.tool_name}{out}{opt}")
lines.append("")
status_text = "\n".join(lines)
@ -3295,6 +3307,7 @@ mcp_bing_*""",
async def _async_connect_servers(self) -> None:
"""异步连接所有配置的 MCP 服务器v1.5.0: 并行连接优化)"""
import asyncio
settings = self.config.get("settings", {})
servers_config = self._load_mcp_servers_config()
@ -3380,10 +3393,7 @@ mcp_bing_*""",
# 并行执行所有连接
start_time = time.time()
results = await asyncio.gather(
*[connect_single_server(cfg) for cfg in enabled_configs],
return_exceptions=True
)
results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
connect_duration = time.time() - start_time
# 统计连接结果
@ -3404,15 +3414,14 @@ mcp_bing_*""",
# 注册所有工具
from src.plugin_system.core.component_registry import component_registry
registered_count = 0
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
tool_name = tool_key.replace("-", "_").replace(".", "_")
is_disabled = tool_name in disabled_tools
info, tool_class = mcp_tool_registry.register_tool(
tool_key, tool_info, tool_prefix, disabled=is_disabled
)
info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
info.plugin_name = self.plugin_name
if component_registry.register_component(info, tool_class):
@ -3433,7 +3442,9 @@ mcp_bing_*""",
react_count = self._register_tools_to_react()
self._initialized = True
logger.info(f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具")
logger.info(
f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具"
)
# 更新状态显示
self._update_status_display()
@ -3508,7 +3519,9 @@ mcp_bing_*""",
logger.info("检测到旧版 servers.list已自动迁移为 Claude mcpServers请在 WebUI 保存一次以固化)")
if not claude_json.strip():
self._last_servers_config_error = "未配置任何 MCP 服务器(请在 WebUI 的「MCP ServersClaude」粘贴 mcpServers JSON"
self._last_servers_config_error = (
"未配置任何 MCP 服务器(请在 WebUI 的「MCP ServersClaude」粘贴 mcpServers JSON"
)
return []
try:

View File

@ -22,15 +22,18 @@ from typing import Any, Dict, List, Optional, Tuple
try:
from src.common.logger import get_logger
logger = get_logger("mcp_tool_chain")
except ImportError:
import logging
logger = logging.getLogger("mcp_tool_chain")
@dataclass
class ToolChainStep:
"""工具链步骤"""
tool_name: str # 要调用的工具名(如 mcp_server_tool
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
output_key: str = "" # 输出存储的键名,供后续步骤引用
@ -60,6 +63,7 @@ class ToolChainStep:
@dataclass
class ToolChainDefinition:
"""工具链定义"""
name: str # 工具链名称(将作为组合工具的名称)
description: str # 工具链描述(供 LLM 理解)
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
@ -90,6 +94,7 @@ class ToolChainDefinition:
@dataclass
class ChainExecutionResult:
"""工具链执行结果"""
success: bool
final_output: str # 最终输出(最后一个步骤的结果)
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
@ -103,7 +108,7 @@ class ChainExecutionResult:
status = "" if step.get("success") else ""
tool = step.get("tool_name", "unknown")
duration = step.get("duration_ms", 0)
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)")
lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
if not step.get("success") and step.get("error"):
lines.append(f" 错误: {step['error'][:50]}")
return "\n".join(lines)
@ -113,7 +118,7 @@ class ToolChainExecutor:
"""工具链执行器"""
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
def __init__(self, mcp_manager):
self._mcp_manager = mcp_manager
@ -201,7 +206,7 @@ class ToolChainExecutor:
tool_key = self._resolve_tool_key(step.tool_name)
if not tool_key:
step_result["error"] = f"工具 {step.tool_name} 不存在"
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在")
logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
if not step.optional:
step_results.append(step_result)
@ -209,13 +214,13 @@ class ToolChainExecutor:
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在",
error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
total_duration_ms=(time.time() - start_time) * 1000,
)
step_results.append(step_result)
continue
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}")
logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
# 调用工具
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
@ -236,10 +241,10 @@ class ToolChainExecutor:
final_output = content
content_preview = content[:100] if content else "(空)"
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...")
logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
else:
step_result["error"] = result.error or "未知错误"
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}")
logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
if not step.optional:
step_results.append(step_result)
@ -247,7 +252,7 @@ class ToolChainExecutor:
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}",
error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
total_duration_ms=(time.time() - start_time) * 1000,
)
@ -255,7 +260,7 @@ class ToolChainExecutor:
step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration
step_result["error"] = str(e)
logger.error(f"工具链步骤 {i+1} 异常: {e}")
logger.error(f"工具链步骤 {i + 1} 异常: {e}")
if not step.optional:
step_results.append(step_result)
@ -263,7 +268,7 @@ class ToolChainExecutor:
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}",
error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
total_duration_ms=(time.time() - start_time) * 1000,
)
@ -295,10 +300,7 @@ class ToolChainExecutor:
elif isinstance(value, dict):
resolved[key] = self._resolve_args(value, context)
elif isinstance(value, list):
resolved[key] = [
self._substitute_vars(v, context) if isinstance(v, str) else v
for v in value
]
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
else:
resolved[key] = value
@ -306,6 +308,7 @@ class ToolChainExecutor:
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
"""替换字符串中的变量"""
def replacer(match):
var_path = match.group(1)
return self._get_var_value(var_path, context)
@ -552,7 +555,7 @@ class ToolChainManager:
try:
chain = ToolChainDefinition.from_dict(item)
if not chain.name:
errors.append(f"{i+1} 个工具链缺少名称")
errors.append(f"{i + 1} 个工具链缺少名称")
continue
if not chain.steps:
errors.append(f"工具链 {chain.name} 没有步骤")
@ -561,7 +564,7 @@ class ToolChainManager:
self.add_chain(chain)
loaded += 1
except Exception as e:
errors.append(f"{i+1} 个工具链解析失败: {e}")
errors.append(f"{i + 1} 个工具链解析失败: {e}")
return loaded, errors

View File

@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
chat_stream=self.message.chat_stream,
reply_reason=reply_reason,
enable_chinese_typo=False,
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
)
if result_status:
# 发送生成的回复

View File

@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
config_base_module.logger = logging.getLogger("config_base_test_logger")
class SimpleClass(ConfigBase):
a: int = 1
b: str = "test"
@ -282,7 +283,7 @@ class TestConfigBase:
True,
"ConfigBase is not Hashable",
id="listset-validation-set-configbase-element_reject",
)
),
],
)
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
@ -340,7 +341,7 @@ class TestConfigBase:
False,
None,
id="dict-validation-happy-configbase-value",
)
),
],
)
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
@ -353,13 +354,11 @@ class TestConfigBase:
field_name = "mapping"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_dict_type(annotation, field_name)

View File

@ -4,7 +4,6 @@ import importlib
import pytest
from pathlib import Path
import importlib.util
import asyncio
class DummyLogger:
@ -71,6 +70,7 @@ class DummyLLMRequest:
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
return ("dummy description", {})
class DummySelect:
def __init__(self, *a, **k):
pass
@ -81,6 +81,7 @@ class DummySelect:
def limit(self, n):
return self
@pytest.fixture(autouse=True)
def patch_external_dependencies(monkeypatch):
# Provide dummy implementations as modules so that importing image_manager is safe
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
if tmp_path is not None:
tmpdir = Path(tmp_path)
tmpdir.mkdir(parents=True, exist_ok=True)
setattr(mod, "IMAGE_DIR", tmpdir)
mod.IMAGE_DIR = tmpdir
except Exception:
pass
return mod
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
# cleanup should run without error
mgr.cleanup_invalid_descriptions_in_db()

View File

@ -1,5 +1,3 @@
import pytest
from src.config.official_configs import ChatConfig
from src.config.config import Config
from src.webui.config_schema import ConfigSchemaGenerator

View File

@ -1,6 +1,5 @@
"""Expression routes pytest tests"""
from datetime import datetime
from typing import Generator
from unittest.mock import MagicMock
@ -12,7 +11,6 @@ from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
def create_test_app() -> FastAPI:

View File

@ -115,7 +115,7 @@ def analyze_single_file(file_path: str) -> Dict:
stats["date_range"] = {
"start": min_date.isoformat(),
"end": max_date.isoformat(),
"duration_days": (max_date - min_date).days + 1
"duration_days": (max_date - min_date).days + 1,
}
# 检查字段存在性
@ -151,8 +151,8 @@ def print_file_stats(stats: Dict, index: int = None):
print(f" 文件中的 total_count: {stats['total_count']}")
print(f" 实际记录数: {stats['actual_count']}")
if stats['total_count'] != stats['actual_count']:
diff = stats['total_count'] - stats['actual_count']
if stats["total_count"] != stats["actual_count"]:
diff = stats["total_count"] - stats["actual_count"]
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
print("\n【评估结果统计】")
@ -161,21 +161,21 @@ def print_file_stats(stats: Dict, index: int = None):
print("\n【唯一性统计】")
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}")
if stats['actual_count'] > 0:
duplicate_count = stats['actual_count'] - stats['unique_pairs']
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["actual_count"] > 0:
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
print("\n【评估者统计】")
if stats['evaluators']:
for evaluator, count in stats['evaluators'].most_common():
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["evaluators"]:
for evaluator, count in stats["evaluators"].most_common():
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
print("\n【时间统计】")
if stats['date_range']:
if stats["date_range"]:
print(f" 最早评估时间: {stats['date_range']['start']}")
print(f" 最晚评估时间: {stats['date_range']['end']}")
print(f" 评估时间跨度: {stats['date_range']['duration_days']}")
@ -185,8 +185,8 @@ def print_file_stats(stats: Dict, index: int = None):
print("\n【字段统计】")
print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}")
print(f" 包含 reason: {'' if stats['has_reason'] else ''}")
if stats['has_reason']:
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["has_reason"]:
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
@ -215,15 +215,15 @@ def print_summary(all_stats: List[Dict]):
return
# 汇总记录统计
total_records = sum(s['actual_count'] for s in valid_files)
total_suitable = sum(s['suitable_count'] for s in valid_files)
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
total_records = sum(s["actual_count"] for s in valid_files)
total_suitable = sum(s["suitable_count"] for s in valid_files)
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
total_unique_pairs = set()
# 收集所有唯一的(situation, style)对
for stats in valid_files:
try:
with open(stats['file_path'], "r", encoding="utf-8") as f:
with open(stats["file_path"], "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
for r in results:
@ -234,8 +234,16 @@ def print_summary(all_stats: List[Dict]):
print("\n【记录汇总】")
print(f" 总记录数: {total_records:,}")
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
print(
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
if total_records > 0
else " 通过: 0 条"
)
print(
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
if total_records > 0
else " 不通过: 0 条"
)
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}")
if total_records > 0:
@ -246,7 +254,7 @@ def print_summary(all_stats: List[Dict]):
# 汇总评估者统计
all_evaluators = Counter()
for stats in valid_files:
all_evaluators.update(stats['evaluators'])
all_evaluators.update(stats["evaluators"])
print("\n【评估者汇总】")
if all_evaluators:
@ -259,7 +267,7 @@ def print_summary(all_stats: List[Dict]):
# 汇总时间范围
all_dates = []
for stats in valid_files:
all_dates.extend(stats['evaluation_dates'])
all_dates.extend(stats["evaluation_dates"])
if all_dates:
min_date = min(all_dates)
@ -270,7 +278,7 @@ def print_summary(all_stats: List[Dict]):
print(f" 总时间跨度: {(max_date - min_date).days + 1}")
# 文件大小汇总
total_size = sum(s['file_size'] for s in valid_files)
total_size = sum(s["file_size"] for s in valid_files)
avg_size = total_size / len(valid_files) if valid_files else 0
print("\n【文件大小汇总】")
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
@ -318,5 +326,3 @@ def main():
if __name__ == "__main__":
main()

View File

@ -171,7 +171,9 @@ def main():
sys.exit(1)
if not args.raw_index:
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
logger.info(
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
)
sys.exit(1)
# 解析索引列表1-based

View File

@ -71,7 +71,7 @@ def save_results(evaluation_results: List[Dict]):
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results),
"evaluation_results": evaluation_results
"evaluation_results": evaluation_results,
}
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
@ -84,9 +84,7 @@ def save_results(evaluation_results: List[Dict]):
print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation(
evaluated_pairs: Set[Tuple[str, str]] = None
) -> List[Expression]:
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
"""
选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目
@ -109,10 +107,7 @@ def select_expressions_for_evaluation(
return []
# 过滤出未评估的项目
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.warning("所有项目都已评估完成")
@ -132,7 +127,9 @@ def select_expressions_for_evaluation(
count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed:
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}")
logger.warning(
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}"
)
count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目
@ -141,13 +138,16 @@ def select_expressions_for_evaluation(
selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序
logger.info(f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍")
logger.info(
f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍"
)
return selected
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
@ -202,9 +202,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
@ -241,7 +239,9 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
logger.info(
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
)
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
@ -258,7 +258,7 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
"reason": reason,
"error": error,
"evaluator": "llm",
"evaluated_at": datetime.now().isoformat()
"evaluated_at": datetime.now().isoformat(),
}
@ -302,7 +302,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print(f"Count = {count}:")
print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
print()
# 比较count=1和count>1
@ -340,12 +340,12 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print("Count = 1:")
print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
print()
print("Count > 1:")
print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
print()
# 进行卡方检验简化版使用2x2列联表
@ -427,7 +427,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
"count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group,
"count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results)
"total_evaluated": len(evaluation_results),
}
try:
@ -466,8 +466,7 @@ async def main():
try:
all_expressions = list(Expression.select())
unevaluated_count_gt1 = [
expr for expr in all_expressions
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
]
has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e:
@ -485,21 +484,20 @@ async def main():
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm"
request_type="expression_evaluator_count_analysis_llm",
)
print("✓ LLM实例创建成功\n")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}")
db.close()
return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation(
evaluated_pairs=evaluated_pairs
)
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
if not expressions:
print("\n没有可评估的项目")
@ -518,7 +516,7 @@ async def main():
llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get('error'):
if llm_result.get("error"):
print(f" 错误: {llm_result['error']}")
print()
@ -553,4 +551,3 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())

View File

@ -140,9 +140,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
@ -152,6 +150,7 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
@ -196,7 +195,7 @@ async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm"
"evaluator": "llm",
}
@ -244,10 +243,16 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
precision = (
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
)
recall = (
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
)
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
specificity = (
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
)
# 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
@ -283,7 +288,7 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
"specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference
"rate_difference": rate_difference,
}
@ -334,13 +339,11 @@ async def main(count: int | None = None):
# 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_llm"
)
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
@ -348,11 +351,7 @@ async def main(count: int | None = None):
llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm(
manual_result["situation"],
manual_result["style"],
llm
))
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前
@ -372,14 +371,16 @@ async def main(count: int | None = None):
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
fp_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
@ -389,7 +390,7 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'):
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
@ -410,14 +411,16 @@ async def main(count: int | None = None):
# 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
fn_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
@ -427,7 +430,7 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}")
print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)")
if item.get('llm_error'):
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
@ -447,13 +450,21 @@ async def main(count: int | None = None):
print()
# print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}")
print(
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}")
print(
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
)
print(
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
@ -464,12 +475,18 @@ async def main(count: int | None = None):
print()
print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
print(
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
)
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
print(f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
)
print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
@ -486,11 +503,12 @@ async def main(count: int | None = None):
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"manual_results": valid_manual_results,
"llm_results": llm_results,
"comparison": comparison
}, f, ensure_ascii=False, indent=2)
json.dump(
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
f,
ensure_ascii=False,
indent=2,
)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
@ -509,15 +527,9 @@ if __name__ == "__main__":
python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
"""
)
parser.add_argument(
"-n", "--count",
type=int,
default=None,
help="随机选取的数据条数(默认:使用全部数据)"
""",
)
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
args = parser.parse_args()
asyncio.run(main(count=args.count))

View File

@ -65,7 +65,7 @@ def save_results(manual_results: List[Dict]):
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(manual_results),
"manual_results": manual_results
"manual_results": manual_results,
}
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
@ -98,10 +98,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.info("所有项目都已评估完成")
@ -120,6 +117,7 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
@ -150,18 +148,18 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ['q', 'quit']:
if user_input in ["q", "quit"]:
print("退出评估")
return None
if user_input in ['s', 'skip']:
if user_input in ["s", "skip"]:
print("跳过当前项目")
return "skip"
if user_input in ['y', 'yes', '1', '', '通过']:
if user_input in ["y", "yes", "1", "", "通过"]:
suitable = True
break
elif user_input in ['n', 'no', '0', '', '不通过']:
elif user_input in ["n", "no", "0", "", "不通过"]:
suitable = False
break
else:
@ -173,7 +171,7 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
"suitable": suitable,
"reason": None,
"evaluator": "manual",
"evaluated_at": datetime.now().isoformat()
"evaluated_at": datetime.now().isoformat(),
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
@ -257,9 +255,9 @@ def main():
# 询问是否继续
while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ['y', 'yes', '1', '', '继续']:
if continue_input in ["y", "yes", "1", "", "继续"]:
break
elif continue_input in ['n', 'no', '0', '', '退出']:
elif continue_input in ["n", "no", "0", "", "退出"]:
print("\n评估结束")
return
else:
@ -275,4 +273,3 @@ def main():
if __name__ == "__main__":
main()

View File

@ -134,9 +134,7 @@ def handle_import_openie(
# 在非交互模式下,不再询问用户,而是直接报错终止
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
if non_interactive:
logger.error(
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
)
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
sys.exit(1)
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower()
@ -189,9 +187,7 @@ def handle_import_openie(
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
# 新增确认提示
if non_interactive:
logger.warning(
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
)
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
else:
print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
def main(argv: Optional[list[str]] = None) -> None:
"""主函数 - 解析参数并运行异步主流程。"""
parser = argparse.ArgumentParser(
description=(
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
"将其导入到 LPMM 的向量库与知识图中。"
)
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
)
parser.add_argument(
"--non-interactive",

View File

@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
ensure_dirs() # 确保目录存在
# 新增用户确认提示
if non_interactive:
logger.warning(
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
)
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
else:
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@ -68,4 +68,3 @@ def main() -> None:
if __name__ == "__main__":
main()

View File

@ -29,6 +29,7 @@ except ImportError as e:
logger = get_logger("lpmm_interactive_manager")
async def interactive_add():
"""交互式导入知识"""
print("\n" + "=" * 40)
@ -68,6 +69,7 @@ async def interactive_add():
print(f"\n[×] 发生异常: {e}")
logger.error(f"add_content 异常: {e}", exc_info=True)
async def interactive_delete():
"""交互式删除知识"""
print("\n" + "=" * 40)
@ -108,8 +110,12 @@ async def interactive_delete():
return
print("-" * 40)
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
if confirm != 'y':
confirm = (
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
.strip()
.lower()
)
if confirm != "y":
print("\n[!] 已取消删除操作。")
return
@ -129,6 +135,7 @@ async def interactive_delete():
print(f"\n[×] 发生异常: {e}")
logger.error(f"delete 异常: {e}", exc_info=True)
async def interactive_clear():
"""交互式清空知识库"""
print("\n" + "=" * 40)
@ -165,16 +172,21 @@ async def interactive_clear():
before = stats.get("before", {})
after = stats.get("after", {})
print("\n[统计信息]")
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
print(
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
)
print(
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
)
else:
print(f"\n[×] 失败:{result['message']}")
except Exception as e:
print(f"\n[×] 发生异常: {e}")
logger.error(f"clear_all 异常: {e}", exc_info=True)
async def interactive_search():
"""交互式查询知识"""
print("\n" + "=" * 40)
@ -224,6 +236,7 @@ async def interactive_search():
print(f"\n[×] 查询失败: {e}")
logger.error(f"查询异常: {e}", exc_info=True)
async def main():
"""主循环"""
while True:
@ -253,6 +266,7 @@ async def main():
else:
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
if __name__ == "__main__":
try:
# 运行主循环
@ -262,4 +276,3 @@ if __name__ == "__main__":
except Exception as e:
print(f"\n[!] 程序运行出错: {e}")
logger.error(f"Main loop 异常: {e}", exc_info=True)

View File

@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
txt_files = list(raw_dir.glob("*.txt"))
if not txt_files:
msg = (
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
"info_extraction 可能立即退出或无数据可处理。"
)
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件info_extraction 可能立即退出或无数据可处理。"
print(msg)
if non_interactive:
logger.error(
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
)
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
return False
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
return cont == "y"
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
json_files = list(openie_dir.glob("*.json"))
if not json_files:
msg = (
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
"import_openie 可能会因为找不到批次而失败。"
)
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件import_openie 可能会因为找不到批次而失败。"
print(msg)
if non_interactive:
logger.error(
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
)
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
return False
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
return cont == "y"
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
try:
if not getattr(global_config.lpmm_knowledge, "enable", False):
print(
"[WARN] 当前配置 lpmm_knowledge.enable = false"
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
)
print("[WARN] 当前配置 lpmm_knowledge.enable = false刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
except Exception:
# 配置异常时不阻断主流程,仅忽略提示
pass
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
if action == "prepare_raw":
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
sha_list, raw_data = load_raw_data()
print(
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
elif action == "info_extract":
if not _check_before_info_extract("--non-interactive" in extra_args):
print("已根据用户选择,取消执行信息提取。")
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
sha_list, raw_data = load_raw_data()
print(
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
non_interactive = "--non-interactive" in extra_args
if not _check_before_info_extract(non_interactive):
print("已根据用户选择,取消 full_import信息提取阶段被取消")
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
)
# 快速选项:按推荐方式清理所有相关实体/关系
quick_all = input(
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
).strip().lower()
quick_all = (
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
)
if quick_all in ("", "y", "yes"):
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
else:
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
def _interactive_build_batch_inspect_args() -> List[str]:
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
path = _interactive_choose_openie_file(
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
)
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
if not path:
return []
return ["--openie-file", path]
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
def _interactive_build_test_args() -> List[str]:
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
print(
"\n[TEST] 你可以:\n"
"- 直接回车使用内置的默认测试用例;\n"
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
)
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
if not query:
return []
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
new_dim = input(
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
).strip()
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
if new_dim and not new_dim.isdigit():
print("输入的维度不是纯数字,已取消操作。")
return
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
if __name__ == "__main__":
main()

View File

@ -28,6 +28,7 @@ from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
@ -44,7 +45,7 @@ def _import_memory_retrieval():
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
if hasattr(existing_module, "init_memory_retrieval_prompt"):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
@ -54,14 +55,14 @@ def _import_memory_retrieval():
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key)
for key in keys_to_remove:
try:
@ -75,6 +76,7 @@ def _import_memory_retrieval():
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
@ -253,7 +255,7 @@ async def test_memory_retrieval(
包含测试结果的字典
"""
print("\n" + "=" * 80)
print(f"[测试] 记忆检索测试")
print("[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
@ -267,6 +269,7 @@ async def test_memory_retrieval(
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
@ -284,7 +287,7 @@ async def test_memory_retrieval(
timeout = global_config.memory.agent_timeout_seconds
print(f"\n[配置]")
print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
@ -321,26 +324,26 @@ async def test_memory_retrieval(
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print(f"\n[结果]")
print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(f" 答案: (未找到答案)")
print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print(f"\n[Token使用情况]")
print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage['model_usage']:
print(f"\n[按模型统计]")
for model_name, usage in token_usage['model_usage'].items():
if token_usage["model_usage"]:
print("\n[按模型统计]")
for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
@ -348,7 +351,7 @@ async def test_memory_retrieval(
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print(f"\n[迭代详情]")
print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
@ -444,4 +447,3 @@ def main() -> None:
if __name__ == "__main__":
main()

View File

@ -455,6 +455,7 @@ class ExpressionSelector:
expr_obj.save()
logger.debug("表达方式激活: 更新last_active_time in db")
try:
expression_selector = ExpressionSelector()
except Exception as e:

View File

@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
logger = get_logger("jargon")
class JargonExplainer:
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""

View File

@ -2,7 +2,6 @@ import time
import asyncio
from typing import List, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression
@ -119,9 +118,7 @@ class MessageRecorder:
# 触发 expression_learner 和 jargon_miner 的处理
if self.enable_expression_learning:
asyncio.create_task(
self._trigger_expression_learning(messages)
)
asyncio.create_task(self._trigger_expression_learning(messages))
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
@ -130,9 +127,7 @@ class MessageRecorder:
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self, messages: List[Any]
) -> None:
async def _trigger_expression_learning(self, messages: List[Any]) -> None:
"""
触发 expression 学习使用指定的消息列表

View File

@ -1,5 +1,5 @@
import time
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
from typing import Tuple, Optional # 增加了 Optional
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@ -170,13 +170,10 @@ class ActionPlanner:
)
break
else:
logger.debug(
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
)
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
except Exception as e:
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
# --- 获取超时提示信息 ---
# (这部分逻辑不变)
timeout_context = ""

View File

@ -112,7 +112,7 @@ class Conversation:
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
"platform": msg.user_info.platform if msg.user_info else "",
}
},
}
initial_messages_dict.append(msg_dict)

View File

@ -5,7 +5,7 @@ from src.common.logger import get_logger
from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType, Notification
from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
from src.common.data_models.database_data_model import DatabaseMessages
import traceback # 导入 traceback 用于调试
logger = get_logger("observation_info")

View File

@ -42,9 +42,7 @@ class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
)
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname

View File

@ -1,10 +1,10 @@
from typing import List, Tuple, Dict, Any
from src.common.logger import get_logger
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
# from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.message_receive.message import Message
from src.config.config import model_config
from src.chat.knowledge import qa_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
@ -16,9 +16,7 @@ class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils
)
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
self.private_name = private_name
def _lpmm_get_knowledge(self, query: str) -> str:

View File

@ -14,10 +14,7 @@ class ReplyChecker:
"""回复检查器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="reply_check"
)
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname
self.private_name = private_name

View File

@ -704,10 +704,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断
try:
await asyncio.wait_for(
self._new_message_event.wait(),
timeout=wait_seconds
)
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
# 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError:
@ -731,7 +728,9 @@ class BrainChatting:
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)")
logger.info(
f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)"
)
# 清除事件状态,准备等待新消息
self._new_message_event.clear()
@ -749,10 +748,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断
try:
await asyncio.wait_for(
self._new_message_event.wait(),
timeout=wait_seconds
)
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
# 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError:

View File

@ -431,15 +431,21 @@ class BrainPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
return extracted_reasoning, [
ActionPlannerInfo(
action_type="complete_talk",
reasoning=extracted_reasoning,
action_data={},
action_message=None,
available_actions=available_actions,
)
], llm_content, llm_reasoning, llm_duration_ms
return (
extracted_reasoning,
[
ActionPlannerInfo(
action_type="complete_talk",
reasoning=extracted_reasoning,
action_data={},
action_message=None,
available_actions=available_actions,
)
],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应
if llm_content:

View File

@ -1,7 +1,7 @@
import asyncio
import json
import time
from typing import List, Union, Dict, Any
from typing import List, Union
from .global_logger import logger
from . import prompt_template
@ -200,9 +200,7 @@ class IEProcess:
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
try:
entities, triples = await asyncio.to_thread(
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
)
entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
if entities is not None:
results.append(

View File

@ -395,8 +395,7 @@ class KGManager:
appear_cnt = self.ent_appear_cnt.get(ent_hash)
if not appear_cnt or appear_cnt <= 0:
logger.debug(
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0"
f"将使用 1.0 作为默认出现次数参与权重计算"
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0将使用 1.0 作为默认出现次数参与权重计算"
)
appear_cnt = 1.0
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)

View File

@ -11,6 +11,7 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API")
class LPMMOperations:
"""
LPMM 内部操作接口
@ -20,9 +21,7 @@ class LPMMOperations:
def __init__(self):
self._initialized = False
async def _run_cancellable_executor(
self, func: Callable, *args, **kwargs
) -> Any:
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
"""
在线程池中执行可取消的同步操作
当任务被取消时 Ctrl+C会立即响应并抛出 CancelledError
@ -79,7 +78,7 @@ class LPMMOperations:
# 1. 分段处理
if auto_split:
# 自动按双换行符分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
else:
# 不分割,作为完整一段
text_stripped = text.strip()
@ -95,7 +94,9 @@ class LPMMOperations:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
llm_ner = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf)
@ -128,25 +129,21 @@ class LPMMOperations:
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor(
embed_mgr.store_new_data_set,
new_raw_paragraphs,
new_triple_list_data
)
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
# 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor(
kg_mgr.build_kg,
new_triple_list_data,
embed_mgr
)
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
# 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
return {
"status": "success",
"count": len(new_raw_paragraphs),
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断")
@ -215,8 +212,7 @@ class LPMMOperations:
# a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
to_delete_keys
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
)
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
@ -224,10 +220,7 @@ class LPMMOperations:
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
to_delete_hashes,
ent_hashes=None,
remove_orphan_entities=True
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
@ -237,7 +230,11 @@ class LPMMOperations:
await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词"
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
return {
"status": "success",
"deleted_count": deleted_count,
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断")
@ -275,16 +272,14 @@ class LPMMOperations:
# 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
para_keys
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
)
embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量
if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items,
ent_keys
embed_mgr.entities_embedding_store.delete_items, ent_keys
)
else:
ent_deleted = 0
@ -292,8 +287,7 @@ class LPMMOperations:
# 删除所有关系向量
if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items,
rel_keys
embed_mgr.relation_embedding_store.delete_items, rel_keys
)
else:
rel_deleted = 0
@ -341,15 +335,13 @@ class LPMMOperations:
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
all_pg_hashes,
ent_hashes=None,
remove_orphan_entities=True
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear()
@ -373,7 +365,7 @@ class LPMMOperations:
"stats": {
"before": before_stats,
"after": after_stats,
}
},
}
except asyncio.CancelledError:
@ -383,6 +375,6 @@ class LPMMOperations:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
# 内部使用的单例
lpmm_ops = LPMMOperations()

View File

@ -136,4 +136,3 @@ class PlanReplyLogger:
return str(value)
# Fallback to string for other complex types
return str(value)

View File

@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 如果未开启 API Server直接跳过 Fallback
if not global_config.maim_message.enable_api_server:
logger.debug(f"[API Server Fallback] API Server未开启跳过fallback")
logger.debug("[API Server Fallback] API Server未开启跳过fallback")
if legacy_exception:
raise legacy_exception
return False
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
extra_server = getattr(global_api, "extra_server", None)
if not extra_server:
logger.warning(f"[API Server Fallback] extra_server不存在")
logger.warning("[API Server Fallback] extra_server不存在")
if legacy_exception:
raise legacy_exception
return False
if not extra_server.is_running():
logger.warning(f"[API Server Fallback] extra_server未运行")
logger.warning("[API Server Fallback] extra_server未运行")
if legacy_exception:
raise legacy_exception
return False
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
)
# 直接调用 Server 的 send_message 接口,它会自动处理路由
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...")
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
results = await extra_server.send_message(api_message)
logger.debug(f"[API Server Fallback] 发送结果: {results}")

View File

@ -35,6 +35,7 @@ logger = get_logger("planner")
install(extra_lines=3)
class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id
@ -111,20 +112,29 @@ class ActionPlanner:
# 替换 [picid:xxx] 为 [图片:描述]
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(pic_match: re.Match) -> str:
pic_id = pic_match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
platform = (
getattr(message, "user_info", None)
and message.user_info.platform
or getattr(message, "chat_info", None)
and message.chat_info.platform
or "qq"
)
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
# 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式)
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
# 这里匹配到的应该都是单独的格式
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
def replace_user_ref(user_match: re.Match) -> str:
user_name = user_match.group(1)
user_id = user_match.group(2)
@ -137,6 +147,7 @@ class ActionPlanner:
except Exception:
# 如果解析失败,使用原始昵称
return user_name
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
@ -306,9 +317,7 @@ class ActionPlanner:
return merged_words
def _process_unknown_words_cache(
self, actions: List[ActionPlannerInfo]
) -> None:
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
"""
处理黑话缓存逻辑
1. 检查是否有 reply action 提取了 unknown_words
@ -442,7 +451,11 @@ class ActionPlanner:
# 检查是否已经有回复该消息的 action
has_reply_to_force_message = False
for action in actions:
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
if (
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
):
has_reply_to_force_message = True
break
@ -577,10 +590,11 @@ class ActionPlanner:
if global_config.chat.think_mode == "classic":
reply_action_example = ""
if global_config.chat.llm_quote:
reply_action_example += "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
"5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"]'
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
)
if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"'
@ -590,7 +604,9 @@ class ActionPlanner:
"5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n"
)
if global_config.chat.llm_quote:
reply_action_example += "6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
"6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += (
'{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", '
@ -741,15 +757,21 @@ class ActionPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}", [
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
action_data={},
action_message=None,
available_actions=available_actions,
)
], llm_content, llm_reasoning, llm_duration_ms
return (
f"LLM 请求失败,模型出现问题: {req_e}",
[
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
action_data={},
action_message=None,
available_actions=available_actions,
)
],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应
extracted_reasoning = ""

View File

@ -1071,7 +1071,6 @@ class DefaultReplyer:
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style
multi_styles = global_config.personality.multiple_reply_style

View File

@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
)
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType

View File

@ -5,6 +5,7 @@ from src.common.logger import get_logger
logger = get_logger("common_utils")
class TempMethodsExpression:
"""用于临时存放一些方法的类"""

View File

@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
from . import BaseDatabaseDataModel
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
self.session_id = session_id

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple, Union

View File

@ -96,8 +96,12 @@ class Images(SQLModel, table=True):
no_file_flag: bool = Field(default=False) # 文件不存在标记如果为True表示文件已经不存在仅保留描述字段
record_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间(数据库记录被创建的时间)
register_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 注册时间(被注册为可用表情包的时间)
record_time: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 记录时间(数据库记录被创建的时间)
register_time: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime, nullable=True)
) # 注册时间(被注册为可用表情包的时间)
last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
@ -171,7 +175,9 @@ class Expression(SQLModel, table=True):
content_list: str # 内容列表JSON格式存储
count: int = Field(default=0) # 使用次数
last_active_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 上次使用时间
last_active_time: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 上次使用时间
create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间
session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID区分是否为全局表达方式
@ -232,8 +238,12 @@ class ThinkingQuestion(SQLModel, table=True):
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤JSON格式存储
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
updated_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后更新时间
created_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 创建时间
updated_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 最后更新时间
class BinaryData(SQLModel, table=True):
@ -272,7 +282,9 @@ class PersonInfo(SQLModel, table=True):
# 认识次数和时间
know_counts: int = Field(default=0) # 认识次数
first_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 首次认识时间
first_known_time: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime, nullable=True)
) # 首次认识时间
last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间
@ -285,8 +297,12 @@ class ChatSession(SQLModel, table=True):
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
last_active_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后活跃时间
created_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 创建时间
last_active_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 最后活跃时间
# 身份元数据
user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID

View File

@ -221,5 +221,7 @@ if not supports_truecolor():
CONVERTED_MODULE_COLORS[name] = escape_str
else:
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold)
escape_str = rgb_pair_to_ansi_truecolor(
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
)
CONVERTED_MODULE_COLORS[name] = escape_str

View File

@ -9,6 +9,7 @@ from .server import get_global_server
global_api = None
def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例"""
global global_api
@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
return False
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
# 3. Setup Message Bridge
# Initialize refined route map if not exists
if not hasattr(global_api, "platform_map"):
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
if platform:
global_api.platform_map[platform] = api_key # type: ignore
global_api.platform_map[platform] = api_key # type: ignore
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
except Exception as e:
api_logger.warning(f"Failed to update platform map: {e}")
@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
if "raw_message" not in msg_dict:
msg_dict["raw_message"] = None
await global_api.process_message(msg_dict) # type: ignore
await global_api.process_message(msg_dict) # type: ignore
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
# 3.5. Register custom message handlers (bridge to Legacy handlers)
# message_id_echo: handles message ID echo from adapters
# 兼容新旧两个版本的 maim_message:
# - 旧版: handler(payload)
# - 新版: handler(payload, metadata)
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
# Bridge to the Legacy custom handler registered in main.py
try:
# The Legacy handler expects the payload format directly
if hasattr(global_api, "_custom_message_handlers"):
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
if handler:
await handler(payload)
api_logger.debug(f"Processed message_id_echo: {payload}")
@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
except Exception as e:
api_logger.warning(f"Failed to process message_id_echo: {e}")
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
# 4. Initialize Server
extra_server = WebSocketServer(config=server_config)
@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
global_api.stop = patched_stop
# Attach for reference
global_api.extra_server = extra_server # type: ignore # 这是什么
global_api.extra_server = extra_server # type: ignore # 这是什么
except ImportError:
get_logger("maim_message").error(

View File

@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
logger = get_logger("file_utils")
class FileUtils:
@staticmethod
def save_binary_to_file(file_path: Path, data: bytes):

View File

@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
reason = ",".join(reasons)
return MigrationResult(data=data, migrated=migrated_any, reason=reason)

View File

@ -54,8 +54,6 @@ async def generate_dream_summary(
) -> None:
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
try:
# 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {}
for msg in conversation_messages:

View File

@ -4,4 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册
"""

View File

@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}"
return create_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}"
return delete_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}"
return delete_jargon

View File

@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg
return finish_maintenance

View File

@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail

View File

@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}"
return search_chat_history

View File

@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}"
return update_chat_history

View File

@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}"
return update_jargon

View File

@ -459,7 +459,7 @@ def _default_normal_response_parser(
# 此时为了调试方便,建议打印出 arguments 的类型
raise RespParseException(
resp,
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
)
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e:

View File

@ -2,7 +2,7 @@ import time
import json
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager

View File

@ -113,6 +113,7 @@ async def search_chat_history(
if start_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
start_timestamp = parse_datetime_to_timestamp(start_time)
except ValueError as e:
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
@ -120,6 +121,7 @@ async def search_chat_history(
if end_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
end_timestamp = parse_datetime_to_timestamp(end_time)
except ValueError as e:
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
@ -165,16 +167,13 @@ async def search_chat_history(
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
query = query.where(
(
(ChatHistory.start_time >= start_timestamp)
& (ChatHistory.start_time <= end_timestamp)
(ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
) # 记录开始时间在查询时间段内
| (
(ChatHistory.end_time >= start_timestamp)
& (ChatHistory.end_time <= end_timestamp)
(ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
) # 记录结束时间在查询时间段内
| (
(ChatHistory.start_time <= start_timestamp)
& (ChatHistory.end_time >= end_timestamp)
(ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
) # 记录完全包含查询时间段
)
logger.debug(

View File

@ -76,4 +76,3 @@ def register_tool():
],
execute_func=query_words,
)

View File

@ -558,7 +558,9 @@ class PluginBase(ABC):
if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok:
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
)
return False
if min_version or max_version:

View File

@ -751,9 +751,7 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": {
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
},
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
}

View File

@ -429,7 +429,9 @@ class PluginManager:
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
indegree: Dict[str, int] = {
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items():

View File

@ -55,7 +55,9 @@ class PluginServiceRegistry:
full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
def list_services(
self, plugin_name: Optional[str] = None, enabled_only: bool = False
) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
services = self._services.copy()
if plugin_name:
@ -120,7 +122,12 @@ class PluginServiceRegistry:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}")
if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin:
if (
"." not in service_name
and plugin_name is None
and caller_plugin
and service_info.plugin_name != caller_plugin
):
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin):
@ -153,7 +160,9 @@ class PluginServiceRegistry:
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
def _validate_input_contract(
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""校验服务入参契约。"""
schema = service_info.params_schema
if not schema:

View File

@ -96,7 +96,9 @@ class WorkflowEngine:
except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
logger.error(
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
)
self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
@ -195,7 +197,9 @@ class WorkflowEngine:
except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}")
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
)
return WorkflowStepResult(
status="failed",
return_message=str(e),

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
plan_count: int
latest_timestamp: float
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
class PlanLogSummary(BaseModel):
"""规划日志摘要"""
chat_id: str
timestamp: float
filename: str
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
class PlanLogDetail(BaseModel):
"""规划日志详情"""
type: str
chat_id: str
timestamp: float
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计"""
total_chats: int
total_plans: int
chats: List[ChatSummary]
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表"""
data: List[PlanLogSummary]
total: int
page: int
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
@ -106,21 +112,19 @@ async def get_planner_overview():
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
chats.append(
ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name,
)
)
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview(
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
):
"""
获取指定聊天的规划日志列表分页
@ -137,9 +141,7 @@ async def get_chat_plan_logs(
"""
chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedChatLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
@ -151,9 +153,9 @@ async def get_chat_plan_logs(
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get('prompt', '')
prompt = data.get("prompt", "")
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
@ -164,46 +166,44 @@ async def get_chat_plan_logs(
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
page_files = json_files[offset : offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
actions = data.get('actions', [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
logs.append(PlanLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
reasoning_preview=reasoning[:100] if reasoning else ''
))
reasoning = data.get("reasoning", "")
actions = data.get("actions", [])
action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
logs.append(
PlanLogSummary(
chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
reasoning_preview=reasoning[:100] if reasoning else "",
)
)
except Exception:
# 文件读取失败时使用文件名信息
logs.append(PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview='[读取失败]'
))
logs.append(
PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview="[读取失败]",
)
)
return PaginatedChatLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
@ -214,7 +214,7 @@ async def get_log_detail(chat_id: str, filename: str):
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
return PlanLogDetail(**data)
except Exception as e:
@ -223,6 +223,7 @@ async def get_log_detail(chat_id: str, filename: str):
# ========== 兼容旧接口 ==========
@router.get("/stats")
async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口"""
@ -246,7 +247,7 @@ async def get_planner_stats():
"total_plans": overview.total_plans,
"avg_plan_time_ms": 0,
"avg_llm_time_ms": 0,
"recent_plans": recent_plans
"recent_plans": recent_plans,
}
@ -258,10 +259,7 @@ async def get_chat_list():
@router.get("/all-logs")
async def get_all_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
"""获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size}
@ -278,23 +276,25 @@ async def get_all_logs(
total = len(all_files)
offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size]
page_files = all_files[offset : offset + page_size]
logs = []
for chat_id, log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
logs.append({
"chat_id": data.get('chat_id', chat_id),
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get('actions', [])),
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
"reasoning_preview": reasoning[:100] if reasoning else ''
})
reasoning = data.get("reasoning", "")
logs.append(
{
"chat_id": data.get("chat_id", chat_id),
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get("actions", [])),
"total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
"llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
"reasoning_preview": reasoning[:100] if reasoning else "",
}
)
except Exception:
continue

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
reply_count: int
latest_timestamp: float
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
class ReplyLogSummary(BaseModel):
"""回复日志摘要"""
chat_id: str
timestamp: float
filename: str
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
class ReplyLogDetail(BaseModel):
"""回复日志详情"""
type: str
chat_id: str
timestamp: float
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计"""
total_chats: int
total_replies: int
chats: List[ReplierChatSummary]
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表"""
data: List[ReplyLogSummary]
total: int
page: int
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
@ -109,21 +115,19 @@ async def get_replier_overview():
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
chats.append(
ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name,
)
)
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview(
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
):
"""
获取指定聊天的回复日志列表分页
@ -140,9 +144,7 @@ async def get_chat_reply_logs(
"""
chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedReplyLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
@ -154,9 +156,9 @@ async def get_chat_reply_logs(
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get('prompt', '')
prompt = data.get("prompt", "")
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
@ -167,44 +169,42 @@ async def get_chat_reply_logs(
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
page_files = json_files[offset : offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
output = data.get('output', '')
logs.append(ReplyLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get('model', ''),
success=data.get('success', True),
llm_ms=data.get('timing', {}).get('llm_ms', 0),
overall_ms=data.get('timing', {}).get('overall_ms', 0),
output_preview=output[:100] if output else ''
))
output = data.get("output", "")
logs.append(
ReplyLogSummary(
chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get("model", ""),
success=data.get("success", True),
llm_ms=data.get("timing", {}).get("llm_ms", 0),
overall_ms=data.get("timing", {}).get("overall_ms", 0),
output_preview=output[:100] if output else "",
)
)
except Exception:
# 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model='',
success=False,
llm_ms=0,
overall_ms=0,
output_preview='[读取失败]'
))
logs.append(
ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model="",
success=False,
llm_ms=0,
overall_ms=0,
output_preview="[读取失败]",
)
)
return PaginatedReplyLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
@ -215,21 +215,21 @@ async def get_reply_log_detail(chat_id: str, filename: str):
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
return ReplyLogDetail(
type=data.get('type', 'reply'),
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', 0),
prompt=data.get('prompt', ''),
output=data.get('output', ''),
processed_output=data.get('processed_output', []),
model=data.get('model', ''),
reasoning=data.get('reasoning', ''),
think_level=data.get('think_level', 0),
timing=data.get('timing', {}),
error=data.get('error'),
success=data.get('success', True)
type=data.get("type", "reply"),
chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", 0),
prompt=data.get("prompt", ""),
output=data.get("output", ""),
processed_output=data.get("processed_output", []),
model=data.get("model", ""),
reasoning=data.get("reasoning", ""),
think_level=data.get("think_level", 0),
timing=data.get("timing", {}),
error=data.get("error"),
success=data.get("success", True),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
@ -237,6 +237,7 @@ async def get_reply_log_detail(chat_id: str, filename: str):
# ========== 兼容接口 ==========
@router.get("/stats")
async def get_replier_stats():
"""获取回复器统计信息"""
@ -258,7 +259,7 @@ async def get_replier_stats():
return {
"total_chats": overview.total_chats,
"total_replies": overview.total_replies,
"recent_replies": recent_replies
"recent_replies": recent_replies,
}

View File

@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Depends, Cookie, Header, Request, HTTPException
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
from fastapi import Depends, Cookie, Header, Request
from .core import get_current_token, get_token_manager, check_auth_rate_limit
async def require_auth(

View File

@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
# loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
def _get_anti_crawler_config():
"""获取防爬虫配置"""
from src.config.config import global_config
return {
'mode': global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff
"mode": global_config.webui.anti_crawler_mode,
"allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
"trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
"trust_xff": global_config.webui.trust_xff,
}
# 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode']
ALLOWED_IPS = _config['allowed_ips']
TRUSTED_PROXIES = _config['trusted_proxies']
TRUST_XFF = _config['trust_xff']
ANTI_CRAWLER_MODE = _config["mode"]
ALLOWED_IPS = _config["allowed_ips"]
TRUSTED_PROXIES = _config["trusted_proxies"]
TRUST_XFF = _config["trust_xff"]
def _get_mode_config(mode: str) -> dict:

View File

@ -43,7 +43,7 @@ def _get_paragraph_store():
namespace="paragraph",
dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程
chunk_size=100
chunk_size=100,
)
paragraph_store.load_from_file()
@ -74,7 +74,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '')
content: str = getattr(paragraph_item, "str", "")
if content:
return content, True
return None, True
@ -160,7 +160,11 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else:
content = node_data["content"] if "content" in node_data else node_id
@ -249,7 +253,11 @@ async def get_knowledge_graph(
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else:
content = node_data["content"] if "content" in node_data else node_id
@ -372,7 +380,11 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else:
content = node_data["content"] if "content" in node_data else node_id