Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev

pull/1496/head
UnCLAS-Prommer 2026-02-21 23:50:53 +08:00
commit 04a5bf3c6d
No known key found for this signature in database
91 changed files with 2110 additions and 1967 deletions

1
bot.py
View File

@ -50,6 +50,7 @@ print("警告Dev进入不稳定开发状态任何插件与WebUI均可能
print("\n\n\n\n\n") print("\n\n\n\n\n")
print("-----------------------------------------") print("-----------------------------------------")
def run_runner_process(): def run_runner_process():
""" """
Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程 Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程

View File

@ -1,2 +1 @@
"""Core helpers for MCP Bridge Plugin.""" """Core helpers for MCP Bridge Plugin."""

View File

@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
if not mcp_servers: if not mcp_servers:
return "" return ""
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2) return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -22,21 +22,24 @@ from typing import Any, Dict, List, Optional, Tuple
try: try:
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("mcp_tool_chain") logger = get_logger("mcp_tool_chain")
except ImportError: except ImportError:
import logging import logging
logger = logging.getLogger("mcp_tool_chain") logger = logging.getLogger("mcp_tool_chain")
@dataclass @dataclass
class ToolChainStep: class ToolChainStep:
"""工具链步骤""" """工具链步骤"""
tool_name: str # 要调用的工具名(如 mcp_server_tool tool_name: str # 要调用的工具名(如 mcp_server_tool
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换 args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
output_key: str = "" # 输出存储的键名,供后续步骤引用 output_key: str = "" # 输出存储的键名,供后续步骤引用
description: str = "" # 步骤描述 description: str = "" # 步骤描述
optional: bool = False # 是否可选(失败时继续执行) optional: bool = False # 是否可选(失败时继续执行)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"tool_name": self.tool_name, "tool_name": self.tool_name,
@ -45,7 +48,7 @@ class ToolChainStep:
"description": self.description, "description": self.description,
"optional": self.optional, "optional": self.optional,
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep": def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep":
return cls( return cls(
@ -60,12 +63,13 @@ class ToolChainStep:
@dataclass @dataclass
class ToolChainDefinition: class ToolChainDefinition:
"""工具链定义""" """工具链定义"""
name: str # 工具链名称(将作为组合工具的名称) name: str # 工具链名称(将作为组合工具的名称)
description: str # 工具链描述(供 LLM 理解) description: str # 工具链描述(供 LLM 理解)
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤 steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述} input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述}
enabled: bool = True # 是否启用 enabled: bool = True # 是否启用
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"name": self.name, "name": self.name,
@ -74,7 +78,7 @@ class ToolChainDefinition:
"input_params": self.input_params, "input_params": self.input_params,
"enabled": self.enabled, "enabled": self.enabled,
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition": def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition":
steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])] steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])]
@ -90,12 +94,13 @@ class ToolChainDefinition:
@dataclass @dataclass
class ChainExecutionResult: class ChainExecutionResult:
"""工具链执行结果""" """工具链执行结果"""
success: bool success: bool
final_output: str # 最终输出(最后一个步骤的结果) final_output: str # 最终输出(最后一个步骤的结果)
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果 step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
error: str = "" error: str = ""
total_duration_ms: float = 0.0 total_duration_ms: float = 0.0
def to_summary(self) -> str: def to_summary(self) -> str:
"""生成执行摘要""" """生成执行摘要"""
lines = [] lines = []
@ -103,7 +108,7 @@ class ChainExecutionResult:
status = "" if step.get("success") else "" status = "" if step.get("success") else ""
tool = step.get("tool_name", "unknown") tool = step.get("tool_name", "unknown")
duration = step.get("duration_ms", 0) duration = step.get("duration_ms", 0)
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)") lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
if not step.get("success") and step.get("error"): if not step.get("success") and step.get("error"):
lines.append(f" 错误: {step['error'][:50]}") lines.append(f" 错误: {step['error'][:50]}")
return "\n".join(lines) return "\n".join(lines)
@ -111,49 +116,49 @@ class ChainExecutionResult:
class ToolChainExecutor: class ToolChainExecutor:
"""工具链执行器""" """工具链执行器"""
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev} # 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}') VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
def __init__(self, mcp_manager): def __init__(self, mcp_manager):
self._mcp_manager = mcp_manager self._mcp_manager = mcp_manager
def _resolve_tool_key(self, tool_name: str) -> Optional[str]: def _resolve_tool_key(self, tool_name: str) -> Optional[str]:
"""解析工具名,返回有效的 tool_key """解析工具名,返回有效的 tool_key
支持: 支持:
- 直接使用 tool_key mcp_server_tool - 直接使用 tool_key mcp_server_tool
- 使用注册后的工具名会自动转换 - . _ - 使用注册后的工具名会自动转换 - . _
""" """
all_tools = self._mcp_manager.all_tools all_tools = self._mcp_manager.all_tools
# 直接匹配 # 直接匹配
if tool_name in all_tools: if tool_name in all_tools:
return tool_name return tool_name
# 尝试转换后匹配(用户可能使用了注册后的名称) # 尝试转换后匹配(用户可能使用了注册后的名称)
normalized = tool_name.replace("-", "_").replace(".", "_") normalized = tool_name.replace("-", "_").replace(".", "_")
if normalized in all_tools: if normalized in all_tools:
return normalized return normalized
# 尝试查找包含该名称的工具 # 尝试查找包含该名称的工具
for key in all_tools.keys(): for key in all_tools.keys():
if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"): if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"):
return key return key
return None return None
async def execute( async def execute(
self, self,
chain: ToolChainDefinition, chain: ToolChainDefinition,
input_args: Dict[str, Any], input_args: Dict[str, Any],
) -> ChainExecutionResult: ) -> ChainExecutionResult:
"""执行工具链 """执行工具链
Args: Args:
chain: 工具链定义 chain: 工具链定义
input_args: 用户输入的参数 input_args: 用户输入的参数
Returns: Returns:
ChainExecutionResult: 执行结果 ChainExecutionResult: 执行结果
""" """
@ -164,15 +169,15 @@ class ToolChainExecutor:
"step": {}, # 各步骤输出,按 output_key 存储 "step": {}, # 各步骤输出,按 output_key 存储
"prev": "", # 上一步的输出 "prev": "", # 上一步的输出
} }
final_output = "" final_output = ""
# 验证必需的输入参数 # 验证必需的输入参数
missing_params = [] missing_params = []
for param_name in chain.input_params.keys(): for param_name in chain.input_params.keys():
if param_name not in context["input"]: if param_name not in context["input"]:
missing_params.append(param_name) missing_params.append(param_name)
if missing_params: if missing_params:
return ChainExecutionResult( return ChainExecutionResult(
success=False, success=False,
@ -180,7 +185,7 @@ class ToolChainExecutor:
error=f"缺少必需参数: {', '.join(missing_params)}", error=f"缺少必需参数: {', '.join(missing_params)}",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
for i, step in enumerate(chain.steps): for i, step in enumerate(chain.steps):
step_start = time.time() step_start = time.time()
step_result = { step_result = {
@ -191,96 +196,96 @@ class ToolChainExecutor:
"error": "", "error": "",
"duration_ms": 0, "duration_ms": 0,
} }
try: try:
# 替换参数中的变量 # 替换参数中的变量
resolved_args = self._resolve_args(step.args_template, context) resolved_args = self._resolve_args(step.args_template, context)
step_result["resolved_args"] = resolved_args step_result["resolved_args"] = resolved_args
# 解析工具名 # 解析工具名
tool_key = self._resolve_tool_key(step.tool_name) tool_key = self._resolve_tool_key(step.tool_name)
if not tool_key: if not tool_key:
step_result["error"] = f"工具 {step.tool_name} 不存在" step_result["error"] = f"工具 {step.tool_name} 不存在"
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在") logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
if not step.optional: if not step.optional:
step_results.append(step_result) step_results.append(step_result)
return ChainExecutionResult( return ChainExecutionResult(
success=False, success=False,
final_output="", final_output="",
step_results=step_results, step_results=step_results,
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在", error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
step_results.append(step_result) step_results.append(step_result)
continue continue
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}") logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
# 调用工具 # 调用工具
result = await self._mcp_manager.call_tool(tool_key, resolved_args) result = await self._mcp_manager.call_tool(tool_key, resolved_args)
step_duration = (time.time() - step_start) * 1000 step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration step_result["duration_ms"] = step_duration
if result.success: if result.success:
step_result["success"] = True step_result["success"] = True
# 确保 content 不为 None # 确保 content 不为 None
content = result.content if result.content is not None else "" content = result.content if result.content is not None else ""
step_result["output"] = content step_result["output"] = content
# 更新上下文 # 更新上下文
context["prev"] = content context["prev"] = content
if step.output_key: if step.output_key:
context["step"][step.output_key] = content context["step"][step.output_key] = content
final_output = content final_output = content
content_preview = content[:100] if content else "(空)" content_preview = content[:100] if content else "(空)"
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...") logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
else: else:
step_result["error"] = result.error or "未知错误" step_result["error"] = result.error or "未知错误"
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}") logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
if not step.optional: if not step.optional:
step_results.append(step_result) step_results.append(step_result)
return ChainExecutionResult( return ChainExecutionResult(
success=False, success=False,
final_output="", final_output="",
step_results=step_results, step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}", error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
except Exception as e: except Exception as e:
step_duration = (time.time() - step_start) * 1000 step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration step_result["duration_ms"] = step_duration
step_result["error"] = str(e) step_result["error"] = str(e)
logger.error(f"工具链步骤 {i+1} 异常: {e}") logger.error(f"工具链步骤 {i + 1} 异常: {e}")
if not step.optional: if not step.optional:
step_results.append(step_result) step_results.append(step_result)
return ChainExecutionResult( return ChainExecutionResult(
success=False, success=False,
final_output="", final_output="",
step_results=step_results, step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}", error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
total_duration_ms=(time.time() - start_time) * 1000, total_duration_ms=(time.time() - start_time) * 1000,
) )
step_results.append(step_result) step_results.append(step_result)
total_duration = (time.time() - start_time) * 1000 total_duration = (time.time() - start_time) * 1000
return ChainExecutionResult( return ChainExecutionResult(
success=True, success=True,
final_output=final_output, final_output=final_output,
step_results=step_results, step_results=step_results,
total_duration_ms=total_duration, total_duration_ms=total_duration,
) )
def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""解析参数模板,替换变量 """解析参数模板,替换变量
支持的变量格式: 支持的变量格式:
- ${input.param_name}: 用户输入的参数 - ${input.param_name}: 用户输入的参数
- ${step.output_key}: 某个步骤的输出 - ${step.output_key}: 某个步骤的输出
@ -288,50 +293,48 @@ class ToolChainExecutor:
- ${prev.field}: 上一步输出JSON的某个字段 - ${prev.field}: 上一步输出JSON的某个字段
""" """
resolved = {} resolved = {}
for key, value in args_template.items(): for key, value in args_template.items():
if isinstance(value, str): if isinstance(value, str):
resolved[key] = self._substitute_vars(value, context) resolved[key] = self._substitute_vars(value, context)
elif isinstance(value, dict): elif isinstance(value, dict):
resolved[key] = self._resolve_args(value, context) resolved[key] = self._resolve_args(value, context)
elif isinstance(value, list): elif isinstance(value, list):
resolved[key] = [ resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
self._substitute_vars(v, context) if isinstance(v, str) else v
for v in value
]
else: else:
resolved[key] = value resolved[key] = value
return resolved return resolved
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str: def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
"""替换字符串中的变量""" """替换字符串中的变量"""
def replacer(match): def replacer(match):
var_path = match.group(1) var_path = match.group(1)
return self._get_var_value(var_path, context) return self._get_var_value(var_path, context)
return self.VAR_PATTERN.sub(replacer, template) return self.VAR_PATTERN.sub(replacer, template)
def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str: def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str:
"""获取变量值 """获取变量值
Args: Args:
var_path: 变量路径 "input.query", "step.search_result", "prev", "prev.id" var_path: 变量路径 "input.query", "step.search_result", "prev", "prev.id"
context: 上下文 context: 上下文
""" """
parts = self._parse_var_path(var_path) parts = self._parse_var_path(var_path)
if not parts: if not parts:
return "" return ""
# 获取根对象 # 获取根对象
root = parts[0] root = parts[0]
if root not in context: if root not in context:
logger.warning(f"变量 {var_path} 的根 '{root}' 不存在") logger.warning(f"变量 {var_path} 的根 '{root}' 不存在")
return "" return ""
value = context[root] value = context[root]
# 遍历路径 # 遍历路径
for part in parts[1:]: for part in parts[1:]:
if isinstance(value, str): if isinstance(value, str):
@ -349,7 +352,7 @@ class ToolChainExecutor:
value = "" value = ""
else: else:
value = "" value = ""
# 确保返回字符串 # 确保返回字符串
if isinstance(value, (dict, list)): if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False) return json.dumps(value, ensure_ascii=False)
@ -448,39 +451,39 @@ class ToolChainExecutor:
class ToolChainManager: class ToolChainManager:
"""工具链管理器""" """工具链管理器"""
_instance: Optional["ToolChainManager"] = None _instance: Optional["ToolChainManager"] = None
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self):
if self._initialized: if self._initialized:
return return
self._initialized = True self._initialized = True
self._chains: Dict[str, ToolChainDefinition] = {} self._chains: Dict[str, ToolChainDefinition] = {}
self._executor: Optional[ToolChainExecutor] = None self._executor: Optional[ToolChainExecutor] = None
def set_executor(self, mcp_manager) -> None: def set_executor(self, mcp_manager) -> None:
"""设置执行器""" """设置执行器"""
self._executor = ToolChainExecutor(mcp_manager) self._executor = ToolChainExecutor(mcp_manager)
def add_chain(self, chain: ToolChainDefinition) -> bool: def add_chain(self, chain: ToolChainDefinition) -> bool:
"""添加工具链""" """添加工具链"""
if not chain.name: if not chain.name:
logger.error("工具链名称不能为空") logger.error("工具链名称不能为空")
return False return False
if chain.name in self._chains: if chain.name in self._chains:
logger.warning(f"工具链 {chain.name} 已存在,将被覆盖") logger.warning(f"工具链 {chain.name} 已存在,将被覆盖")
self._chains[chain.name] = chain self._chains[chain.name] = chain
logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)") logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)")
return True return True
def remove_chain(self, name: str) -> bool: def remove_chain(self, name: str) -> bool:
"""移除工具链""" """移除工具链"""
if name in self._chains: if name in self._chains:
@ -488,19 +491,19 @@ class ToolChainManager:
logger.info(f"已移除工具链: {name}") logger.info(f"已移除工具链: {name}")
return True return True
return False return False
def get_chain(self, name: str) -> Optional[ToolChainDefinition]: def get_chain(self, name: str) -> Optional[ToolChainDefinition]:
"""获取工具链""" """获取工具链"""
return self._chains.get(name) return self._chains.get(name)
def get_all_chains(self) -> Dict[str, ToolChainDefinition]: def get_all_chains(self) -> Dict[str, ToolChainDefinition]:
"""获取所有工具链""" """获取所有工具链"""
return self._chains.copy() return self._chains.copy()
def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]: def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]:
"""获取所有启用的工具链""" """获取所有启用的工具链"""
return {name: chain for name, chain in self._chains.items() if chain.enabled} return {name: chain for name, chain in self._chains.items() if chain.enabled}
async def execute_chain( async def execute_chain(
self, self,
chain_name: str, chain_name: str,
@ -514,64 +517,64 @@ class ToolChainManager:
final_output="", final_output="",
error=f"工具链 {chain_name} 不存在", error=f"工具链 {chain_name} 不存在",
) )
if not chain.enabled: if not chain.enabled:
return ChainExecutionResult( return ChainExecutionResult(
success=False, success=False,
final_output="", final_output="",
error=f"工具链 {chain_name} 已禁用", error=f"工具链 {chain_name} 已禁用",
) )
if not self._executor: if not self._executor:
return ChainExecutionResult( return ChainExecutionResult(
success=False, success=False,
final_output="", final_output="",
error="工具链执行器未初始化", error="工具链执行器未初始化",
) )
return await self._executor.execute(chain, input_args) return await self._executor.execute(chain, input_args)
def load_from_json(self, json_str: str) -> Tuple[int, List[str]]: def load_from_json(self, json_str: str) -> Tuple[int, List[str]]:
"""从 JSON 字符串加载工具链配置 """从 JSON 字符串加载工具链配置
Returns: Returns:
(成功加载数量, 错误列表) (成功加载数量, 错误列表)
""" """
errors = [] errors = []
loaded = 0 loaded = 0
try: try:
data = json.loads(json_str) if json_str.strip() else [] data = json.loads(json_str) if json_str.strip() else []
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
return 0, [f"JSON 解析失败: {e}"] return 0, [f"JSON 解析失败: {e}"]
if not isinstance(data, list): if not isinstance(data, list):
data = [data] data = [data]
for i, item in enumerate(data): for i, item in enumerate(data):
try: try:
chain = ToolChainDefinition.from_dict(item) chain = ToolChainDefinition.from_dict(item)
if not chain.name: if not chain.name:
errors.append(f"{i+1} 个工具链缺少名称") errors.append(f"{i + 1} 个工具链缺少名称")
continue continue
if not chain.steps: if not chain.steps:
errors.append(f"工具链 {chain.name} 没有步骤") errors.append(f"工具链 {chain.name} 没有步骤")
continue continue
self.add_chain(chain) self.add_chain(chain)
loaded += 1 loaded += 1
except Exception as e: except Exception as e:
errors.append(f"{i+1} 个工具链解析失败: {e}") errors.append(f"{i + 1} 个工具链解析失败: {e}")
return loaded, errors return loaded, errors
def export_to_json(self, pretty: bool = True) -> str: def export_to_json(self, pretty: bool = True) -> str:
"""导出所有工具链为 JSON""" """导出所有工具链为 JSON"""
chains_data = [chain.to_dict() for chain in self._chains.values()] chains_data = [chain.to_dict() for chain in self._chains.values()]
if pretty: if pretty:
return json.dumps(chains_data, ensure_ascii=False, indent=2) return json.dumps(chains_data, ensure_ascii=False, indent=2)
return json.dumps(chains_data, ensure_ascii=False) return json.dumps(chains_data, ensure_ascii=False)
def clear(self) -> None: def clear(self) -> None:
"""清空所有工具链""" """清空所有工具链"""
self._chains.clear() self._chains.clear()

View File

@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
chat_stream=self.message.chat_stream, chat_stream=self.message.chat_stream,
reply_reason=reply_reason, reply_reason=reply_reason,
enable_chinese_typo=False, enable_chinese_typo=False,
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"", extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
) )
if result_status: if result_status:
# 发送生成的回复 # 发送生成的回复

View File

@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
config_base_module.logger = logging.getLogger("config_base_test_logger") config_base_module.logger = logging.getLogger("config_base_test_logger")
class SimpleClass(ConfigBase): class SimpleClass(ConfigBase):
a: int = 1 a: int = 1
b: str = "test" b: str = "test"
@ -282,7 +283,7 @@ class TestConfigBase:
True, True,
"ConfigBase is not Hashable", "ConfigBase is not Hashable",
id="listset-validation-set-configbase-element_reject", id="listset-validation-set-configbase-element_reject",
) ),
], ],
) )
def test_validate_list_set_type(self, annotation, expect_error, error_fragment): def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
@ -340,7 +341,7 @@ class TestConfigBase:
False, False,
None, None,
id="dict-validation-happy-configbase-value", id="dict-validation-happy-configbase-value",
) ),
], ],
) )
def test_validate_dict_type(self, annotation, expect_error, error_fragment): def test_validate_dict_type(self, annotation, expect_error, error_fragment):
@ -353,13 +354,11 @@ class TestConfigBase:
field_name = "mapping" field_name = "mapping"
if expect_error: if expect_error:
# Act / Assert # Act / Assert
with pytest.raises(TypeError) as exc_info: with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name) dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value) assert error_fragment in str(exc_info.value)
else: else:
# Act # Act
dummy._validate_dict_type(annotation, field_name) dummy._validate_dict_type(annotation, field_name)
@ -392,7 +391,7 @@ class TestConfigBase:
# Assert # Assert
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
def test_discourage_any_usage_suppressed_warning(self, caplog): def test_discourage_any_usage_suppressed_warning(self, caplog):
class Sample(ConfigBase): class Sample(ConfigBase):
_validate_any: bool = False _validate_any: bool = False

View File

@ -4,7 +4,6 @@ import importlib
import pytest import pytest
from pathlib import Path from pathlib import Path
import importlib.util import importlib.util
import asyncio
class DummyLogger: class DummyLogger:
@ -71,6 +70,7 @@ class DummyLLMRequest:
async def generate_response_for_image(self, prompt, image_base64, image_format, temp): async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
return ("dummy description", {}) return ("dummy description", {})
class DummySelect: class DummySelect:
def __init__(self, *a, **k): def __init__(self, *a, **k):
pass pass
@ -81,6 +81,7 @@ class DummySelect:
def limit(self, n): def limit(self, n):
return self return self
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_external_dependencies(monkeypatch): def patch_external_dependencies(monkeypatch):
# Provide dummy implementations as modules so that importing image_manager is safe # Provide dummy implementations as modules so that importing image_manager is safe
@ -103,11 +104,11 @@ def patch_external_dependencies(monkeypatch):
# Patch MaiImage data model # Patch MaiImage data model
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage) data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod) monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
# Patch SQLModel select function # Patch SQLModel select function
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect()) sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod) monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
# Patch config values used at import-time # Patch config values used at import-time
cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style")) cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style"))
model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm")) model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm"))
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
if tmp_path is not None: if tmp_path is not None:
tmpdir = Path(tmp_path) tmpdir = Path(tmp_path)
tmpdir.mkdir(parents=True, exist_ok=True) tmpdir.mkdir(parents=True, exist_ok=True)
setattr(mod, "IMAGE_DIR", tmpdir) mod.IMAGE_DIR = tmpdir
except Exception: except Exception:
pass pass
return mod return mod
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
# cleanup should run without error # cleanup should run without error
mgr.cleanup_invalid_descriptions_in_db() mgr.cleanup_invalid_descriptions_in_db()

View File

@ -1,5 +1,3 @@
import pytest
from src.config.official_configs import ChatConfig from src.config.official_configs import ChatConfig
from src.config.config import Config from src.config.config import Config
from src.webui.config_schema import ConfigSchemaGenerator from src.webui.config_schema import ConfigSchemaGenerator

View File

@ -387,7 +387,7 @@ def test_auth_required_list(client):
"""测试未认证访问列表端点401""" """测试未认证访问列表端点401"""
# Without mock_token_verify fixture # Without mock_token_verify fixture
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False): with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
response = client.get("/emoji/list") client.get("/emoji/list")
# verify_auth_token 返回 False 会触发 HTTPException # verify_auth_token 返回 False 会触发 HTTPException
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现 # 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
# 这里假设它抛出 401 # 这里假设它抛出 401
@ -397,7 +397,7 @@ def test_auth_required_update(client, sample_emojis):
"""测试未认证访问更新端点401""" """测试未认证访问更新端点401"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False): with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
emoji_id = sample_emojis[0].id emoji_id = sample_emojis[0].id
response = client.patch(f"/emoji/{emoji_id}", json={"description": "test"}) client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
# Should be unauthorized # Should be unauthorized

View File

@ -1,6 +1,5 @@
"""Expression routes pytest tests""" """Expression routes pytest tests"""
from datetime import datetime
from typing import Generator from typing import Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -12,7 +11,6 @@ from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, select from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
def create_test_app() -> FastAPI: def create_test_app() -> FastAPI:

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Set, Tuple
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.logger import get_logger from src.common.logger import get_logger # noqa: E402
logger = get_logger("evaluation_stats_analyzer") logger = get_logger("evaluation_stats_analyzer")
@ -38,10 +38,10 @@ def parse_datetime(dt_str: str) -> datetime | None:
def analyze_single_file(file_path: str) -> Dict: def analyze_single_file(file_path: str) -> Dict:
""" """
分析单个JSON文件的统计信息 分析单个JSON文件的统计信息
Args: Args:
file_path: JSON文件路径 file_path: JSON文件路径
Returns: Returns:
统计信息字典 统计信息字典
""" """
@ -65,40 +65,40 @@ def analyze_single_file(file_path: str) -> Dict:
"has_reason": False, "has_reason": False,
"reason_count": 0, "reason_count": 0,
} }
try: try:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# 基本信息 # 基本信息
stats["last_updated"] = data.get("last_updated") stats["last_updated"] = data.get("last_updated")
stats["total_count"] = data.get("total_count", 0) stats["total_count"] = data.get("total_count", 0)
results = data.get("manual_results", []) results = data.get("manual_results", [])
stats["actual_count"] = len(results) stats["actual_count"] = len(results)
if not results: if not results:
return stats return stats
# 统计通过/不通过 # 统计通过/不通过
suitable_count = sum(1 for r in results if r.get("suitable") is True) suitable_count = sum(1 for r in results if r.get("suitable") is True)
unsuitable_count = sum(1 for r in results if r.get("suitable") is False) unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
stats["suitable_count"] = suitable_count stats["suitable_count"] = suitable_count
stats["unsuitable_count"] = unsuitable_count stats["unsuitable_count"] = unsuitable_count
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0 stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
# 统计唯一的(situation, style)对 # 统计唯一的(situation, style)对
pairs: Set[Tuple[str, str]] = set() pairs: Set[Tuple[str, str]] = set()
for r in results: for r in results:
if "situation" in r and "style" in r: if "situation" in r and "style" in r:
pairs.add((r["situation"], r["style"])) pairs.add((r["situation"], r["style"]))
stats["unique_pairs"] = len(pairs) stats["unique_pairs"] = len(pairs)
# 统计评估者 # 统计评估者
for r in results: for r in results:
evaluator = r.get("evaluator", "unknown") evaluator = r.get("evaluator", "unknown")
stats["evaluators"][evaluator] += 1 stats["evaluators"][evaluator] += 1
# 统计评估时间 # 统计评估时间
evaluation_dates = [] evaluation_dates = []
for r in results: for r in results:
@ -107,7 +107,7 @@ def analyze_single_file(file_path: str) -> Dict:
dt = parse_datetime(evaluated_at) dt = parse_datetime(evaluated_at)
if dt: if dt:
evaluation_dates.append(dt) evaluation_dates.append(dt)
stats["evaluation_dates"] = evaluation_dates stats["evaluation_dates"] = evaluation_dates
if evaluation_dates: if evaluation_dates:
min_date = min(evaluation_dates) min_date = min(evaluation_dates)
@ -115,18 +115,18 @@ def analyze_single_file(file_path: str) -> Dict:
stats["date_range"] = { stats["date_range"] = {
"start": min_date.isoformat(), "start": min_date.isoformat(),
"end": max_date.isoformat(), "end": max_date.isoformat(),
"duration_days": (max_date - min_date).days + 1 "duration_days": (max_date - min_date).days + 1,
} }
# 检查字段存在性 # 检查字段存在性
stats["has_expression_id"] = any("expression_id" in r for r in results) stats["has_expression_id"] = any("expression_id" in r for r in results)
stats["has_reason"] = any(r.get("reason") for r in results) stats["has_reason"] = any(r.get("reason") for r in results)
stats["reason_count"] = sum(1 for r in results if r.get("reason")) stats["reason_count"] = sum(1 for r in results if r.get("reason"))
except Exception as e: except Exception as e:
stats["error"] = str(e) stats["error"] = str(e)
logger.error(f"分析文件 {file_name} 时出错: {e}") logger.error(f"分析文件 {file_name} 时出错: {e}")
return stats return stats
@ -136,57 +136,57 @@ def print_file_stats(stats: Dict, index: int = None):
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
print(f"{prefix}文件: {stats['file_name']}") print(f"{prefix}文件: {stats['file_name']}")
print(f"{'=' * 80}") print(f"{'=' * 80}")
if stats["error"]: if stats["error"]:
print(f"✗ 错误: {stats['error']}") print(f"✗ 错误: {stats['error']}")
return return
print(f"文件路径: {stats['file_path']}") print(f"文件路径: {stats['file_path']}")
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)") print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
if stats["last_updated"]: if stats["last_updated"]:
print(f"最后更新: {stats['last_updated']}") print(f"最后更新: {stats['last_updated']}")
print("\n【记录统计】") print("\n【记录统计】")
print(f" 文件中的 total_count: {stats['total_count']}") print(f" 文件中的 total_count: {stats['total_count']}")
print(f" 实际记录数: {stats['actual_count']}") print(f" 实际记录数: {stats['actual_count']}")
if stats['total_count'] != stats['actual_count']: if stats["total_count"] != stats["actual_count"]:
diff = stats['total_count'] - stats['actual_count'] diff = stats["total_count"] - stats["actual_count"]
print(f" ⚠️ 数量不一致,差值: {diff:+d}") print(f" ⚠️ 数量不一致,差值: {diff:+d}")
print("\n【评估结果统计】") print("\n【评估结果统计】")
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)") print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)") print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
print("\n【唯一性统计】") print("\n【唯一性统计】")
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}") print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}")
if stats['actual_count'] > 0: if stats["actual_count"] > 0:
duplicate_count = stats['actual_count'] - stats['unique_pairs'] duplicate_count = stats["actual_count"] - stats["unique_pairs"]
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)") print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
print("\n【评估者统计】") print("\n【评估者统计】")
if stats['evaluators']: if stats["evaluators"]:
for evaluator, count in stats['evaluators'].most_common(): for evaluator, count in stats["evaluators"].most_common():
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" {evaluator}: {count} 条 ({rate:.2f}%)") print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
else: else:
print(" 无评估者信息") print(" 无评估者信息")
print("\n【时间统计】") print("\n【时间统计】")
if stats['date_range']: if stats["date_range"]:
print(f" 最早评估时间: {stats['date_range']['start']}") print(f" 最早评估时间: {stats['date_range']['start']}")
print(f" 最晚评估时间: {stats['date_range']['end']}") print(f" 最晚评估时间: {stats['date_range']['end']}")
print(f" 评估时间跨度: {stats['date_range']['duration_days']}") print(f" 评估时间跨度: {stats['date_range']['duration_days']}")
else: else:
print(" 无时间信息") print(" 无时间信息")
print("\n【字段统计】") print("\n【字段统计】")
print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}") print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}")
print(f" 包含 reason: {'' if stats['has_reason'] else ''}") print(f" 包含 reason: {'' if stats['has_reason'] else ''}")
if stats['has_reason']: if stats["has_reason"]:
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0 rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)") print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
@ -195,35 +195,35 @@ def print_summary(all_stats: List[Dict]):
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
print("汇总统计") print("汇总统计")
print(f"{'=' * 80}") print(f"{'=' * 80}")
total_files = len(all_stats) total_files = len(all_stats)
valid_files = [s for s in all_stats if not s.get("error")] valid_files = [s for s in all_stats if not s.get("error")]
error_files = [s for s in all_stats if s.get("error")] error_files = [s for s in all_stats if s.get("error")]
print("\n【文件统计】") print("\n【文件统计】")
print(f" 总文件数: {total_files}") print(f" 总文件数: {total_files}")
print(f" 成功解析: {len(valid_files)}") print(f" 成功解析: {len(valid_files)}")
print(f" 解析失败: {len(error_files)}") print(f" 解析失败: {len(error_files)}")
if error_files: if error_files:
print("\n 失败文件列表:") print("\n 失败文件列表:")
for stats in error_files: for stats in error_files:
print(f" - {stats['file_name']}: {stats['error']}") print(f" - {stats['file_name']}: {stats['error']}")
if not valid_files: if not valid_files:
print("\n没有成功解析的文件") print("\n没有成功解析的文件")
return return
# 汇总记录统计 # 汇总记录统计
total_records = sum(s['actual_count'] for s in valid_files) total_records = sum(s["actual_count"] for s in valid_files)
total_suitable = sum(s['suitable_count'] for s in valid_files) total_suitable = sum(s["suitable_count"] for s in valid_files)
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files) total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
total_unique_pairs = set() total_unique_pairs = set()
# 收集所有唯一的(situation, style)对 # 收集所有唯一的(situation, style)对
for stats in valid_files: for stats in valid_files:
try: try:
with open(stats['file_path'], "r", encoding="utf-8") as f: with open(stats["file_path"], "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
results = data.get("manual_results", []) results = data.get("manual_results", [])
for r in results: for r in results:
@ -231,23 +231,31 @@ def print_summary(all_stats: List[Dict]):
total_unique_pairs.add((r["situation"], r["style"])) total_unique_pairs.add((r["situation"], r["style"]))
except Exception: except Exception:
pass pass
print("\n【记录汇总】") print("\n【记录汇总】")
print(f" 总记录数: {total_records:,}") print(f" 总记录数: {total_records:,}")
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条") print(
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条") f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
if total_records > 0
else " 通过: 0 条"
)
print(
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
if total_records > 0
else " 不通过: 0 条"
)
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}") print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}")
if total_records > 0: if total_records > 0:
duplicate_count = total_records - len(total_unique_pairs) duplicate_count = total_records - len(total_unique_pairs)
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0 duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)") print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
# 汇总评估者统计 # 汇总评估者统计
all_evaluators = Counter() all_evaluators = Counter()
for stats in valid_files: for stats in valid_files:
all_evaluators.update(stats['evaluators']) all_evaluators.update(stats["evaluators"])
print("\n【评估者汇总】") print("\n【评估者汇总】")
if all_evaluators: if all_evaluators:
for evaluator, count in all_evaluators.most_common(): for evaluator, count in all_evaluators.most_common():
@ -255,12 +263,12 @@ def print_summary(all_stats: List[Dict]):
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)") print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
else: else:
print(" 无评估者信息") print(" 无评估者信息")
# 汇总时间范围 # 汇总时间范围
all_dates = [] all_dates = []
for stats in valid_files: for stats in valid_files:
all_dates.extend(stats['evaluation_dates']) all_dates.extend(stats["evaluation_dates"])
if all_dates: if all_dates:
min_date = min(all_dates) min_date = min(all_dates)
max_date = max(all_dates) max_date = max(all_dates)
@ -268,9 +276,9 @@ def print_summary(all_stats: List[Dict]):
print(f" 最早评估时间: {min_date.isoformat()}") print(f" 最早评估时间: {min_date.isoformat()}")
print(f" 最晚评估时间: {max_date.isoformat()}") print(f" 最晚评估时间: {max_date.isoformat()}")
print(f" 总时间跨度: {(max_date - min_date).days + 1}") print(f" 总时间跨度: {(max_date - min_date).days + 1}")
# 文件大小汇总 # 文件大小汇总
total_size = sum(s['file_size'] for s in valid_files) total_size = sum(s["file_size"] for s in valid_files)
avg_size = total_size / len(valid_files) if valid_files else 0 avg_size = total_size / len(valid_files) if valid_files else 0
print("\n【文件大小汇总】") print("\n【文件大小汇总】")
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)") print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
@ -282,35 +290,35 @@ def main():
logger.info("=" * 80) logger.info("=" * 80)
logger.info("开始分析评估结果统计信息") logger.info("开始分析评估结果统计信息")
logger.info("=" * 80) logger.info("=" * 80)
if not os.path.exists(TEMP_DIR): if not os.path.exists(TEMP_DIR):
print(f"\n✗ 错误未找到temp目录: {TEMP_DIR}") print(f"\n✗ 错误未找到temp目录: {TEMP_DIR}")
logger.error(f"未找到temp目录: {TEMP_DIR}") logger.error(f"未找到temp目录: {TEMP_DIR}")
return return
# 查找所有JSON文件 # 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json")) json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files: if not json_files:
print(f"\n✗ 错误temp目录下未找到JSON文件: {TEMP_DIR}") print(f"\n✗ 错误temp目录下未找到JSON文件: {TEMP_DIR}")
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}") logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
return return
json_files.sort() # 按文件名排序 json_files.sort() # 按文件名排序
print(f"\n找到 {len(json_files)} 个JSON文件") print(f"\n找到 {len(json_files)} 个JSON文件")
print("=" * 80) print("=" * 80)
# 分析每个文件 # 分析每个文件
all_stats = [] all_stats = []
for i, json_file in enumerate(json_files, 1): for i, json_file in enumerate(json_files, 1):
stats = analyze_single_file(json_file) stats = analyze_single_file(json_file)
all_stats.append(stats) all_stats.append(stats)
print_file_stats(stats, index=i) print_file_stats(stats, index=i)
# 打印汇总统计 # 打印汇总统计
print_summary(all_stats) print_summary(all_stats)
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
print("分析完成") print("分析完成")
print(f"{'=' * 80}") print(f"{'=' * 80}")
@ -318,5 +326,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -171,7 +171,9 @@ def main():
sys.exit(1) sys.exit(1)
if not args.raw_index: if not args.raw_index:
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3") logger.info(
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
)
sys.exit(1) sys.exit(1)
# 解析索引列表1-based # 解析索引列表1-based

View File

@ -22,11 +22,11 @@ from collections import defaultdict
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Expression from src.common.database.database_model import Expression # noqa: E402
from src.common.database.database import db from src.common.database.database import db # noqa: E402
from src.common.logger import get_logger from src.common.logger import get_logger # noqa: E402
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config from src.config.config import model_config # noqa: E402
logger = get_logger("expression_evaluator_count_analysis_llm") logger = get_logger("expression_evaluator_count_analysis_llm")
@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
""" """
加载已有的评估结果 加载已有的评估结果
Returns: Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合) (已有结果列表, 已评估的项目(situation, style)元组集合)
""" """
if not os.path.exists(COUNT_ANALYSIS_FILE): if not os.path.exists(COUNT_ANALYSIS_FILE):
return [], set() return [], set()
try: try:
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f: with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
def save_results(evaluation_results: List[Dict]): def save_results(evaluation_results: List[Dict]):
""" """
保存评估结果到文件 保存评估结果到文件
Args: Args:
evaluation_results: 评估结果列表 evaluation_results: 评估结果列表
""" """
try: try:
os.makedirs(TEMP_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True)
data = { data = {
"last_updated": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results), "total_count": len(evaluation_results),
"evaluation_results": evaluation_results "evaluation_results": evaluation_results,
} }
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f: with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}") logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)") print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
except Exception as e: except Exception as e:
@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]):
print(f"\n✗ 保存评估结果失败: {e}") print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation( def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
evaluated_pairs: Set[Tuple[str, str]] = None
) -> List[Expression]:
""" """
选择用于评估的表达方式 选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目 选择所有count>1的项目然后选择两倍数量的count=1的项目
Args: Args:
evaluated_pairs: 已评估的项目集合用于避免重复 evaluated_pairs: 已评估的项目集合用于避免重复
Returns: Returns:
选中的表达方式列表 选中的表达方式列表
""" """
if evaluated_pairs is None: if evaluated_pairs is None:
evaluated_pairs = set() evaluated_pairs = set()
try: try:
# 查询所有表达方式 # 查询所有表达方式
all_expressions = list(Expression.select()) all_expressions = list(Expression.select())
if not all_expressions: if not all_expressions:
logger.warning("数据库中没有表达方式记录") logger.warning("数据库中没有表达方式记录")
return [] return []
# 过滤出未评估的项目 # 过滤出未评估的项目
unevaluated = [ unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
if not unevaluated: if not unevaluated:
logger.warning("所有项目都已评估完成") logger.warning("所有项目都已评估完成")
return [] return []
# 按count分组 # 按count分组
count_eq1 = [expr for expr in unevaluated if expr.count == 1] count_eq1 = [expr for expr in unevaluated if expr.count == 1]
count_gt1 = [expr for expr in unevaluated if expr.count > 1] count_gt1 = [expr for expr in unevaluated if expr.count > 1]
logger.info(f"未评估项目中count=1的有{len(count_eq1)}count>1的有{len(count_gt1)}") logger.info(f"未评估项目中count=1的有{len(count_eq1)}count>1的有{len(count_gt1)}")
# 选择所有count>1的项目 # 选择所有count>1的项目
selected_count_gt1 = count_gt1.copy() selected_count_gt1 = count_gt1.copy()
# 选择count=1的项目数量为count>1数量的2倍 # 选择count=1的项目数量为count>1数量的2倍
count_gt1_count = len(selected_count_gt1) count_gt1_count = len(selected_count_gt1)
count_eq1_needed = count_gt1_count * 2 count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed: if len(count_eq1) < count_eq1_needed:
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}") logger.warning(
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}"
)
count_eq1_needed = len(count_eq1) count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目 # 随机选择count=1的项目
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else [] selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
selected = selected_count_gt1 + selected_count_eq1 selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序 random.shuffle(selected) # 打乱顺序
logger.info(f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍") logger.info(
f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍"
)
return selected return selected
except Exception as e: except Exception as e:
logger.error(f"选择表达方式失败: {e}") logger.error(f"选择表达方式失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return [] return []
@ -155,11 +155,11 @@ def select_expressions_for_evaluation(
def create_evaluation_prompt(situation: str, style: str) -> str: def create_evaluation_prompt(situation: str, style: str) -> str:
""" """
创建评估提示词 创建评估提示词
Args: Args:
situation: 情境 situation: 情境
style: 风格 style: 风格
Returns: Returns:
评估提示词 评估提示词
""" """
@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}} }}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因 如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容""" 请严格按照JSON格式输出不要包含其他内容"""
return prompt return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]: async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
""" """
执行单次LLM评估 执行单次LLM评估
Args: Args:
situation: 情境 situation: 情境
style: 风格 style: 风格
llm: LLM请求实例 llm: LLM请求实例
Returns: Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息 (suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
""" """
try: try:
prompt = create_evaluation_prompt(situation, style) prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}") logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async( response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt, prompt=prompt, temperature=0.6, max_tokens=1024
temperature=0.6,
max_tokens=1024
) )
logger.debug(f"LLM响应: {response}") logger.debug(f"LLM响应: {response}")
# 解析JSON响应 # 解析JSON响应
try: try:
evaluation = json.loads(response) evaluation = json.loads(response)
@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
evaluation = json.loads(json_match.group()) evaluation = json.loads(json_match.group())
else: else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False) suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由") reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None return suitable, reason, None
except Exception as e: except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e) return False, f"评估过程出错: {str(e)}", str(e)
@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict: async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
""" """
使用LLM评估单个表达方式 使用LLM评估单个表达方式
Args: Args:
expression: 表达方式对象 expression: 表达方式对象
llm: LLM请求实例 llm: LLM请求实例
Returns: Returns:
评估结果字典 评估结果字典
""" """
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}") logger.info(
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
)
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm) suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
if error: if error:
suitable = False suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}") logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return { return {
"situation": expression.situation, "situation": expression.situation,
"style": expression.style, "style": expression.style,
@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
"reason": reason, "reason": reason,
"error": error, "error": error,
"evaluator": "llm", "evaluator": "llm",
"evaluated_at": datetime.now().isoformat() "evaluated_at": datetime.now().isoformat(),
} }
def perform_statistical_analysis(evaluation_results: List[Dict]): def perform_statistical_analysis(evaluation_results: List[Dict]):
""" """
对评估结果进行统计分析 对评估结果进行统计分析
Args: Args:
evaluation_results: 评估结果列表 evaluation_results: 评估结果列表
""" """
if not evaluation_results: if not evaluation_results:
print("\n没有评估结果可供分析") print("\n没有评估结果可供分析")
return return
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("统计分析结果") print("统计分析结果")
print("=" * 60) print("=" * 60)
# 按count分组统计 # 按count分组统计
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0}) count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
for result in evaluation_results: for result in evaluation_results:
count = result.get("count", 1) count = result.get("count", 1)
suitable = result.get("suitable", False) suitable = result.get("suitable", False)
@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
count_groups[count]["suitable"] += 1 count_groups[count]["suitable"] += 1
else: else:
count_groups[count]["unsuitable"] += 1 count_groups[count]["unsuitable"] += 1
# 显示每个count的统计 # 显示每个count的统计
print("\n【按count分组统计】") print("\n【按count分组统计】")
print("-" * 60) print("-" * 60)
@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
suitable = group["suitable"] suitable = group["suitable"]
unsuitable = group["unsuitable"] unsuitable = group["unsuitable"]
pass_rate = (suitable / total * 100) if total > 0 else 0 pass_rate = (suitable / total * 100) if total > 0 else 0
print(f"Count = {count}:") print(f"Count = {count}:")
print(f" 总数: {total}") print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)") print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)") print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
print() print()
# 比较count=1和count>1 # 比较count=1和count>1
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0} count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0} count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
for result in evaluation_results: for result in evaluation_results:
count = result.get("count", 1) count = result.get("count", 1)
suitable = result.get("suitable", False) suitable = result.get("suitable", False)
if count == 1: if count == 1:
count_eq1_group["total"] += 1 count_eq1_group["total"] += 1
if suitable: if suitable:
@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
count_gt1_group["suitable"] += 1 count_gt1_group["suitable"] += 1
else: else:
count_gt1_group["unsuitable"] += 1 count_gt1_group["unsuitable"] += 1
print("\n【Count=1 vs Count>1 对比】") print("\n【Count=1 vs Count>1 对比】")
print("-" * 60) print("-" * 60)
eq1_total = count_eq1_group["total"] eq1_total = count_eq1_group["total"]
eq1_suitable = count_eq1_group["suitable"] eq1_suitable = count_eq1_group["suitable"]
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0 eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
gt1_total = count_gt1_group["total"] gt1_total = count_gt1_group["total"]
gt1_suitable = count_gt1_group["suitable"] gt1_suitable = count_gt1_group["suitable"]
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0 gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
print("Count = 1:") print("Count = 1:")
print(f" 总数: {eq1_total}") print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)") print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)") print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
print() print()
print("Count > 1:") print("Count > 1:")
print(f" 总数: {gt1_total}") print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)") print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)") print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
print() print()
# 进行卡方检验简化版使用2x2列联表 # 进行卡方检验简化版使用2x2列联表
if eq1_total > 0 and gt1_total > 0: if eq1_total > 0 and gt1_total > 0:
print("【统计显著性检验】") print("【统计显著性检验】")
print("-" * 60) print("-" * 60)
# 构建2x2列联表 # 构建2x2列联表
# 通过 不通过 # 通过 不通过
# count=1 a b # count=1 a b
@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
b = eq1_total - eq1_suitable b = eq1_total - eq1_suitable
c = gt1_suitable c = gt1_suitable
d = gt1_total - gt1_suitable d = gt1_total - gt1_suitable
# 计算卡方统计量简化版使用Pearson卡方检验 # 计算卡方统计量简化版使用Pearson卡方检验
n = eq1_total + gt1_total n = eq1_total + gt1_total
if n > 0: if n > 0:
@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
e_b = (eq1_total * (b + d)) / n e_b = (eq1_total * (b + d)) / n
e_c = (gt1_total * (a + c)) / n e_c = (gt1_total * (a + c)) / n
e_d = (gt1_total * (b + d)) / n e_d = (gt1_total * (b + d)) / n
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5 # 检查期望频数是否足够大(卡方检验要求每个期望频数>=5
min_expected = min(e_a, e_b, e_c, e_d) min_expected = min(e_a, e_b, e_c, e_d)
if min_expected < 5: if min_expected < 5:
print("警告期望频数小于5卡方检验可能不准确") print("警告期望频数小于5卡方检验可能不准确")
print("建议使用Fisher精确检验") print("建议使用Fisher精确检验")
# 计算卡方值 # 计算卡方值
chi_square = 0 chi_square = 0
if e_a > 0: if e_a > 0:
@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
chi_square += ((c - e_c) ** 2) / e_c chi_square += ((c - e_c) ** 2) / e_c
if e_d > 0: if e_d > 0:
chi_square += ((d - e_d) ** 2) / e_d chi_square += ((d - e_d) ** 2) / e_d
# 自由度 = (行数-1) * (列数-1) = 1 # 自由度 = (行数-1) * (列数-1) = 1
df = 1 df = 1
# 临界值(α=0.05 # 临界值(α=0.05
chi_square_critical_005 = 3.841 chi_square_critical_005 = 3.841
chi_square_critical_001 = 6.635 chi_square_critical_001 = 6.635
print(f"卡方统计量: {chi_square:.4f}") print(f"卡方统计量: {chi_square:.4f}")
print(f"自由度: {df}") print(f"自由度: {df}")
print(f"临界值 (α=0.05): {chi_square_critical_005}") print(f"临界值 (α=0.05): {chi_square_critical_005}")
print(f"临界值 (α=0.01): {chi_square_critical_001}") print(f"临界值 (α=0.01): {chi_square_critical_001}")
if chi_square >= chi_square_critical_001: if chi_square >= chi_square_critical_001:
print("结论: 在α=0.01水平下count=1和count>1的合格率存在显著差异p<0.01") print("结论: 在α=0.01水平下count=1和count>1的合格率存在显著差异p<0.01")
elif chi_square >= chi_square_critical_005: elif chi_square >= chi_square_critical_005:
print("结论: 在α=0.05水平下count=1和count>1的合格率存在显著差异p<0.05") print("结论: 在α=0.05水平下count=1和count>1的合格率存在显著差异p<0.05")
else: else:
print("结论: 在α=0.05水平下count=1和count>1的合格率不存在显著差异p≥0.05") print("结论: 在α=0.05水平下count=1和count>1的合格率不存在显著差异p≥0.05")
# 计算差异大小 # 计算差异大小
diff = abs(eq1_pass_rate - gt1_pass_rate) diff = abs(eq1_pass_rate - gt1_pass_rate)
print(f"\n合格率差异: {diff:.2f}%") print(f"\n合格率差异: {diff:.2f}%")
@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print("数据不足,无法进行统计检验") print("数据不足,无法进行统计检验")
else: else:
print("数据不足无法进行count=1和count>1的对比分析") print("数据不足无法进行count=1和count>1的对比分析")
# 保存统计分析结果 # 保存统计分析结果
analysis_result = { analysis_result = {
"analysis_time": datetime.now().isoformat(), "analysis_time": datetime.now().isoformat(),
"count_groups": {str(k): v for k, v in count_groups.items()}, "count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group, "count_eq1": count_eq1_group,
"count_gt1": count_gt1_group, "count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results) "total_evaluated": len(evaluation_results),
} }
try: try:
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json") analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
with open(analysis_file, "w", encoding="utf-8") as f: with open(analysis_file, "w", encoding="utf-8") as f:
@ -444,7 +444,7 @@ async def main():
logger.info("=" * 60) logger.info("=" * 60)
logger.info("开始表达方式按count分组的LLM评估和统计分析") logger.info("开始表达方式按count分组的LLM评估和统计分析")
logger.info("=" * 60) logger.info("=" * 60)
# 初始化数据库连接 # 初始化数据库连接
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
@ -452,97 +452,95 @@ async def main():
except Exception as e: except Exception as e:
logger.error(f"数据库连接失败: {e}") logger.error(f"数据库连接失败: {e}")
return return
# 加载已有评估结果 # 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results() existing_results, evaluated_pairs = load_existing_results()
evaluation_results = existing_results.copy() evaluation_results = existing_results.copy()
if evaluated_pairs: if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果") print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}") print(f"已评估项目数: {len(evaluated_pairs)}")
# 检查是否需要继续评估检查是否还有未评估的count>1项目 # 检查是否需要继续评估检查是否还有未评估的count>1项目
# 先查询未评估的count>1项目数量 # 先查询未评估的count>1项目数量
try: try:
all_expressions = list(Expression.select()) all_expressions = list(Expression.select())
unevaluated_count_gt1 = [ unevaluated_count_gt1 = [
expr for expr in all_expressions expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
] ]
has_unevaluated = len(unevaluated_count_gt1) > 0 has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e: except Exception as e:
logger.error(f"查询未评估项目失败: {e}") logger.error(f"查询未评估项目失败: {e}")
has_unevaluated = False has_unevaluated = False
if has_unevaluated: if has_unevaluated:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("开始LLM评估") print("开始LLM评估")
print("=" * 60) print("=" * 60)
print("评估结果会自动保存到文件\n") print("评估结果会自动保存到文件\n")
# 创建LLM实例 # 创建LLM实例
print("创建LLM实例...") print("创建LLM实例...")
try: try:
llm = LLMRequest( llm = LLMRequest(
model_set=model_config.model_task_config.tool_use, model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm" request_type="expression_evaluator_count_analysis_llm",
) )
print("✓ LLM实例创建成功\n") print("✓ LLM实例创建成功\n")
except Exception as e: except Exception as e:
logger.error(f"创建LLM实例失败: {e}") logger.error(f"创建LLM实例失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}") print(f"\n✗ 创建LLM实例失败: {e}")
db.close() db.close()
return return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目 # 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation( expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
evaluated_pairs=evaluated_pairs
)
if not expressions: if not expressions:
print("\n没有可评估的项目") print("\n没有可评估的项目")
else: else:
print(f"\n已选择 {len(expressions)} 条表达方式进行评估") print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)}") print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)}")
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)}\n") print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)}\n")
batch_results = [] batch_results = []
for i, expression in enumerate(expressions, 1): for i, expression in enumerate(expressions, 1):
print(f"LLM评估进度: {i}/{len(expressions)}") print(f"LLM评估进度: {i}/{len(expressions)}")
print(f" Situation: {expression.situation}") print(f" Situation: {expression.situation}")
print(f" Style: {expression.style}") print(f" Style: {expression.style}")
print(f" Count: {expression.count}") print(f" Count: {expression.count}")
llm_result = await llm_evaluate_expression(expression, llm) llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}") print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get('error'): if llm_result.get("error"):
print(f" 错误: {llm_result['error']}") print(f" 错误: {llm_result['error']}")
print() print()
batch_results.append(llm_result) batch_results.append(llm_result)
# 使用 (situation, style) 作为唯一标识 # 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((llm_result["situation"], llm_result["style"])) evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
# 添加延迟以避免API限流 # 添加延迟以避免API限流
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
# 将当前批次结果添加到总结果中 # 将当前批次结果添加到总结果中
evaluation_results.extend(batch_results) evaluation_results.extend(batch_results)
# 保存结果 # 保存结果
save_results(evaluation_results) save_results(evaluation_results)
else: else:
print(f"\n所有count>1的项目都已评估完成已有 {len(evaluation_results)} 条评估结果") print(f"\n所有count>1的项目都已评估完成已有 {len(evaluation_results)} 条评估结果")
# 进行统计分析 # 进行统计分析
if len(evaluation_results) > 0: if len(evaluation_results) > 0:
perform_statistical_analysis(evaluation_results) perform_statistical_analysis(evaluation_results)
else: else:
print("\n没有评估结果可供分析") print("\n没有评估结果可供分析")
# 关闭数据库连接 # 关闭数据库连接
try: try:
db.close() db.close()
@ -553,4 +551,3 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -20,9 +20,9 @@ from typing import List, Dict, Set, Tuple
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config from src.config.config import model_config # noqa: E402
from src.common.logger import get_logger from src.common.logger import get_logger # noqa: E402
logger = get_logger("expression_evaluator_llm") logger = get_logger("expression_evaluator_llm")
@ -33,7 +33,7 @@ TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
def load_manual_results() -> List[Dict]: def load_manual_results() -> List[Dict]:
""" """
加载人工评估结果自动读取temp目录下所有JSON文件并合并 加载人工评估结果自动读取temp目录下所有JSON文件并合并
Returns: Returns:
人工评估结果列表已去重 人工评估结果列表已去重
""" """
@ -42,62 +42,62 @@ def load_manual_results() -> List[Dict]:
print("\n✗ 错误未找到temp目录") print("\n✗ 错误未找到temp目录")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估") print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return [] return []
# 查找所有JSON文件 # 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json")) json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files: if not json_files:
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}") logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
print("\n✗ 错误temp目录下未找到JSON文件") print("\n✗ 错误temp目录下未找到JSON文件")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估") print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return [] return []
logger.info(f"找到 {len(json_files)} 个JSON文件") logger.info(f"找到 {len(json_files)} 个JSON文件")
print(f"\n找到 {len(json_files)} 个JSON文件:") print(f"\n找到 {len(json_files)} 个JSON文件:")
for json_file in json_files: for json_file in json_files:
print(f" - {os.path.basename(json_file)}") print(f" - {os.path.basename(json_file)}")
# 读取并合并所有JSON文件 # 读取并合并所有JSON文件
all_results = [] all_results = []
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重 seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
for json_file in json_files: for json_file in json_files:
try: try:
with open(json_file, "r", encoding="utf-8") as f: with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
results = data.get("manual_results", []) results = data.get("manual_results", [])
# 去重:使用(situation, style)作为唯一标识 # 去重:使用(situation, style)作为唯一标识
for result in results: for result in results:
if "situation" not in result or "style" not in result: if "situation" not in result or "style" not in result:
logger.warning(f"跳过无效数据(缺少必要字段): {result}") logger.warning(f"跳过无效数据(缺少必要字段): {result}")
continue continue
pair = (result["situation"], result["style"]) pair = (result["situation"], result["style"])
if pair not in seen_pairs: if pair not in seen_pairs:
seen_pairs.add(pair) seen_pairs.add(pair)
all_results.append(result) all_results.append(result)
logger.info(f"{os.path.basename(json_file)} 加载了 {len(results)} 条结果") logger.info(f"{os.path.basename(json_file)} 加载了 {len(results)} 条结果")
except Exception as e: except Exception as e:
logger.error(f"加载文件 {json_file} 失败: {e}") logger.error(f"加载文件 {json_file} 失败: {e}")
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}") print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
continue continue
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)") logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)") print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
return all_results return all_results
def create_evaluation_prompt(situation: str, style: str) -> str: def create_evaluation_prompt(situation: str, style: str) -> str:
""" """
创建评估提示词 创建评估提示词
Args: Args:
situation: 情境 situation: 情境
style: 风格 style: 风格
Returns: Returns:
评估提示词 评估提示词
""" """
@ -119,51 +119,50 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}} }}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因 如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容""" 请严格按照JSON格式输出不要包含其他内容"""
return prompt return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]: async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
""" """
执行单次LLM评估 执行单次LLM评估
Args: Args:
situation: 情境 situation: 情境
style: 风格 style: 风格
llm: LLM请求实例 llm: LLM请求实例
Returns: Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息 (suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
""" """
try: try:
prompt = create_evaluation_prompt(situation, style) prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}") logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async( response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt, prompt=prompt, temperature=0.6, max_tokens=1024
temperature=0.6,
max_tokens=1024
) )
logger.debug(f"LLM响应: {response}") logger.debug(f"LLM响应: {response}")
# 解析JSON响应 # 解析JSON响应
try: try:
evaluation = json.loads(response) evaluation = json.loads(response)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
import re import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL) json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match: if json_match:
evaluation = json.loads(json_match.group()) evaluation = json.loads(json_match.group())
else: else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False) suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由") reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None return suitable, reason, None
except Exception as e: except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e) return False, f"评估过程出错: {str(e)}", str(e)
@ -172,68 +171,68 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict: async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
""" """
使用LLM评估单个表达方式 使用LLM评估单个表达方式
Args: Args:
situation: 情境 situation: 情境
style: 风格 style: 风格
llm: LLM请求实例 llm: LLM请求实例
Returns: Returns:
评估结果字典 评估结果字典
""" """
logger.info(f"开始评估表达方式: situation={situation}, style={style}") logger.info(f"开始评估表达方式: situation={situation}, style={style}")
suitable, reason, error = await _single_llm_evaluation(situation, style, llm) suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
if error: if error:
suitable = False suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}") logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return { return {
"situation": situation, "situation": situation,
"style": style, "style": style,
"suitable": suitable, "suitable": suitable,
"reason": reason, "reason": reason,
"error": error, "error": error,
"evaluator": "llm" "evaluator": "llm",
} }
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict: def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
""" """
对比人工评估和LLM评估的结果 对比人工评估和LLM评估的结果
Args: Args:
manual_results: 人工评估结果列表 manual_results: 人工评估结果列表
llm_results: LLM评估结果列表 llm_results: LLM评估结果列表
method_name: 评估方法名称用于标识 method_name: 评估方法名称用于标识
Returns: Returns:
对比分析结果字典 对比分析结果字典
""" """
# 按(situation, style)建立映射 # 按(situation, style)建立映射
llm_dict = {(r["situation"], r["style"]): r for r in llm_results} llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
total = len(manual_results) total = len(manual_results)
matched = 0 matched = 0
true_positives = 0 true_positives = 0
true_negatives = 0 true_negatives = 0
false_positives = 0 false_positives = 0
false_negatives = 0 false_negatives = 0
for manual_result in manual_results: for manual_result in manual_results:
pair = (manual_result["situation"], manual_result["style"]) pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair) llm_result = llm_dict.get(pair)
if llm_result is None: if llm_result is None:
continue continue
manual_suitable = manual_result["suitable"] manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"] llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable: if manual_suitable == llm_suitable:
matched += 1 matched += 1
if manual_suitable and llm_suitable: if manual_suitable and llm_suitable:
true_positives += 1 true_positives += 1
elif not manual_suitable and not llm_suitable: elif not manual_suitable and not llm_suitable:
@ -242,30 +241,36 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
false_positives += 1 false_positives += 1
elif manual_suitable and not llm_suitable: elif manual_suitable and not llm_suitable:
false_negatives += 1 false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0 accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0 precision = (
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0 (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
)
recall = (
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
)
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0 f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0 specificity = (
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
)
# 计算人工效标的不合适率 # 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数 manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0 manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
# 计算经过LLM删除后剩余项目中的不合适率 # 计算经过LLM删除后剩余项目中的不合适率
# 在所有项目中移除LLM判定为不合适的项目后剩下的项目 = TP + FPLLM判定为合适的项目 # 在所有项目中移除LLM判定为不合适的项目后剩下的项目 = TP + FPLLM判定为合适的项目
# 在这些剩下的项目中,按人工评定的不合适项目 = FP人工认为不合适但LLM认为合适 # 在这些剩下的项目中,按人工评定的不合适项目 = FP人工认为不合适但LLM认为合适
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数保留的项目 llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数保留的项目
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0 llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
# 两者百分比相减评估LLM评定修正后的不合适率是否有降低 # 两者百分比相减评估LLM评定修正后的不合适率是否有降低
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
random_baseline = 50.0 random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0 accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return { return {
"method": method_name, "method": method_name,
"total": total, "total": total,
@ -283,29 +288,29 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
"specificity": specificity, "specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate, "manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate, "llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference "rate_difference": rate_difference,
} }
async def main(count: int | None = None): async def main(count: int | None = None):
""" """
主函数 主函数
Args: Args:
count: 随机选取的数据条数如果为None则使用全部数据 count: 随机选取的数据条数如果为None则使用全部数据
""" """
logger.info("=" * 60) logger.info("=" * 60)
logger.info("开始表达方式LLM评估") logger.info("开始表达方式LLM评估")
logger.info("=" * 60) logger.info("=" * 60)
# 1. 加载人工评估结果 # 1. 加载人工评估结果
print("\n步骤1: 加载人工评估结果") print("\n步骤1: 加载人工评估结果")
manual_results = load_manual_results() manual_results = load_manual_results()
if not manual_results: if not manual_results:
return return
print(f"成功加载 {len(manual_results)} 条人工评估结果") print(f"成功加载 {len(manual_results)} 条人工评估结果")
# 如果指定了数量,随机选择指定数量的数据 # 如果指定了数量,随机选择指定数量的数据
if count is not None: if count is not None:
if count <= 0: if count <= 0:
@ -317,7 +322,7 @@ async def main(count: int | None = None):
random.seed() # 使用系统时间作为随机种子 random.seed() # 使用系统时间作为随机种子
manual_results = random.sample(manual_results, count) manual_results = random.sample(manual_results, count)
print(f"随机选取 {len(manual_results)} 条数据进行评估") print(f"随机选取 {len(manual_results)} 条数据进行评估")
# 验证数据完整性 # 验证数据完整性
valid_manual_results = [] valid_manual_results = []
for r in manual_results: for r in manual_results:
@ -325,62 +330,58 @@ async def main(count: int | None = None):
valid_manual_results.append(r) valid_manual_results.append(r)
else: else:
logger.warning(f"跳过无效数据: {r}") logger.warning(f"跳过无效数据: {r}")
if len(valid_manual_results) != len(manual_results): if len(valid_manual_results) != len(manual_results):
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过") print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
print(f"有效数据: {len(valid_manual_results)}") print(f"有效数据: {len(valid_manual_results)}")
# 2. 创建LLM实例并评估 # 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例") print("\n步骤2: 创建LLM实例")
try: try:
llm = LLMRequest( llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_llm"
)
except Exception as e: except Exception as e:
logger.error(f"创建LLM实例失败: {e}") logger.error(f"创建LLM实例失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return return
print("\n步骤3: 开始LLM评估") print("\n步骤3: 开始LLM评估")
llm_results = [] llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1): for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}") print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm( llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
manual_result["situation"],
manual_result["style"],
llm
))
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前 # 5. 输出FP和FN项目在评估结果之前
llm_dict = {(r["situation"], r["style"]): r for r in llm_results} llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
# 5.1 输出FP项目人工评估不通过但LLM误判为通过 # 5.1 输出FP项目人工评估不通过但LLM误判为通过
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive") print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60) print("=" * 60)
fp_items = [] fp_items = []
for manual_result in valid_manual_results: for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"]) pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair) llm_result = llm_dict.get(pair)
if llm_result is None: if llm_result is None:
continue continue
# 人工评估不通过但LLM评估通过FP情况 # 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]: if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({ fp_items.append(
"situation": manual_result["situation"], {
"style": manual_result["style"], "situation": manual_result["situation"],
"manual_suitable": manual_result["suitable"], "style": manual_result["style"],
"llm_suitable": llm_result["suitable"], "manual_suitable": manual_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"), "llm_suitable": llm_result["suitable"],
"llm_error": llm_result.get("error") "llm_reason": llm_result.get("reason", "未提供理由"),
}) "llm_error": llm_result.get("error"),
}
)
if fp_items: if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n") print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1): for idx, item in enumerate(fp_items, 1):
@ -389,36 +390,38 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}") print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌") print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)") print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'): if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}") print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}") print(f"LLM理由: {item['llm_reason']}")
print() print()
else: else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过") print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 5.2 输出FN项目人工评估通过但LLM误判为不通过 # 5.2 输出FN项目人工评估通过但LLM误判为不通过
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("人工评估通过但LLM误判为不通过的项目FN - False Negative") print("人工评估通过但LLM误判为不通过的项目FN - False Negative")
print("=" * 60) print("=" * 60)
fn_items = [] fn_items = []
for manual_result in valid_manual_results: for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"]) pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair) llm_result = llm_dict.get(pair)
if llm_result is None: if llm_result is None:
continue continue
# 人工评估通过但LLM评估不通过FN情况 # 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]: if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append({ fn_items.append(
"situation": manual_result["situation"], {
"style": manual_result["style"], "situation": manual_result["situation"],
"manual_suitable": manual_result["suitable"], "style": manual_result["style"],
"llm_suitable": llm_result["suitable"], "manual_suitable": manual_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"), "llm_suitable": llm_result["suitable"],
"llm_error": llm_result.get("error") "llm_reason": llm_result.get("reason", "未提供理由"),
}) "llm_error": llm_result.get("error"),
}
)
if fn_items: if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n") print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
for idx, item in enumerate(fn_items, 1): for idx, item in enumerate(fn_items, 1):
@ -427,33 +430,41 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}") print(f"Style: {item['style']}")
print("人工评估: 通过 ✅") print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)") print("LLM评估: 不通过 ❌ (误删)")
if item.get('llm_error'): if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}") print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}") print(f"LLM理由: {item['llm_reason']}")
print() print()
else: else:
print("\n✓ 没有误删项目所有人工评估通过的项目都被LLM正确识别为通过") print("\n✓ 没有误删项目所有人工评估通过的项目都被LLM正确识别为通过")
# 6. 对比分析并输出结果 # 6. 对比分析并输出结果
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估") comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)") print("评估结果(以人工评估为标准)")
print("=" * 60) print("=" * 60)
# 详细评估结果(核心指标优先) # 详细评估结果(核心指标优先)
print(f"\n--- {comparison['method']} ---") print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}") print(f" 总数: {comparison['total']}")
print() print()
# print(" 【核心能力指标】") # print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)") print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})") print(
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}") f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)") # print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print() print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)") print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})") print(
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}") f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
)
print(
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)") # print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print() print()
print(" 【其他指标】") print(" 【其他指标】")
@ -464,12 +475,18 @@ async def main(count: int | None = None):
print() print()
print(" 【不合适率分析】") print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%") print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}") print(
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
)
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适") print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print() print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%") print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})") print(
print(f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%") f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
)
print() print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%") # print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%") # print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
@ -480,21 +497,22 @@ async def main(count: int | None = None):
print(f" TN (正确识别为不合适): {comparison['true_negatives']}") print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️") print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️") print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 7. 保存结果到JSON文件 # 7. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json") output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
try: try:
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, "w", encoding="utf-8") as f:
json.dump({ json.dump(
"manual_results": valid_manual_results, {"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
"llm_results": llm_results, f,
"comparison": comparison ensure_ascii=False,
}, f, ensure_ascii=False, indent=2) indent=2,
)
logger.info(f"\n评估结果已保存到: {output_file}") logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e: except Exception as e:
logger.warning(f"保存结果到文件失败: {e}") logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("评估完成") print("评估完成")
print("=" * 60) print("=" * 60)
@ -509,15 +527,9 @@ if __name__ == "__main__":
python evaluate_expressions_llm_v6.py # 使用全部数据 python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据 python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据 python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
""" """,
) )
parser.add_argument( parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
"-n", "--count",
type=int,
default=None,
help="随机选取的数据条数(默认:使用全部数据)"
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(main(count=args.count)) asyncio.run(main(count=args.count))

View File

@ -18,9 +18,9 @@ from datetime import datetime
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Expression from src.common.database.database_model import Expression # noqa: E402
from src.common.database.database import db from src.common.database.database import db # noqa: E402
from src.common.logger import get_logger from src.common.logger import get_logger # noqa: E402
logger = get_logger("expression_evaluator_manual") logger = get_logger("expression_evaluator_manual")
@ -32,13 +32,13 @@ MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
""" """
加载已有的评估结果 加载已有的评估结果
Returns: Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合) (已有结果列表, 已评估的项目(situation, style)元组集合)
""" """
if not os.path.exists(MANUAL_EVAL_FILE): if not os.path.exists(MANUAL_EVAL_FILE):
return [], set() return [], set()
try: try:
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f: with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
@ -55,22 +55,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
def save_results(manual_results: List[Dict]): def save_results(manual_results: List[Dict]):
""" """
保存评估结果到文件 保存评估结果到文件
Args: Args:
manual_results: 评估结果列表 manual_results: 评估结果列表
""" """
try: try:
os.makedirs(TEMP_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True)
data = { data = {
"last_updated": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(),
"total_count": len(manual_results), "total_count": len(manual_results),
"manual_results": manual_results "manual_results": manual_results,
} }
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f: with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}") logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)") print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
except Exception as e: except Exception as e:
@ -81,45 +81,43 @@ def save_results(manual_results: List[Dict]):
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]: def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
""" """
获取未评估的表达方式 获取未评估的表达方式
Args: Args:
evaluated_pairs: 已评估的项目(situation, style)元组集合 evaluated_pairs: 已评估的项目(situation, style)元组集合
batch_size: 每次获取的数量 batch_size: 每次获取的数量
Returns: Returns:
未评估的表达方式列表 未评估的表达方式列表
""" """
try: try:
# 查询所有表达方式 # 查询所有表达方式
all_expressions = list(Expression.select()) all_expressions = list(Expression.select())
if not all_expressions: if not all_expressions:
logger.warning("数据库中没有表达方式记录") logger.warning("数据库中没有表达方式记录")
return [] return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致 # 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [ unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
if not unevaluated: if not unevaluated:
logger.info("所有项目都已评估完成") logger.info("所有项目都已评估完成")
return [] return []
# 如果未评估数量少于请求数量,返回所有 # 如果未评估数量少于请求数量,返回所有
if len(unevaluated) <= batch_size: if len(unevaluated) <= batch_size:
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回") logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
return unevaluated return unevaluated
# 随机选择指定数量 # 随机选择指定数量
selected = random.sample(unevaluated, batch_size) selected = random.sample(unevaluated, batch_size)
logger.info(f"{len(unevaluated)} 条未评估项目中随机选择了 {len(selected)}") logger.info(f"{len(unevaluated)} 条未评估项目中随机选择了 {len(selected)}")
return selected return selected
except Exception as e: except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}") logger.error(f"获取未评估表达方式失败: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return [] return []
@ -127,12 +125,12 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict: def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
""" """
人工评估单个表达方式 人工评估单个表达方式
Args: Args:
expression: 表达方式对象 expression: 表达方式对象
index: 当前索引从1开始 index: 当前索引从1开始
total: 总数 total: 总数
Returns: Returns:
评估结果字典如果用户退出则返回 None 评估结果字典如果用户退出则返回 None
""" """
@ -146,38 +144,38 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
print(" 输入 'n''no''0' 表示不合适(不通过)") print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估") print(" 输入 'q''quit' 退出评估")
print(" 输入 's''skip' 跳过当前项目") print(" 输入 's''skip' 跳过当前项目")
while True: while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower() user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ['q', 'quit']: if user_input in ["q", "quit"]:
print("退出评估") print("退出评估")
return None return None
if user_input in ['s', 'skip']: if user_input in ["s", "skip"]:
print("跳过当前项目") print("跳过当前项目")
return "skip" return "skip"
if user_input in ['y', 'yes', '1', '', '通过']: if user_input in ["y", "yes", "1", "", "通过"]:
suitable = True suitable = True
break break
elif user_input in ['n', 'no', '0', '', '不通过']: elif user_input in ["n", "no", "0", "", "不通过"]:
suitable = False suitable = False
break break
else: else:
print("输入无效,请重新输入 (y/n/q/s)") print("输入无效,请重新输入 (y/n/q/s)")
result = { result = {
"situation": expression.situation, "situation": expression.situation,
"style": expression.style, "style": expression.style,
"suitable": suitable, "suitable": suitable,
"reason": None, "reason": None,
"evaluator": "manual", "evaluator": "manual",
"evaluated_at": datetime.now().isoformat() "evaluated_at": datetime.now().isoformat(),
} }
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}") print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result return result
@ -186,7 +184,7 @@ def main():
logger.info("=" * 60) logger.info("=" * 60)
logger.info("开始表达方式人工评估") logger.info("开始表达方式人工评估")
logger.info("=" * 60) logger.info("=" * 60)
# 初始化数据库连接 # 初始化数据库连接
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
@ -194,41 +192,41 @@ def main():
except Exception as e: except Exception as e:
logger.error(f"数据库连接失败: {e}") logger.error(f"数据库连接失败: {e}")
return return
# 加载已有评估结果 # 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results() existing_results, evaluated_pairs = load_existing_results()
manual_results = existing_results.copy() manual_results = existing_results.copy()
if evaluated_pairs: if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果") print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}") print(f"已评估项目数: {len(evaluated_pairs)}")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("开始人工评估") print("开始人工评估")
print("=" * 60) print("=" * 60)
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目") print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
print("评估结果会自动保存到文件\n") print("评估结果会自动保存到文件\n")
batch_size = 10 batch_size = 10
batch_count = 0 batch_count = 0
while True: while True:
# 获取未评估的项目 # 获取未评估的项目
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size) expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
if not expressions: if not expressions:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("所有项目都已评估完成!") print("所有项目都已评估完成!")
print("=" * 60) print("=" * 60)
break break
batch_count += 1 batch_count += 1
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---") print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
batch_results = [] batch_results = []
for i, expression in enumerate(expressions, 1): for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions)) manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None: if manual_result is None:
# 用户退出 # 用户退出
print("\n评估已中断") print("\n评估已中断")
@ -237,34 +235,34 @@ def main():
manual_results.extend(batch_results) manual_results.extend(batch_results)
save_results(manual_results) save_results(manual_results)
return return
if manual_result == "skip": if manual_result == "skip":
# 跳过当前项目 # 跳过当前项目
continue continue
batch_results.append(manual_result) batch_results.append(manual_result)
# 使用 (situation, style) 作为唯一标识 # 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((manual_result["situation"], manual_result["style"])) evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
# 将当前批次结果添加到总结果中 # 将当前批次结果添加到总结果中
manual_results.extend(batch_results) manual_results.extend(batch_results)
# 保存结果 # 保存结果
save_results(manual_results) save_results(manual_results)
print(f"\n当前批次完成,已评估总数: {len(manual_results)}") print(f"\n当前批次完成,已评估总数: {len(manual_results)}")
# 询问是否继续 # 询问是否继续
while True: while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower() continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ['y', 'yes', '1', '', '继续']: if continue_input in ["y", "yes", "1", "", "继续"]:
break break
elif continue_input in ['n', 'no', '0', '', '退出']: elif continue_input in ["n", "no", "0", "", "退出"]:
print("\n评估结束") print("\n评估结束")
return return
else: else:
print("输入无效,请重新输入 (y/n)") print("输入无效,请重新输入 (y/n)")
# 关闭数据库连接 # 关闭数据库连接
try: try:
db.close() db.close()
@ -275,4 +273,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -134,9 +134,7 @@ def handle_import_openie(
# 在非交互模式下,不再询问用户,而是直接报错终止 # 在非交互模式下,不再询问用户,而是直接报错终止
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
if non_interactive: if non_interactive:
logger.error( logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
)
sys.exit(1) sys.exit(1)
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower() user_choice = input().strip().lower()
@ -189,9 +187,7 @@ def handle_import_openie(
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
# 新增确认提示 # 新增确认提示
if non_interactive: if non_interactive:
logger.warning( logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
)
else: else:
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型") print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
def main(argv: Optional[list[str]] = None) -> None: def main(argv: Optional[list[str]] = None) -> None:
"""主函数 - 解析参数并运行异步主流程。""" """主函数 - 解析参数并运行异步主流程。"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=( description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
"将其导入到 LPMM 的向量库与知识图中。"
)
) )
parser.add_argument( parser.add_argument(
"--non-interactive", "--non-interactive",

View File

@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
ensure_dirs() # 确保目录存在 ensure_dirs() # 确保目录存在
# 新增用户确认提示 # 新增用户确认提示
if non_interactive: if non_interactive:
logger.warning( logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
)
else: else:
print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。") print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@ -1,6 +1,5 @@
import os import os
import sys import sys
from typing import Set
# 保证可以导入 src.* # 保证可以导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -32,7 +31,6 @@ def main() -> None:
# KG 统计 # KG 统计
nodes = kg.graph.get_node_list() nodes = kg.graph.get_node_list()
edges = kg.graph.get_edge_list() edges = kg.graph.get_edge_list()
node_set: Set[str] = set(nodes)
para_nodes = [n for n in nodes if n.startswith("paragraph-")] para_nodes = [n for n in nodes if n.startswith("paragraph-")]
ent_nodes = [n for n in nodes if n.startswith("entity-")] ent_nodes = [n for n in nodes if n.startswith("entity-")]
@ -68,4 +66,3 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -29,6 +29,7 @@ except ImportError as e:
logger = get_logger("lpmm_interactive_manager") logger = get_logger("lpmm_interactive_manager")
async def interactive_add(): async def interactive_add():
"""交互式导入知识""" """交互式导入知识"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -38,7 +39,7 @@ async def interactive_add():
print(" - 支持多段落,段落间请保留空行。") print(" - 支持多段落,段落间请保留空行。")
print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。") print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。")
print("-" * 40) print("-" * 40)
lines = [] lines = []
while True: while True:
try: try:
@ -48,7 +49,7 @@ async def interactive_add():
lines.append(line) lines.append(line)
except EOFError: except EOFError:
break break
text = "\n".join(lines).strip() text = "\n".join(lines).strip()
if not text: if not text:
print("\n[!] 内容为空,操作已取消。") print("\n[!] 内容为空,操作已取消。")
@ -58,7 +59,7 @@ async def interactive_add():
try: try:
# 使用 lpmm_ops.py 中的接口 # 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.add_content(text) result = await lpmm_ops.add_content(text)
if result["status"] == "success": if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}") print(f"\n[√] 成功:{result['message']}")
print(f" 实际新增段落数: {result.get('count', 0)}") print(f" 实际新增段落数: {result.get('count', 0)}")
@ -68,6 +69,7 @@ async def interactive_add():
print(f"\n[×] 发生异常: {e}") print(f"\n[×] 发生异常: {e}")
logger.error(f"add_content 异常: {e}", exc_info=True) logger.error(f"add_content 异常: {e}", exc_info=True)
async def interactive_delete(): async def interactive_delete():
"""交互式删除知识""" """交互式删除知识"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -77,10 +79,10 @@ async def interactive_delete():
print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)") print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)")
print(" 2. 完整文段匹配(删除完全匹配的段落)") print(" 2. 完整文段匹配(删除完全匹配的段落)")
print("-" * 40) print("-" * 40)
mode = input("请选择删除模式 (1/2): ").strip() mode = input("请选择删除模式 (1/2): ").strip()
exact_match = False exact_match = False
if mode == "2": if mode == "2":
exact_match = True exact_match = True
print("\n[完整文段匹配模式]") print("\n[完整文段匹配模式]")
@ -102,14 +104,18 @@ async def interactive_delete():
print("\n[!] 无效选择,默认使用关键词模糊匹配模式。") print("\n[!] 无效选择,默认使用关键词模糊匹配模式。")
print("\n[关键词模糊匹配模式]") print("\n[关键词模糊匹配模式]")
keyword = input("请输入匹配关键词: ").strip() keyword = input("请输入匹配关键词: ").strip()
if not keyword: if not keyword:
print("\n[!] 输入为空,操作已取消。") print("\n[!] 输入为空,操作已取消。")
return return
print("-" * 40) print("-" * 40)
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower() confirm = (
if confirm != 'y': input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
.strip()
.lower()
)
if confirm != "y":
print("\n[!] 已取消删除操作。") print("\n[!] 已取消删除操作。")
return return
@ -117,7 +123,7 @@ async def interactive_delete():
try: try:
# 使用 lpmm_ops.py 中的接口 # 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.delete(keyword, exact_match=exact_match) result = await lpmm_ops.delete(keyword, exact_match=exact_match)
if result["status"] == "success": if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}") print(f"\n[√] 成功:{result['message']}")
print(f" 删除条数: {result.get('deleted_count', 0)}") print(f" 删除条数: {result.get('deleted_count', 0)}")
@ -129,6 +135,7 @@ async def interactive_delete():
print(f"\n[×] 发生异常: {e}") print(f"\n[×] 发生异常: {e}")
logger.error(f"delete 异常: {e}", exc_info=True) logger.error(f"delete 异常: {e}", exc_info=True)
async def interactive_clear(): async def interactive_clear():
"""交互式清空知识库""" """交互式清空知识库"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -141,40 +148,45 @@ async def interactive_clear():
print(" - 整个知识图谱") print(" - 整个知识图谱")
print(" - 此操作不可恢复!") print(" - 此操作不可恢复!")
print("-" * 40) print("-" * 40)
# 双重确认 # 双重确认
confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip() confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip()
if confirm1 != "YES": if confirm1 != "YES":
print("\n[!] 已取消清空操作。") print("\n[!] 已取消清空操作。")
return return
print("\n" + "=" * 40) print("\n" + "=" * 40)
confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip() confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip()
if confirm2 != "CLEAR": if confirm2 != "CLEAR":
print("\n[!] 已取消清空操作。") print("\n[!] 已取消清空操作。")
return return
print("\n[进度] 正在清空知识库...") print("\n[进度] 正在清空知识库...")
try: try:
# 使用 lpmm_ops.py 中的接口 # 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.clear_all() result = await lpmm_ops.clear_all()
if result["status"] == "success": if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}") print(f"\n[√] 成功:{result['message']}")
stats = result.get("stats", {}) stats = result.get("stats", {})
before = stats.get("before", {}) before = stats.get("before", {})
after = stats.get("after", {}) after = stats.get("after", {})
print("\n[统计信息]") print("\n[统计信息]")
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, " print(
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}") f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, " f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}") )
print(
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
)
else: else:
print(f"\n[×] 失败:{result['message']}") print(f"\n[×] 失败:{result['message']}")
except Exception as e: except Exception as e:
print(f"\n[×] 发生异常: {e}") print(f"\n[×] 发生异常: {e}")
logger.error(f"clear_all 异常: {e}", exc_info=True) logger.error(f"clear_all 异常: {e}", exc_info=True)
async def interactive_search(): async def interactive_search():
"""交互式查询知识""" """交互式查询知识"""
print("\n" + "=" * 40) print("\n" + "=" * 40)
@ -182,25 +194,25 @@ async def interactive_search():
print("=" * 40) print("=" * 40)
print("说明:输入查询问题或关键词,系统会返回相关的知识段落。") print("说明:输入查询问题或关键词,系统会返回相关的知识段落。")
print("-" * 40) print("-" * 40)
# 确保 LPMM 已初始化 # 确保 LPMM 已初始化
if not global_config.lpmm_knowledge.enable: if not global_config.lpmm_knowledge.enable:
print("\n[!] 警告LPMM 知识库在配置中未启用。") print("\n[!] 警告LPMM 知识库在配置中未启用。")
return return
try: try:
lpmm_start_up() lpmm_start_up()
except Exception as e: except Exception as e:
print(f"\n[!] LPMM 初始化失败: {e}") print(f"\n[!] LPMM 初始化失败: {e}")
logger.error(f"LPMM 初始化失败: {e}", exc_info=True) logger.error(f"LPMM 初始化失败: {e}", exc_info=True)
return return
query = input("请输入查询问题或关键词: ").strip() query = input("请输入查询问题或关键词: ").strip()
if not query: if not query:
print("\n[!] 查询内容为空,操作已取消。") print("\n[!] 查询内容为空,操作已取消。")
return return
# 询问返回条数 # 询问返回条数
print("-" * 40) print("-" * 40)
limit_str = input("希望返回的相关知识条数默认3直接回车使用默认值: ").strip() limit_str = input("希望返回的相关知识条数默认3直接回车使用默认值: ").strip()
@ -210,11 +222,11 @@ async def interactive_search():
except ValueError: except ValueError:
limit = 3 limit = 3
print("[!] 输入无效,使用默认值 3。") print("[!] 输入无效,使用默认值 3。")
print("\n[进度] 正在查询知识库...") print("\n[进度] 正在查询知识库...")
try: try:
result = await query_lpmm_knowledge(query, limit=limit) result = await query_lpmm_knowledge(query, limit=limit)
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("[查询结果]") print("[查询结果]")
print("=" * 60) print("=" * 60)
@ -224,6 +236,7 @@ async def interactive_search():
print(f"\n[×] 查询失败: {e}") print(f"\n[×] 查询失败: {e}")
logger.error(f"查询异常: {e}", exc_info=True) logger.error(f"查询异常: {e}", exc_info=True)
async def main(): async def main():
"""主循环""" """主循环"""
while True: while True:
@ -236,9 +249,9 @@ async def main():
print("║ 4. 清空知识库 (Clear All) ⚠️ ║") print("║ 4. 清空知识库 (Clear All) ⚠️ ║")
print("║ 0. 退出 (Exit) ║") print("║ 0. 退出 (Exit) ║")
print("" + "" * 38 + "") print("" + "" * 38 + "")
choice = input("请选择操作编号: ").strip() choice = input("请选择操作编号: ").strip()
if choice == "1": if choice == "1":
await interactive_add() await interactive_add()
elif choice == "2": elif choice == "2":
@ -253,6 +266,7 @@ async def main():
else: else:
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。") print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
# 运行主循环 # 运行主循环
@ -262,4 +276,3 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"\n[!] 程序运行出错: {e}") print(f"\n[!] 程序运行出错: {e}")
logger.error(f"Main loop 异常: {e}", exc_info=True) logger.error(f"Main loop 异常: {e}", exc_info=True)

View File

@ -21,18 +21,18 @@ PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
if PROJECT_ROOT not in sys.path: if PROJECT_ROOT not in sys.path:
sys.path.append(PROJECT_ROOT) sys.path.append(PROJECT_ROOT)
from src.common.logger import get_logger # type: ignore from src.common.logger import get_logger # type: ignore # noqa: E402
from src.config.config import global_config, model_config # type: ignore from src.config.config import global_config, model_config # type: ignore # noqa: E402
# 引入各功能脚本的入口函数 # 引入各功能脚本的入口函数
from import_openie import main as import_openie_main # type: ignore from import_openie import main as import_openie_main # type: ignore # noqa: E402
from info_extraction import main as info_extraction_main # type: ignore from info_extraction import main as info_extraction_main # type: ignore # noqa: E402
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore # noqa: E402
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore # noqa: E402
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore # noqa: E402
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore # noqa: E402
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore # noqa: E402
from raw_data_preprocessor import load_raw_data # type: ignore from raw_data_preprocessor import load_raw_data # type: ignore # noqa: E402
logger = get_logger("lpmm_manager") logger = get_logger("lpmm_manager")
@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data" raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
txt_files = list(raw_dir.glob("*.txt")) txt_files = list(raw_dir.glob("*.txt"))
if not txt_files: if not txt_files:
msg = ( msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件info_extraction 可能立即退出或无数据可处理。"
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
"info_extraction 可能立即退出或无数据可处理。"
)
print(msg) print(msg)
if non_interactive: if non_interactive:
logger.error( logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
)
return False return False
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower() cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
return cont == "y" return cont == "y"
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
openie_dir = Path(PROJECT_ROOT) / "data" / "openie" openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
json_files = list(openie_dir.glob("*.json")) json_files = list(openie_dir.glob("*.json"))
if not json_files: if not json_files:
msg = ( msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件import_openie 可能会因为找不到批次而失败。"
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
"import_openie 可能会因为找不到批次而失败。"
)
print(msg) print(msg)
if non_interactive: if non_interactive:
logger.error( logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
)
return False return False
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower() cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
return cont == "y" return cont == "y"
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
"""在部分操作前提醒 lpmm_knowledge.enable 状态。""" """在部分操作前提醒 lpmm_knowledge.enable 状态。"""
try: try:
if not getattr(global_config.lpmm_knowledge, "enable", False): if not getattr(global_config.lpmm_knowledge, "enable", False):
print( print("[WARN] 当前配置 lpmm_knowledge.enable = false刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
"[WARN] 当前配置 lpmm_knowledge.enable = false"
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
)
except Exception: except Exception:
# 配置异常时不阻断主流程,仅忽略提示 # 配置异常时不阻断主流程,仅忽略提示
pass pass
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
if action == "prepare_raw": if action == "prepare_raw":
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...") logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
sha_list, raw_data = load_raw_data() sha_list, raw_data = load_raw_data()
print( print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
elif action == "info_extract": elif action == "info_extract":
if not _check_before_info_extract("--non-interactive" in extra_args): if not _check_before_info_extract("--non-interactive" in extra_args):
print("已根据用户选择,取消执行信息提取。") print("已根据用户选择,取消执行信息提取。")
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新 # 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新") logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
sha_list, raw_data = load_raw_data() sha_list, raw_data = load_raw_data()
print( print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
non_interactive = "--non-interactive" in extra_args non_interactive = "--non-interactive" in extra_args
if not _check_before_info_extract(non_interactive): if not _check_before_info_extract(non_interactive):
print("已根据用户选择,取消 full_import信息提取阶段被取消") print("已根据用户选择,取消 full_import信息提取阶段被取消")
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
) )
# 快速选项:按推荐方式清理所有相关实体/关系 # 快速选项:按推荐方式清理所有相关实体/关系
quick_all = input( quick_all = (
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): " input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
).strip().lower() )
if quick_all in ("", "y", "yes"): if quick_all in ("", "y", "yes"):
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"]) args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
else: else:
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
def _interactive_build_batch_inspect_args() -> List[str]: def _interactive_build_batch_inspect_args() -> List[str]:
"""为 inspect_lpmm_batch 构造 --openie-file 参数。""" """为 inspect_lpmm_batch 构造 --openie-file 参数。"""
path = _interactive_choose_openie_file( path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
)
if not path: if not path:
return [] return []
return ["--openie-file", path] return ["--openie-file", path]
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
def _interactive_build_test_args() -> List[str]: def _interactive_build_test_args() -> List[str]:
"""为 test_lpmm_retrieval 构造自定义测试用例参数。""" """为 test_lpmm_retrieval 构造自定义测试用例参数。"""
print( print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
"\n[TEST] 你可以:\n"
"- 直接回车使用内置的默认测试用例;\n"
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
)
query = input("请输入自定义测试问题(回车则使用默认用例):").strip() query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
if not query: if not query:
return [] return []
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}") print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}") print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
new_dim = input( new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
).strip()
if new_dim and not new_dim.isdigit(): if new_dim and not new_dim.isdigit():
print("输入的维度不是纯数字,已取消操作。") print("输入的维度不是纯数字,已取消操作。")
return return
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -28,53 +28,55 @@ from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval") logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题 # 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval(): def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入""" """使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try: try:
# 先导入 prompt_builder检查 prompt 是否已经初始化 # 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册 # 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了 # 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval" module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块 # 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules: if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name] existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'): if hasattr(existing_module, "init_memory_retrieval_prompt"):
return ( return (
existing_module.init_memory_retrieval_prompt, existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question, existing_module._react_agent_solve_question,
existing_module._process_single_question, existing_module._process_single_question,
) )
# 如果模块已经在 sys.modules 中但部分初始化,先移除它 # 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules: if module_name in sys.modules:
existing_module = sys.modules[module_name] existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'): if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它 # 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入") logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name] del sys.modules[module_name]
# 清理可能相关的部分初始化模块 # 清理可能相关的部分初始化模块
keys_to_remove = [] keys_to_remove = []
for key in sys.modules.keys(): for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system': if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key) keys_to_remove.append(key)
for key in keys_to_remove: for key in keys_to_remove:
try: try:
del sys.modules[key] del sys.modules[key]
except KeyError: except KeyError:
pass pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载 # 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们 # 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try: try:
# 先导入可能触发循环导入的模块,让它们完成初始化 # 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config import src.config.config
import src.chat.utils.prompt_builder import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化 # 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
@ -89,11 +91,11 @@ def _import_memory_retrieval():
pass # 如果导入失败,继续 pass # 如果导入失败,继续
except Exception as e: except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}") logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval # 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval # 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name) memory_retrieval_module = importlib.import_module(module_name)
return ( return (
memory_retrieval_module.init_memory_retrieval_prompt, memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question, memory_retrieval_module._react_agent_solve_question,
@ -126,16 +128,16 @@ def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStrea
def get_token_usage_since(start_time: float) -> Dict[str, Any]: def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况 """获取从指定时间开始的token使用情况
Args: Args:
start_time: 开始时间戳 start_time: 开始时间戳
Returns: Returns:
包含token使用统计的字典 包含token使用统计的字典
""" """
try: try:
start_datetime = datetime.fromtimestamp(start_time) start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录 # 查询从开始时间到现在的所有memory相关的token使用记录
records = ( records = (
LLMUsage.select() LLMUsage.select()
@ -150,21 +152,21 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
) )
.order_by(LLMUsage.timestamp.asc()) .order_by(LLMUsage.timestamp.asc())
) )
total_prompt_tokens = 0 total_prompt_tokens = 0
total_completion_tokens = 0 total_completion_tokens = 0
total_tokens = 0 total_tokens = 0
total_cost = 0.0 total_cost = 0.0
request_count = 0 request_count = 0
model_usage = {} # 按模型统计 model_usage = {} # 按模型统计
for record in records: for record in records:
total_prompt_tokens += record.prompt_tokens or 0 total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0 total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0 total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0 total_cost += record.cost or 0.0
request_count += 1 request_count += 1
# 按模型统计 # 按模型统计
model_name = record.model_name or "unknown" model_name = record.model_name or "unknown"
if model_name not in model_usage: if model_name not in model_usage:
@ -180,7 +182,7 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
model_usage[model_name]["total_tokens"] += record.total_tokens or 0 model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0 model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1 model_usage[model_name]["request_count"] += 1
return { return {
"total_prompt_tokens": total_prompt_tokens, "total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens, "total_completion_tokens": total_completion_tokens,
@ -205,25 +207,25 @@ def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串""" """格式化思考步骤为可读字符串"""
if not thinking_steps: if not thinking_steps:
return "无思考步骤" return "无思考步骤"
lines = [] lines = []
for step in thinking_steps: for step in thinking_steps:
iteration = step.get("iteration", "?") iteration = step.get("iteration", "?")
thought = step.get("thought", "") thought = step.get("thought", "")
actions = step.get("actions", []) actions = step.get("actions", [])
observations = step.get("observations", []) observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---") lines.append(f"\n--- 迭代 {iteration} ---")
if thought: if thought:
lines.append(f"思考: {thought[:200]}...") lines.append(f"思考: {thought[:200]}...")
if actions: if actions:
lines.append("行动:") lines.append("行动:")
for action in actions: for action in actions:
action_type = action.get("action_type", "unknown") action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {}) action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}") lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations: if observations:
lines.append("观察:") lines.append("观察:")
for obs in observations: for obs in observations:
@ -231,7 +233,7 @@ def format_thinking_steps(thinking_steps: list) -> str:
if len(str(obs)) > 200: if len(str(obs)) > 200:
obs_str += "..." obs_str += "..."
lines.append(f" - {obs_str}") lines.append(f" - {obs_str}")
return "\n".join(lines) return "\n".join(lines)
@ -242,31 +244,32 @@ async def test_memory_retrieval(
max_iterations: Optional[int] = None, max_iterations: Optional[int] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""测试记忆检索功能 """测试记忆检索功能
Args: Args:
question: 要查询的问题 question: 要查询的问题
chat_id: 聊天ID chat_id: 聊天ID
context: 上下文信息 context: 上下文信息
max_iterations: 最大迭代次数 max_iterations: 最大迭代次数
Returns: Returns:
包含测试结果的字典 包含测试结果的字典
""" """
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"[测试] 记忆检索测试") print("[测试] 记忆检索测试")
print(f"[问题] {question}") print(f"[问题] {question}")
print("=" * 80) print("=" * 80)
# 记录开始时间 # 记录开始时间
start_time = time.time() start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config # 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入 # 注意:必须在函数内部调用,避免在模块级别触发循环导入
try: try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval() init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化 # 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts: if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt() init_memory_retrieval_prompt()
else: else:
@ -274,24 +277,24 @@ async def test_memory_retrieval(
except Exception as e: except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True) logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise raise
# 获取 global_config此时应该已经加载 # 获取 global_config此时应该已经加载
from src.config.config import global_config from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息 # 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None: if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds timeout = global_config.memory.agent_timeout_seconds
print(f"\n[配置]") print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}") print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}") print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}") print(f" 聊天ID: {chat_id}")
# 执行检索 # 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question, question=question,
chat_id=chat_id, chat_id=chat_id,
@ -299,14 +302,14 @@ async def test_memory_retrieval(
timeout=timeout, timeout=timeout,
initial_info="", initial_info="",
) )
# 记录结束时间 # 记录结束时间
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
# 获取token使用情况 # 获取token使用情况
token_usage = get_token_usage_since(start_time) token_usage = get_token_usage_since(start_time)
# 构建结果 # 构建结果
result = { result = {
"question": question, "question": question,
@ -318,41 +321,41 @@ async def test_memory_retrieval(
"iteration_count": len(thinking_steps), "iteration_count": len(thinking_steps),
"token_usage": token_usage, "token_usage": token_usage,
} }
# 输出结果 # 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print(f"\n[结果]") print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}") print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer: if found_answer and answer:
print(f" 答案: {answer}") print(f" 答案: {answer}")
else: else:
print(f" 答案: (未找到答案)") print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}") print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}") print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}") print(f" 总耗时: {elapsed_time:.2f}")
print(f"\n[Token使用情况]") print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}") print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}") print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}") print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}") print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}") print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage['model_usage']: if token_usage["model_usage"]:
print(f"\n[按模型统计]") print("\n[按模型统计]")
for model_name, usage in token_usage['model_usage'].items(): for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:") print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}") print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}") print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}") print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}") print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}") print(f" 成本: ${usage['cost']:.6f}")
print(f"\n[迭代详情]") print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps)) print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80) print("\n" + "=" * 80)
return result return result
@ -375,12 +378,12 @@ def main() -> None:
"-o", "-o",
help="将结果保存到JSON文件可选", help="将结果保存到JSON文件可选",
) )
args = parser.parse_args() args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志) # 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False) initialize_logging(verbose=False)
# 交互式输入问题 # 交互式输入问题
print("\n" + "=" * 80) print("\n" + "=" * 80)
print("记忆检索测试工具") print("记忆检索测试工具")
@ -389,7 +392,7 @@ def main() -> None:
if not question: if not question:
print("错误: 问题不能为空") print("错误: 问题不能为空")
return return
# 交互式输入最大迭代次数 # 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip() max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None max_iterations = None
@ -402,7 +405,7 @@ def main() -> None:
except ValueError: except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值") print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None max_iterations = None
# 连接数据库 # 连接数据库
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
@ -410,7 +413,7 @@ def main() -> None:
logger.error(f"数据库连接失败: {e}") logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}") print(f"错误: 数据库连接失败: {e}")
return return
# 运行测试 # 运行测试
try: try:
result = asyncio.run( result = asyncio.run(
@ -421,7 +424,7 @@ def main() -> None:
max_iterations=max_iterations, max_iterations=max_iterations,
) )
) )
# 如果指定了输出文件,保存结果 # 如果指定了输出文件,保存结果
if args.output: if args.output:
# 将thinking_steps转换为可序列化的格式 # 将thinking_steps转换为可序列化的格式
@ -429,7 +432,7 @@ def main() -> None:
with open(args.output, "w", encoding="utf-8") as f: with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2) json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}") print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试") print("\n\n[中断] 用户中断测试")
except Exception as e: except Exception as e:
@ -444,4 +447,3 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -455,6 +455,7 @@ class ExpressionSelector:
expr_obj.save() expr_obj.save()
logger.debug("表达方式激活: 更新last_active_time in db") logger.debug("表达方式激活: 更新last_active_time in db")
try: try:
expression_selector = ExpressionSelector() expression_selector = ExpressionSelector()
except Exception as e: except Exception as e:

View File

@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
logger = get_logger("jargon") logger = get_logger("jargon")
class JargonExplainer: class JargonExplainer:
"""黑话解释器,用于在回复前识别和解释上下文中的黑话""" """黑话解释器,用于在回复前识别和解释上下文中的黑话"""

View File

@ -60,31 +60,31 @@ def calculate_style_similarity(style1: str, style2: str) -> float:
""" """
计算两个 style 的相似度返回0-1之间的值 计算两个 style 的相似度返回0-1之间的值
在计算前会移除"使用""句式"这两个词参考 expression_similarity_analysis.py 在计算前会移除"使用""句式"这两个词参考 expression_similarity_analysis.py
Args: Args:
style1: 第一个 style style1: 第一个 style
style2: 第二个 style style2: 第二个 style
Returns: Returns:
float: 相似度值范围0-1 float: 相似度值范围0-1
""" """
if not style1 or not style2: if not style1 or not style2:
return 0.0 return 0.0
# 移除"使用"和"句式"这两个词 # 移除"使用"和"句式"这两个词
def remove_ignored_words(text: str) -> str: def remove_ignored_words(text: str) -> str:
"""移除需要忽略的词""" """移除需要忽略的词"""
text = text.replace("使用", "") text = text.replace("使用", "")
text = text.replace("句式", "") text = text.replace("句式", "")
return text.strip() return text.strip()
cleaned_style1 = remove_ignored_words(style1) cleaned_style1 = remove_ignored_words(style1)
cleaned_style2 = remove_ignored_words(style2) cleaned_style2 = remove_ignored_words(style2)
# 如果清理后文本为空返回0 # 如果清理后文本为空返回0
if not cleaned_style1 or not cleaned_style2: if not cleaned_style1 or not cleaned_style2:
return 0.0 return 0.0
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio() return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
@ -495,4 +495,4 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
if content and source_id: if content and source_id:
jargon_entries.append((content, source_id)) jargon_entries.append((content, source_id))
return expressions, jargon_entries return expressions, jargon_entries

View File

@ -2,7 +2,6 @@ import time
import asyncio import asyncio
from typing import List, Any from typing import List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.common_utils import TempMethodsExpression
@ -119,9 +118,7 @@ class MessageRecorder:
# 触发 expression_learner 和 jargon_miner 的处理 # 触发 expression_learner 和 jargon_miner 的处理
if self.enable_expression_learning: if self.enable_expression_learning:
asyncio.create_task( asyncio.create_task(self._trigger_expression_learning(messages))
self._trigger_expression_learning(messages)
)
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
@ -130,9 +127,7 @@ class MessageRecorder:
traceback.print_exc() traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试 # 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning( async def _trigger_expression_learning(self, messages: List[Any]) -> None:
self, messages: List[Any]
) -> None:
""" """
触发 expression 学习使用指定的消息列表 触发 expression 学习使用指定的消息列表

View File

@ -1,5 +1,5 @@
import time import time
from typing import Tuple, Optional, Dict, Any # 增加了 Optional from typing import Tuple, Optional # 增加了 Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@ -120,7 +120,7 @@ class ActionPlanner:
def _get_personality_prompt(self) -> str: def _get_personality_prompt(self) -> str:
"""获取个性提示信息""" """获取个性提示信息"""
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if ( if (
global_config.personality.states global_config.personality.states
@ -128,7 +128,7 @@ class ActionPlanner:
and random.random() < global_config.personality.state_probability and random.random() < global_config.personality.state_probability
): ):
prompt_personality = random.choice(global_config.personality.states) prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};" return f"你的名字是{bot_name},你{prompt_personality};"
@ -170,13 +170,10 @@ class ActionPlanner:
) )
break break
else: else:
logger.debug( logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
)
except Exception as e: except Exception as e:
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}") logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
# --- 获取超时提示信息 --- # --- 获取超时提示信息 ---
# (这部分逻辑不变) # (这部分逻辑不变)
timeout_context = "" timeout_context = ""

View File

@ -112,10 +112,10 @@ class Conversation:
"user_nickname": msg.user_info.user_nickname if msg.user_info else "", "user_nickname": msg.user_info.user_nickname if msg.user_info else "",
"user_cardname": msg.user_info.user_cardname if msg.user_info else None, "user_cardname": msg.user_info.user_cardname if msg.user_info else None,
"platform": msg.user_info.platform if msg.user_info else "", "platform": msg.user_info.platform if msg.user_info else "",
} },
} }
initial_messages_dict.append(msg_dict) initial_messages_dict.append(msg_dict)
# 将加载的消息填充到 ObservationInfo 的 chat_history # 将加载的消息填充到 ObservationInfo 的 chat_history
self.observation_info.chat_history = initial_messages_dict self.observation_info.chat_history = initial_messages_dict
self.observation_info.chat_history_str = chat_talking_prompt + "\n" self.observation_info.chat_history_str = chat_talking_prompt + "\n"

View File

@ -66,9 +66,9 @@ class DirectMessageSender:
# 发送消息(直接调用底层 API # 发送消息(直接调用底层 API
from src.chat.message_receive.uni_message_sender import _send_message from src.chat.message_receive.uni_message_sender import _send_message
sent = await _send_message(message, show_log=True) sent = await _send_message(message, show_log=True)
if sent: if sent:
# 存储消息 # 存储消息
await self.storage.store_message(message, chat_stream) await self.storage.store_message(message, chat_stream)

View File

@ -5,7 +5,7 @@ from src.common.logger import get_logger
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType, Notification from .chat_states import NotificationHandler, NotificationType, Notification
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseMessages
import traceback # 导入 traceback 用于调试 import traceback # 导入 traceback 用于调试
logger = get_logger("observation_info") logger = get_logger("observation_info")
@ -13,15 +13,15 @@ logger = get_logger("observation_info")
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages: def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
"""Convert PFC dict format to DatabaseMessages object """Convert PFC dict format to DatabaseMessages object
Args: Args:
msg_dict: Message in PFC dict format with nested user_info msg_dict: Message in PFC dict format with nested user_info
Returns: Returns:
DatabaseMessages object compatible with build_readable_messages() DatabaseMessages object compatible with build_readable_messages()
""" """
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {}) user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
return DatabaseMessages( return DatabaseMessages(
message_id=msg_dict.get("message_id", ""), message_id=msg_dict.get("message_id", ""),
time=msg_dict.get("time", 0.0), time=msg_dict.get("time", 0.0),

View File

@ -42,9 +42,7 @@ class GoalAnalyzer:
"""对话目标分析器""" """对话目标分析器"""
def __init__(self, stream_id: str, private_name: str): def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest( self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
)
self.personality_info = self._get_personality_prompt() self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname self.name = global_config.bot.nickname
@ -60,7 +58,7 @@ class GoalAnalyzer:
def _get_personality_prompt(self) -> str: def _get_personality_prompt(self) -> str:
"""获取个性提示信息""" """获取个性提示信息"""
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if ( if (
global_config.personality.states global_config.personality.states
@ -68,7 +66,7 @@ class GoalAnalyzer:
and random.random() < global_config.personality.state_probability and random.random() < global_config.personality.state_probability
): ):
prompt_personality = random.choice(global_config.personality.states) prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};" return f"你的名字是{bot_name},你{prompt_personality};"

View File

@ -1,13 +1,11 @@
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned # NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
# from src.plugins.memory_system.Hippocampus import HippocampusManager # from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import model_config
from src.chat.message_receive.message import Message
from src.chat.knowledge import qa_manager from src.chat.knowledge import qa_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
logger = get_logger("knowledge_fetcher") logger = get_logger("knowledge_fetcher")
@ -16,9 +14,7 @@ class KnowledgeFetcher:
"""知识调取器""" """知识调取器"""
def __init__(self, private_name: str): def __init__(self, private_name: str):
self.llm = LLMRequest( self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
model_set=model_config.model_task_config.utils
)
self.private_name = private_name self.private_name = private_name
def _lpmm_get_knowledge(self, query: str) -> str: def _lpmm_get_knowledge(self, query: str) -> str:
@ -50,13 +46,7 @@ class KnowledgeFetcher:
Returns: Returns:
Tuple[str, str]: (获取的知识, 知识来源) Tuple[str, str]: (获取的知识, 知识来源)
""" """
db_messages = [dict_to_database_message(m) for m in chat_history] _ = chat_history
chat_history_text = build_readable_messages(
db_messages,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
)
# NOTE: Hippocampus memory system was redesigned in v0.12.2 # NOTE: Hippocampus memory system was redesigned in v0.12.2
# The old get_memory_from_text API no longer exists # The old get_memory_from_text API no longer exists
@ -64,7 +54,7 @@ class KnowledgeFetcher:
# TODO: Integrate with new memory system if needed # TODO: Integrate with new memory system if needed
knowledge_text = "" knowledge_text = ""
sources_text = "无记忆匹配" # 默认值 sources_text = "无记忆匹配" # 默认值
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API) # # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
# related_memory = await HippocampusManager.get_instance().get_memory_from_text( # related_memory = await HippocampusManager.get_instance().get_memory_from_text(
# text=f"{query}\n{chat_history_text}", # text=f"{query}\n{chat_history_text}",

View File

@ -14,10 +14,7 @@ class ReplyChecker:
"""回复检查器""" """回复检查器"""
def __init__(self, stream_id: str, private_name: str): def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest( self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
model_set=model_config.model_task_config.utils,
request_type="reply_check"
)
self.personality_info = self._get_personality_prompt() self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname self.name = global_config.bot.nickname
self.private_name = private_name self.private_name = private_name
@ -27,7 +24,7 @@ class ReplyChecker:
def _get_personality_prompt(self) -> str: def _get_personality_prompt(self) -> str:
"""获取个性提示信息""" """获取个性提示信息"""
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if ( if (
global_config.personality.states global_config.personality.states
@ -35,7 +32,7 @@ class ReplyChecker:
and random.random() < global_config.personality.state_probability and random.random() < global_config.personality.state_probability
): ):
prompt_personality = random.choice(global_config.personality.states) prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};" return f"你的名字是{bot_name},你{prompt_personality};"

View File

@ -99,7 +99,7 @@ class ReplyGenerator:
def _get_personality_prompt(self) -> str: def _get_personality_prompt(self) -> str:
"""获取个性提示信息""" """获取个性提示信息"""
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if ( if (
global_config.personality.states global_config.personality.states
@ -107,7 +107,7 @@ class ReplyGenerator:
and random.random() < global_config.personality.state_probability and random.random() < global_config.personality.state_probability
): ):
prompt_personality = random.choice(global_config.personality.states) prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};" return f"你的名字是{bot_name},你{prompt_personality};"

View File

@ -704,10 +704,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断 # 等待指定时间,但可被新消息打断
try: try:
await asyncio.wait_for( await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
self._new_message_event.wait(),
timeout=wait_seconds
)
# 如果事件被触发,说明有新消息到达 # 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待") logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -731,7 +728,9 @@ class BrainChatting:
# 使用默认等待时间 # 使用默认等待时间
wait_seconds = 3 wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)") logger.info(
f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)"
)
# 清除事件状态,准备等待新消息 # 清除事件状态,准备等待新消息
self._new_message_event.clear() self._new_message_event.clear()
@ -749,10 +748,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断 # 等待指定时间,但可被新消息打断
try: try:
await asyncio.wait_for( await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
self._new_message_event.wait(),
timeout=wait_seconds
)
# 如果事件被触发,说明有新消息到达 # 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待") logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@ -431,15 +431,21 @@ class BrainPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}" extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
return extracted_reasoning, [ return (
ActionPlannerInfo( extracted_reasoning,
action_type="complete_talk", [
reasoning=extracted_reasoning, ActionPlannerInfo(
action_data={}, action_type="complete_talk",
action_message=None, reasoning=extracted_reasoning,
available_actions=available_actions, action_data={},
) action_message=None,
], llm_content, llm_reasoning, llm_duration_ms available_actions=available_actions,
)
],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应 # 解析LLM响应
if llm_content: if llm_content:

View File

@ -105,7 +105,7 @@ class EmbeddingStore:
self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index" self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json" self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
self.dirty = False # 标记是否有新增数据需要重建索引 self.dirty = False # 标记是否有新增数据需要重建索引
# 多线程配置参数验证和设置 # 多线程配置参数验证和设置

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
import time import time
from typing import List, Union, Dict, Any from typing import List, Union
from .global_logger import logger from .global_logger import logger
from . import prompt_template from . import prompt_template
@ -192,17 +192,15 @@ class IEProcess:
results = [] results = []
total = len(paragraphs) total = len(paragraphs)
for i, pg in enumerate(paragraphs, start=1): for i, pg in enumerate(paragraphs, start=1):
# 打印进度日志,让用户知道没有卡死 # 打印进度日志,让用户知道没有卡死
logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...") logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...")
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁 # 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行 # 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
try: try:
entities, triples = await asyncio.to_thread( entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
)
if entities is not None: if entities is not None:
results.append( results.append(

View File

@ -395,8 +395,7 @@ class KGManager:
appear_cnt = self.ent_appear_cnt.get(ent_hash) appear_cnt = self.ent_appear_cnt.get(ent_hash)
if not appear_cnt or appear_cnt <= 0: if not appear_cnt or appear_cnt <= 0:
logger.debug( logger.debug(
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0" f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0将使用 1.0 作为默认出现次数参与权重计算"
f"将使用 1.0 作为默认出现次数参与权重计算"
) )
appear_cnt = 1.0 appear_cnt = 1.0
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt) ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)

View File

@ -11,31 +11,30 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API") logger = get_logger("LPMM-Plugin-API")
class LPMMOperations: class LPMMOperations:
""" """
LPMM 内部操作接口 LPMM 内部操作接口
封装了 LPMM 的核心操作供插件系统 API 或其他内部组件调用 封装了 LPMM 的核心操作供插件系统 API 或其他内部组件调用
""" """
def __init__(self): def __init__(self):
self._initialized = False self._initialized = False
async def _run_cancellable_executor( async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
self, func: Callable, *args, **kwargs
) -> Any:
""" """
在线程池中执行可取消的同步操作 在线程池中执行可取消的同步操作
当任务被取消时 Ctrl+C会立即响应并抛出 CancelledError 当任务被取消时 Ctrl+C会立即响应并抛出 CancelledError
注意线程池中的操作可能仍在运行但协程会立即返回不会阻塞主进程 注意线程池中的操作可能仍在运行但协程会立即返回不会阻塞主进程
Args: Args:
func: 要执行的同步函数 func: 要执行的同步函数
*args: 函数的位置参数 *args: 函数的位置参数
**kwargs: 函数的关键字参数 **kwargs: 函数的关键字参数
Returns: Returns:
函数的返回值 函数的返回值
Raises: Raises:
asyncio.CancelledError: 当任务被取消时 asyncio.CancelledError: 当任务被取消时
""" """
@ -51,42 +50,42 @@ class LPMMOperations:
# 如果全局没初始化,尝试初始化 # 如果全局没初始化,尝试初始化
if not global_config.lpmm_knowledge.enable: if not global_config.lpmm_knowledge.enable:
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。") logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
lpmm_start_up() lpmm_start_up()
qa_mgr = get_qa_manager() qa_mgr = get_qa_manager()
if qa_mgr is None: if qa_mgr is None:
raise RuntimeError("无法获取 LPMM QAManager请检查 LPMM 是否已正确安装和配置。") raise RuntimeError("无法获取 LPMM QAManager请检查 LPMM 是否已正确安装和配置。")
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
async def add_content(self, text: str, auto_split: bool = True) -> dict: async def add_content(self, text: str, auto_split: bool = True) -> dict:
""" """
向知识库添加新内容 向知识库添加新内容
Args: Args:
text: 原始文本 text: 原始文本
auto_split: 是否自动按双换行符分割段落 auto_split: 是否自动按双换行符分割段落
- True: 自动分割默认支持多段文本用双换行分隔 - True: 自动分割默认支持多段文本用双换行分隔
- False: 不分割将整个文本作为完整一段处理 - False: 不分割将整个文本作为完整一段处理
Returns: Returns:
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"} dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
""" """
try: try:
embed_mgr, kg_mgr, _ = await self._get_managers() embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 分段处理 # 1. 分段处理
if auto_split: if auto_split:
# 自动按双换行符分割 # 自动按双换行符分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
else: else:
# 不分割,作为完整一段 # 不分割,作为完整一段
text_stripped = text.strip() text_stripped = text.strip()
if not text_stripped: if not text_stripped:
return {"status": "error", "message": "文本内容为空"} return {"status": "error", "message": "文本内容为空"}
paragraphs = [text_stripped] paragraphs = [text_stripped]
if not paragraphs: if not paragraphs:
return {"status": "error", "message": "文本内容为空"} return {"status": "error", "message": "文本内容为空"}
@ -94,14 +93,16 @@ class LPMMOperations:
from src.chat.knowledge.ie_process import IEProcess from src.chat.knowledge.ie_process import IEProcess
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract") llm_ner = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf) ie_process = IEProcess(llm_ner, llm_rdf)
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...") logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
extracted_docs = await ie_process.process_paragraphs(paragraphs) extracted_docs = await ie_process.process_paragraphs(paragraphs)
# 3. 构造并导入数据 # 3. 构造并导入数据
# 这里我们手动实现导入逻辑,不依赖外部脚本 # 这里我们手动实现导入逻辑,不依赖外部脚本
# a. 准备段落 # a. 准备段落
@ -115,7 +116,7 @@ class LPMMOperations:
# store_new_data_set 期望的格式raw_paragraphs 的键是段落hash不带前缀值是段落文本 # store_new_data_set 期望的格式raw_paragraphs 的键是段落hash不带前缀值是段落文本
new_raw_paragraphs = {} new_raw_paragraphs = {}
new_triple_list_data = {} new_triple_list_data = {}
for pg_hash, passage in raw_paragraphs.items(): for pg_hash, passage in raw_paragraphs.items():
key = f"paragraph-{pg_hash}" key = f"paragraph-{pg_hash}"
if key not in embed_mgr.stored_pg_hashes: if key not in embed_mgr.stored_pg_hashes:
@ -128,26 +129,22 @@ class LPMMOperations:
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入 # 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储 # store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor( await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
embed_mgr.store_new_data_set,
new_raw_paragraphs,
new_triple_list_data
)
# 3. 构建知识图谱只需要三元组数据和embedding_manager # 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor( await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
kg_mgr.build_kg,
new_triple_list_data,
embed_mgr
)
# 4. 持久化 # 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file)
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"} return {
"status": "success",
"count": len(new_raw_paragraphs),
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
}
except asyncio.CancelledError: except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断") logger.warning("[Plugin API] 导入操作被用户中断")
return {"status": "cancelled", "message": "导入操作已被用户中断"} return {"status": "cancelled", "message": "导入操作已被用户中断"}
@ -158,11 +155,11 @@ class LPMMOperations:
async def search(self, query: str, top_k: int = 3) -> List[str]: async def search(self, query: str, top_k: int = 3) -> List[str]:
""" """
检索知识库 检索知识库
Args: Args:
query: 查询问题 query: 查询问题
top_k: 返回最相关的条目数 top_k: 返回最相关的条目数
Returns: Returns:
List[str]: 相关文段列表 List[str]: 相关文段列表
""" """
@ -179,21 +176,21 @@ class LPMMOperations:
async def delete(self, keyword: str, exact_match: bool = False) -> dict: async def delete(self, keyword: str, exact_match: bool = False) -> dict:
""" """
根据关键词或完整文段删除知识库内容 根据关键词或完整文段删除知识库内容
Args: Args:
keyword: 匹配关键词或完整文段 keyword: 匹配关键词或完整文段
exact_match: 是否使用完整文段匹配True=完全匹配False=关键词模糊匹配 exact_match: 是否使用完整文段匹配True=完全匹配False=关键词模糊匹配
Returns: Returns:
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"} dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
""" """
try: try:
embed_mgr, kg_mgr, _ = await self._get_managers() embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 查找匹配的段落 # 1. 查找匹配的段落
to_delete_keys = [] to_delete_keys = []
to_delete_hashes = [] to_delete_hashes = []
for key, item in embed_mgr.paragraphs_embedding_store.store.items(): for key, item in embed_mgr.paragraphs_embedding_store.store.items():
if exact_match: if exact_match:
# 完整文段匹配 # 完整文段匹配
@ -205,29 +202,25 @@ class LPMMOperations:
if keyword in item.str: if keyword in item.str:
to_delete_keys.append(key) to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1)) to_delete_hashes.append(key.replace("paragraph-", "", 1))
if not to_delete_keys: if not to_delete_keys:
match_type = "完整文段" if exact_match else "关键词" match_type = "完整文段" if exact_match else "关键词"
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"} return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
# 2. 执行删除 # 2. 执行删除
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# a. 从向量库删除 # a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor( deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items, embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
to_delete_keys
) )
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys()) embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
# b. 从知识图谱删除 # b. 从知识图谱删除
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial( delete_func = partial(
kg_mgr.delete_paragraphs, kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
to_delete_hashes,
ent_hashes=None,
remove_orphan_entities=True
) )
await self._run_cancellable_executor(delete_func) await self._run_cancellable_executor(delete_func)
@ -235,9 +228,13 @@ class LPMMOperations:
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词" match_type = "完整文段" if exact_match else "关键词"
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"} return {
"status": "success",
"deleted_count": deleted_count,
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
}
except asyncio.CancelledError: except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断") logger.warning("[Plugin API] 删除操作被用户中断")
@ -249,13 +246,13 @@ class LPMMOperations:
async def clear_all(self) -> dict: async def clear_all(self) -> dict:
""" """
清空整个LPMM知识库删除所有段落实体关系和知识图谱数据 清空整个LPMM知识库删除所有段落实体关系和知识图谱数据
Returns: Returns:
dict: {"status": "success/error", "message": "描述", "stats": {...}} dict: {"status": "success/error", "message": "描述", "stats": {...}}
""" """
try: try:
embed_mgr, kg_mgr, _ = await self._get_managers() embed_mgr, kg_mgr, _ = await self._get_managers()
# 记录清空前的统计信息 # 记录清空前的统计信息
before_stats = { before_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store), "paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
@ -264,40 +261,37 @@ class LPMMOperations:
"kg_nodes": len(kg_mgr.graph.get_node_list()), "kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()), "kg_edges": len(kg_mgr.graph.get_edge_list()),
} }
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# 1. 清空所有向量库 # 1. 清空所有向量库
# 获取所有keys # 获取所有keys
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys()) para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
ent_keys = list(embed_mgr.entities_embedding_store.store.keys()) ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
rel_keys = list(embed_mgr.relation_embedding_store.store.keys()) rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
# 删除所有段落向量 # 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor( para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items, embed_mgr.paragraphs_embedding_store.delete_items, para_keys
para_keys
) )
embed_mgr.stored_pg_hashes.clear() embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量 # 删除所有实体向量
if ent_keys: if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor( ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items, embed_mgr.entities_embedding_store.delete_items, ent_keys
ent_keys
) )
else: else:
ent_deleted = 0 ent_deleted = 0
# 删除所有关系向量 # 删除所有关系向量
if rel_keys: if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor( rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items, embed_mgr.relation_embedding_store.delete_items, rel_keys
rel_keys
) )
else: else:
rel_deleted = 0 rel_deleted = 0
# 2. 清空所有 embedding store 的索引和映射 # 2. 清空所有 embedding store 的索引和映射
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件 # 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
def _clear_embedding_indices(): def _clear_embedding_indices():
@ -310,7 +304,7 @@ class LPMMOperations:
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path) os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path): if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path) os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
# 清空实体索引 # 清空实体索引
embed_mgr.entities_embedding_store.faiss_index = None embed_mgr.entities_embedding_store.faiss_index = None
embed_mgr.entities_embedding_store.idx2hash = None embed_mgr.entities_embedding_store.idx2hash = None
@ -320,7 +314,7 @@ class LPMMOperations:
os.remove(embed_mgr.entities_embedding_store.index_file_path) os.remove(embed_mgr.entities_embedding_store.index_file_path)
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path): if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path) os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
# 清空关系索引 # 清空关系索引
embed_mgr.relation_embedding_store.faiss_index = None embed_mgr.relation_embedding_store.faiss_index = None
embed_mgr.relation_embedding_store.idx2hash = None embed_mgr.relation_embedding_store.idx2hash = None
@ -330,9 +324,9 @@ class LPMMOperations:
os.remove(embed_mgr.relation_embedding_store.index_file_path) os.remove(embed_mgr.relation_embedding_store.index_file_path)
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path): if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path) os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
await self._run_cancellable_executor(_clear_embedding_indices) await self._run_cancellable_executor(_clear_embedding_indices)
# 3. 清空知识图谱 # 3. 清空知识图谱
# 获取所有段落hash # 获取所有段落hash
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes) all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
@ -341,24 +335,22 @@ class LPMMOperations:
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial( delete_func = partial(
kg_mgr.delete_paragraphs, kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
all_pg_hashes,
ent_hashes=None,
remove_orphan_entities=True
) )
await self._run_cancellable_executor(delete_func) await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行 # 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph() kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear() kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear() kg_mgr.ent_appear_cnt.clear()
# 4. 保存所有数据此时所有store都是空的索引也是None # 4. 保存所有数据此时所有store都是空的索引也是None
# 注意即使store为空save_to_file也会保存空的DataFrame这是正确的 # 注意即使store为空save_to_file也会保存空的DataFrame这是正确的
await self._run_cancellable_executor(embed_mgr.save_to_file) await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file) await self._run_cancellable_executor(kg_mgr.save_to_file)
after_stats = { after_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store), "paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store), "entities": len(embed_mgr.entities_embedding_store.store),
@ -366,14 +358,14 @@ class LPMMOperations:
"kg_nodes": len(kg_mgr.graph.get_node_list()), "kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()), "kg_edges": len(kg_mgr.graph.get_edge_list()),
} }
return { return {
"status": "success", "status": "success",
"message": f"已成功清空LPMM知识库删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)", "message": f"已成功清空LPMM知识库删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
"stats": { "stats": {
"before": before_stats, "before": before_stats,
"after": after_stats, "after": after_stats,
} },
} }
except asyncio.CancelledError: except asyncio.CancelledError:
@ -383,6 +375,6 @@ class LPMMOperations:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True) logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
# 内部使用的单例 # 内部使用的单例
lpmm_ops = LPMMOperations() lpmm_ops = LPMMOperations()

View File

@ -136,4 +136,3 @@ class PlanReplyLogger:
return str(value) return str(value)
# Fallback to string for other complex types # Fallback to string for other complex types
return str(value) return str(value)

View File

@ -85,17 +85,17 @@ class ChatBot:
async def _create_pfc_chat(self, message: MessageRecv): async def _create_pfc_chat(self, message: MessageRecv):
"""创建或获取PFC对话实例 """创建或获取PFC对话实例
Args: Args:
message: 消息对象 message: 消息对象
""" """
try: try:
chat_id = str(message.chat_stream.stream_id) chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname) private_name = str(message.message_info.user_info.user_nickname)
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}") logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
await self.pfc_manager.get_or_create_conversation(chat_id, private_name) await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e: except Exception as e:
logger.error(f"创建PFC聊天失败: {e}") logger.error(f"创建PFC聊天失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())

View File

@ -96,7 +96,7 @@ class Message(MessageBase):
if processed_text: if processed_text:
return f"{global_config.bot.nickname}: {processed_text}" return f"{global_config.bot.nickname}: {processed_text}"
return None return None
tasks = [process_forward_node(node_dict) for node_dict in segment.data] tasks = [process_forward_node(node_dict) for node_dict in segment.data]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
segments_text = [] segments_text = []

View File

@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 如果未开启 API Server直接跳过 Fallback # 如果未开启 API Server直接跳过 Fallback
if not global_config.maim_message.enable_api_server: if not global_config.maim_message.enable_api_server:
logger.debug(f"[API Server Fallback] API Server未开启跳过fallback") logger.debug("[API Server Fallback] API Server未开启跳过fallback")
if legacy_exception: if legacy_exception:
raise legacy_exception raise legacy_exception
return False return False
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
extra_server = getattr(global_api, "extra_server", None) extra_server = getattr(global_api, "extra_server", None)
if not extra_server: if not extra_server:
logger.warning(f"[API Server Fallback] extra_server不存在") logger.warning("[API Server Fallback] extra_server不存在")
if legacy_exception: if legacy_exception:
raise legacy_exception raise legacy_exception
return False return False
if not extra_server.is_running(): if not extra_server.is_running():
logger.warning(f"[API Server Fallback] extra_server未运行") logger.warning("[API Server Fallback] extra_server未运行")
if legacy_exception: if legacy_exception:
raise legacy_exception raise legacy_exception
return False return False
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
) )
# 直接调用 Server 的 send_message 接口,它会自动处理路由 # 直接调用 Server 的 send_message 接口,它会自动处理路由
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...") logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
results = await extra_server.send_message(api_message) results = await extra_server.send_message(api_message)
logger.debug(f"[API Server Fallback] 发送结果: {results}") logger.debug(f"[API Server Fallback] 发送结果: {results}")

View File

@ -35,6 +35,7 @@ logger = get_logger("planner")
install(extra_lines=3) install(extra_lines=3)
class ActionPlanner: class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager): def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id self.chat_id = chat_id
@ -48,7 +49,7 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = [] self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
# 黑话缓存:使用 OrderedDict 实现 LRU最多缓存10个 # 黑话缓存:使用 OrderedDict 实现 LRU最多缓存10个
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict() self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
self.unknown_words_cache_limit = 10 self.unknown_words_cache_limit = 10
@ -111,20 +112,29 @@ class ActionPlanner:
# 替换 [picid:xxx] 为 [图片:描述] # 替换 [picid:xxx] 为 [图片:描述]
pic_pattern = r"\[picid:([^\]]+)\]" pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(pic_match: re.Match) -> str: def replace_pic_id(pic_match: re.Match) -> str:
pic_id = pic_match.group(1) pic_id = pic_match.group(1)
description = translate_pid_to_description(pic_id) description = translate_pid_to_description(pic_id)
return f"[图片:{description}]" return f"[图片:{description}]"
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text) msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb> # 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq" platform = (
getattr(message, "user_info", None)
and message.user_info.platform
or getattr(message, "chat_info", None)
and message.chat_info.platform
or "qq"
)
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True) msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
# 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式) # 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式)
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式, # 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
# 这里匹配到的应该都是单独的格式 # 这里匹配到的应该都是单独的格式
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>" user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
def replace_user_ref(user_match: re.Match) -> str: def replace_user_ref(user_match: re.Match) -> str:
user_name = user_match.group(1) user_name = user_match.group(1)
user_id = user_match.group(2) user_id = user_match.group(2)
@ -137,6 +147,7 @@ class ActionPlanner:
except Exception: except Exception:
# 如果解析失败,使用原始昵称 # 如果解析失败,使用原始昵称
return user_name return user_name
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text) msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..." preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
@ -165,7 +176,7 @@ class ActionPlanner:
else: else:
reasoning = "未提供原因" reasoning = "未提供原因"
action_data = {key: value for key, value in action_json.items() if key not in ["action"]} action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
# 非no_reply动作需要target_message_id # 非no_reply动作需要target_message_id
target_message = None target_message = None
@ -244,7 +255,7 @@ class ActionPlanner:
def _update_unknown_words_cache(self, new_words: List[str]) -> None: def _update_unknown_words_cache(self, new_words: List[str]) -> None:
""" """
更新黑话缓存将新的黑话加入缓存 更新黑话缓存将新的黑话加入缓存
Args: Args:
new_words: 新提取的黑话列表 new_words: 新提取的黑话列表
""" """
@ -254,7 +265,7 @@ class ActionPlanner:
word = word.strip() word = word.strip()
if not word: if not word:
continue continue
# 如果已存在移到末尾LRU # 如果已存在移到末尾LRU
if word in self.unknown_words_cache: if word in self.unknown_words_cache:
self.unknown_words_cache.move_to_end(word) self.unknown_words_cache.move_to_end(word)
@ -269,10 +280,10 @@ class ActionPlanner:
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]: def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
""" """
合并新提取的黑话和缓存中的黑话 合并新提取的黑话和缓存中的黑话
Args: Args:
new_words: 新提取的黑话列表可能为None new_words: 新提取的黑话列表可能为None
Returns: Returns:
合并后的黑话列表去重 合并后的黑话列表去重
""" """
@ -284,31 +295,29 @@ class ActionPlanner:
word = word.strip() word = word.strip()
if word: if word:
cleaned_new_words.append(word) cleaned_new_words.append(word)
# 获取缓存中的黑话列表 # 获取缓存中的黑话列表
cached_words = list(self.unknown_words_cache.keys()) cached_words = list(self.unknown_words_cache.keys())
# 合并并去重(保留顺序:新提取的在前,缓存的在后) # 合并并去重(保留顺序:新提取的在前,缓存的在后)
merged_words: List[str] = [] merged_words: List[str] = []
seen = set() seen = set()
# 先添加新提取的 # 先添加新提取的
for word in cleaned_new_words: for word in cleaned_new_words:
if word not in seen: if word not in seen:
merged_words.append(word) merged_words.append(word)
seen.add(word) seen.add(word)
# 再添加缓存的(如果不在新提取的列表中) # 再添加缓存的(如果不在新提取的列表中)
for word in cached_words: for word in cached_words:
if word not in seen: if word not in seen:
merged_words.append(word) merged_words.append(word)
seen.add(word) seen.add(word)
return merged_words return merged_words
def _process_unknown_words_cache( def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
self, actions: List[ActionPlannerInfo]
) -> None:
""" """
处理黑话缓存逻辑 处理黑话缓存逻辑
1. 检查是否有 reply action 提取了 unknown_words 1. 检查是否有 reply action 提取了 unknown_words
@ -316,7 +325,7 @@ class ActionPlanner:
3. 如果缓存数量大于5移除最老的2个 3. 如果缓存数量大于5移除最老的2个
4. 对于每个 reply action合并缓存和新提取的黑话 4. 对于每个 reply action合并缓存和新提取的黑话
5. 更新缓存 5. 更新缓存
Args: Args:
actions: 解析后的动作列表 actions: 解析后的动作列表
""" """
@ -330,7 +339,7 @@ class ActionPlanner:
removed_count += 1 removed_count += 1
if removed_count > 0: if removed_count > 0:
logger.debug(f"{self.log_prefix}缓存数量大于5移除最老的{removed_count}个缓存") logger.debug(f"{self.log_prefix}缓存数量大于5移除最老的{removed_count}个缓存")
# 检查是否有 reply action 提取了 unknown_words # 检查是否有 reply action 提取了 unknown_words
has_extracted_unknown_words = False has_extracted_unknown_words = False
for action in actions: for action in actions:
@ -340,22 +349,22 @@ class ActionPlanner:
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0: if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
has_extracted_unknown_words = True has_extracted_unknown_words = True
break break
# 如果当前 plan 的 reply 没有提取移除最老的1个 # 如果当前 plan 的 reply 没有提取移除最老的1个
if not has_extracted_unknown_words: if not has_extracted_unknown_words:
if len(self.unknown_words_cache) > 0: if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False) self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存") logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
# 对于每个 reply action合并缓存和新提取的黑话 # 对于每个 reply action合并缓存和新提取的黑话
for action in actions: for action in actions:
if action.action_type == "reply": if action.action_type == "reply":
action_data = action.action_data or {} action_data = action.action_data or {}
new_words = action_data.get("unknown_words") new_words = action_data.get("unknown_words")
# 合并新提取的和缓存的黑话列表 # 合并新提取的和缓存的黑话列表
merged_words = self._merge_unknown_words_with_cache(new_words) merged_words = self._merge_unknown_words_with_cache(new_words)
# 更新 action_data # 更新 action_data
if merged_words: if merged_words:
action_data["unknown_words"] = merged_words action_data["unknown_words"] = merged_words
@ -366,7 +375,7 @@ class ActionPlanner:
else: else:
# 如果没有合并后的黑话,移除 unknown_words 字段 # 如果没有合并后的黑话,移除 unknown_words 字段
action_data.pop("unknown_words", None) action_data.pop("unknown_words", None)
# 更新缓存(将新提取的黑话加入缓存) # 更新缓存(将新提取的黑话加入缓存)
if new_words: if new_words:
self._update_unknown_words_cache(new_words) self._update_unknown_words_cache(new_words)
@ -442,15 +451,19 @@ class ActionPlanner:
# 检查是否已经有回复该消息的 action # 检查是否已经有回复该消息的 action
has_reply_to_force_message = False has_reply_to_force_message = False
for action in actions: for action in actions:
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id: if (
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
):
has_reply_to_force_message = True has_reply_to_force_message = True
break break
# 如果没有回复该消息,强制添加回复 action # 如果没有回复该消息,强制添加回复 action
if not has_reply_to_force_message: if not has_reply_to_force_message:
# 移除所有 no_reply action如果有 # 移除所有 no_reply action如果有
actions = [a for a in actions if a.action_type != "no_reply"] actions = [a for a in actions if a.action_type != "no_reply"]
# 创建强制回复 action # 创建强制回复 action
available_actions_dict = dict(current_available_actions) available_actions_dict = dict(current_available_actions)
force_reply_action = ActionPlannerInfo( force_reply_action = ActionPlannerInfo(
@ -577,10 +590,11 @@ class ActionPlanner:
if global_config.chat.think_mode == "classic": if global_config.chat.think_mode == "classic":
reply_action_example = "" reply_action_example = ""
if global_config.chat.llm_quote: if global_config.chat.llm_quote:
reply_action_example += "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n" reply_action_example += (
"5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += ( reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", ' '{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
'"unknown_words":["词语1","词语2"]'
) )
if global_config.chat.llm_quote: if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"' reply_action_example += ', "quote":"如果需要引用该message设置为true"'
@ -590,7 +604,9 @@ class ActionPlanner:
"5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n" "5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n"
) )
if global_config.chat.llm_quote: if global_config.chat.llm_quote:
reply_action_example += "6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n" reply_action_example += (
"6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += ( reply_action_example += (
'{{"action":"reply", "think_level":数值等级(0或1), ' '{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", ' '"target_message_id":"消息id(m+数字)", '
@ -741,15 +757,21 @@ class ActionPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}", [ return (
ActionPlannerInfo( f"LLM 请求失败,模型出现问题: {req_e}",
action_type="no_reply", [
reasoning=f"LLM 请求失败,模型出现问题: {req_e}", ActionPlannerInfo(
action_data={}, action_type="no_reply",
action_message=None, reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
available_actions=available_actions, action_data={},
) action_message=None,
], llm_content, llm_reasoning, llm_duration_ms available_actions=available_actions,
)
],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应 # 解析LLM响应
extracted_reasoning = "" extracted_reasoning = ""

View File

@ -1071,7 +1071,6 @@ class DefaultReplyer:
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2") chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt) chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换 # 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style reply_style = global_config.personality.reply_style
multi_styles = global_config.personality.multiple_reply_style multi_styles = global_config.personality.multiple_reply_style

View File

@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
) )
from src.bw_learner.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.base.component_types import ActionInfo, EventType
@ -807,7 +808,7 @@ class PrivateReplyer:
reply_style = global_config.personality.reply_style reply_style = global_config.personality.reply_style
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI # 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(platform, user_id): if is_bot_self(platform, user_id):
prompt_template = prompt_manager.get_prompt("private_replyer_self") prompt_template = prompt_manager.get_prompt("private_replyer_self")
prompt_template.add_context("target", target) prompt_template.add_context("target", target)

View File

@ -519,7 +519,7 @@ def _build_readable_messages_internal(
output_lines: List[str] = [] output_lines: List[str] = []
prev_timestamp: Optional[float] = None prev_timestamp: Optional[float] = None
for timestamp, name, content, is_action in detailed_message: for timestamp, name, content, _is_action in detailed_message:
# 检查是否需要插入长时间间隔提示 # 检查是否需要插入长时间间隔提示
if long_time_notice and prev_timestamp is not None: if long_time_notice and prev_timestamp is not None:
time_diff = timestamp - prev_timestamp time_diff = timestamp - prev_timestamp

View File

@ -5,6 +5,7 @@ from src.common.logger import get_logger
logger = get_logger("common_utils") logger = get_logger("common_utils")
class TempMethodsExpression: class TempMethodsExpression:
"""用于临时存放一些方法的类""" """用于临时存放一些方法的类"""

View File

@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
from . import BaseDatabaseDataModel from . import BaseDatabaseDataModel
class MaiChatSession(BaseDatabaseDataModel[ChatSession]): class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None): def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
self.session_id = session_id self.session_id = session_id
@ -33,4 +34,4 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
platform=self.platform, platform=self.platform,
user_id=self.user_id, user_id=self.user_id,
group_id=self.group_id, group_id=self.group_id,
) )

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple, Union from typing import Any, Iterable, List, Optional, Tuple, Union

View File

@ -221,5 +221,7 @@ if not supports_truecolor():
CONVERTED_MODULE_COLORS[name] = escape_str CONVERTED_MODULE_COLORS[name] = escape_str
else: else:
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items(): for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold) escape_str = rgb_pair_to_ansi_truecolor(
CONVERTED_MODULE_COLORS[name] = escape_str hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
)
CONVERTED_MODULE_COLORS[name] = escape_str

View File

@ -9,6 +9,7 @@ from .server import get_global_server
global_api = None global_api = None
def get_global_api() -> MessageServer: # sourcery skip: extract-method def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例""" """获取全局MessageServer实例"""
global global_api global global_api
@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}") api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
return False return False
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了 server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
# 3. Setup Message Bridge # 3. Setup Message Bridge
# Initialize refined route map if not exists # Initialize refined route map if not exists
if not hasattr(global_api, "platform_map"): if not hasattr(global_api, "platform_map"):
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法 global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
async def bridge_message_handler(message: APIMessageBase, metadata: dict): async def bridge_message_handler(message: APIMessageBase, metadata: dict):
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase # 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'") api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
if platform: if platform:
global_api.platform_map[platform] = api_key # type: ignore global_api.platform_map[platform] = api_key # type: ignore
api_logger.info(f"Updated platform_map: {platform} -> {api_key}") api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
except Exception as e: except Exception as e:
api_logger.warning(f"Failed to update platform map: {e}") api_logger.warning(f"Failed to update platform map: {e}")
@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
if "raw_message" not in msg_dict: if "raw_message" not in msg_dict:
msg_dict["raw_message"] = None msg_dict["raw_message"] = None
await global_api.process_message(msg_dict) # type: ignore await global_api.process_message(msg_dict) # type: ignore
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了 server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
# 3.5. Register custom message handlers (bridge to Legacy handlers) # 3.5. Register custom message handlers (bridge to Legacy handlers)
# message_id_echo: handles message ID echo from adapters # message_id_echo: handles message ID echo from adapters
# 兼容新旧两个版本的 maim_message: # 兼容新旧两个版本的 maim_message:
# - 旧版: handler(payload) # - 旧版: handler(payload)
# - 新版: handler(payload, metadata) # - 新版: handler(payload, metadata)
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
# Bridge to the Legacy custom handler registered in main.py # Bridge to the Legacy custom handler registered in main.py
try: try:
# The Legacy handler expects the payload format directly # The Legacy handler expects the payload format directly
if hasattr(global_api, "_custom_message_handlers"): if hasattr(global_api, "_custom_message_handlers"):
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了 handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
if handler: if handler:
await handler(payload) await handler(payload)
api_logger.debug(f"Processed message_id_echo: {payload}") api_logger.debug(f"Processed message_id_echo: {payload}")
@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
except Exception as e: except Exception as e:
api_logger.warning(f"Failed to process message_id_echo: {e}") api_logger.warning(f"Failed to process message_id_echo: {e}")
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了 server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
# 4. Initialize Server # 4. Initialize Server
extra_server = WebSocketServer(config=server_config) extra_server = WebSocketServer(config=server_config)
@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
global_api.stop = patched_stop global_api.stop = patched_stop
# Attach for reference # Attach for reference
global_api.extra_server = extra_server # type: ignore # 这是什么 global_api.extra_server = extra_server # type: ignore # 这是什么
except ImportError: except ImportError:
get_logger("maim_message").error( get_logger("maim_message").error(

View File

@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
logger = get_logger("file_utils") logger = get_logger("file_utils")
class FileUtils: class FileUtils:
@staticmethod @staticmethod
def save_binary_to_file(file_path: Path, data: bytes): def save_binary_to_file(file_path: Path, data: bytes):
@ -35,7 +36,7 @@ class FileUtils:
except Exception as e: except Exception as e:
logger.error(f"保存文件 {file_path} 失败: {e}") logger.error(f"保存文件 {file_path} 失败: {e}")
raise e raise e
@staticmethod @staticmethod
def get_file_path_by_hash(data_hash: str) -> Path: def get_file_path_by_hash(data_hash: str) -> Path:
""" """
@ -52,4 +53,4 @@ class FileUtils:
if binary_data := session.exec(statement).first(): if binary_data := session.exec(statement).first():
return Path(binary_data.full_path) return Path(binary_data.full_path)
else: else:
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录") raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")

View File

@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
reason = ",".join(reasons) reason = ",".join(reasons)
return MigrationResult(data=data, migrated=migrated_any, reason=reason) return MigrationResult(data=data, migrated=migrated_any, reason=reason)

View File

@ -86,8 +86,8 @@ def init_dream_tools(chat_id: str) -> None:
finish_maintenance = make_finish_maintenance(chat_id) finish_maintenance = make_finish_maintenance(chat_id)
search_jargon = make_search_jargon(chat_id) search_jargon = make_search_jargon(chat_id)
delete_jargon = make_delete_jargon(chat_id) _delete_jargon = make_delete_jargon(chat_id)
update_jargon = make_update_jargon(chat_id) _update_jargon = make_update_jargon(chat_id)
_dream_tool_registry.register_tool( _dream_tool_registry.register_tool(
DreamTool( DreamTool(

View File

@ -54,8 +54,6 @@ async def generate_dream_summary(
) -> None: ) -> None:
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户""" """生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
try: try:
# 第一步:建立工具调用结果映射 (call_id -> result) # 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {} tool_results_map: dict[str, str] = {}
for msg in conversation_messages: for msg in conversation_messages:

View File

@ -4,4 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数 每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册 生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册
""" """

View File

@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}" return f"create_chat_history 执行失败: {e}"
return create_chat_history return create_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}" return f"delete_chat_history 执行失败: {e}"
return delete_chat_history return delete_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}" return f"delete_jargon 执行失败: {e}"
return delete_jargon return delete_jargon

View File

@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg return msg
return finish_maintenance return finish_maintenance

View File

@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
return f"get_chat_history_detail 执行失败: {e}" return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail return get_chat_history_detail

View File

@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}" return f"search_chat_history 执行失败: {e}"
return search_chat_history return search_chat_history

View File

@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}" return f"update_chat_history 执行失败: {e}"
return update_chat_history return update_chat_history

View File

@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}" return f"update_jargon 执行失败: {e}"
return update_jargon return update_jargon

View File

@ -458,8 +458,8 @@ def _default_normal_response_parser(
if not isinstance(arguments, dict): if not isinstance(arguments, dict):
# 此时为了调试方便,建议打印出 arguments 的类型 # 此时为了调试方便,建议打印出 arguments 的类型
raise RespParseException( raise RespParseException(
resp, resp,
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}" f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
) )
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e: except json.JSONDecodeError as e:

View File

@ -2,7 +2,7 @@ import time
import json import json
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable, cast from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
@ -34,7 +34,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
try: try:
with get_db_session() as session: with get_db_session() as session:
statement = select(ThinkingQuestion).where( statement = select(ThinkingQuestion).where(
(ThinkingQuestion.found_answer == False) col(ThinkingQuestion.found_answer).is_(False)
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time)) & (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
) )
records = session.exec(statement).all() records = session.exec(statement).all()
@ -786,8 +786,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
str: 格式化的查询历史字符串 str: 格式化的查询历史字符串
""" """
try: try:
current_time = time.time() _current_time = time.time()
start_time = current_time - time_window_seconds
with get_db_session() as session: with get_db_session() as session:
statement = ( statement = (
@ -838,15 +837,14 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
List[str]: 格式化的答案列表每个元素格式为 "问题xxx\n答案xxx" List[str]: 格式化的答案列表每个元素格式为 "问题xxx\n答案xxx"
""" """
try: try:
current_time = time.time() _current_time = time.time()
start_time = current_time - time_window_seconds
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序 # 查询最近时间窗口内已找到答案的记录,按更新时间倒序
with get_db_session() as session: with get_db_session() as session:
statement = ( statement = (
select(ThinkingQuestion) select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id) .where(col(ThinkingQuestion.context) == chat_id)
.where(col(ThinkingQuestion.found_answer) == True) .where(col(ThinkingQuestion.found_answer))
.where(col(ThinkingQuestion.answer).is_not(None)) .where(col(ThinkingQuestion.answer).is_not(None))
.where(col(ThinkingQuestion.answer) != "") .where(col(ThinkingQuestion.answer) != "")
.order_by(col(ThinkingQuestion.updated_timestamp).desc()) .order_by(col(ThinkingQuestion.updated_timestamp).desc())

View File

@ -105,25 +105,27 @@ async def search_chat_history(
# 检查参数 # 检查参数
if not keyword and not participant and not start_time and not end_time: if not keyword and not participant and not start_time and not end_time:
return "未指定查询参数需要提供keyword、participant、start_time或end_time之一" return "未指定查询参数需要提供keyword、participant、start_time或end_time之一"
# 解析时间参数 # 解析时间参数
start_timestamp = None start_timestamp = None
end_timestamp = None end_timestamp = None
if start_time: if start_time:
try: try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp from src.memory_system.memory_utils import parse_datetime_to_timestamp
start_timestamp = parse_datetime_to_timestamp(start_time) start_timestamp = parse_datetime_to_timestamp(start_time)
except ValueError as e: except ValueError as e:
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'" return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
if end_time: if end_time:
try: try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp from src.memory_system.memory_utils import parse_datetime_to_timestamp
end_timestamp = parse_datetime_to_timestamp(end_time) end_timestamp = parse_datetime_to_timestamp(end_time)
except ValueError as e: except ValueError as e:
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'" return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
# 验证时间范围 # 验证时间范围
if start_timestamp and end_timestamp and start_timestamp > end_timestamp: if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
return "开始时间不能晚于结束时间" return "开始时间不能晚于结束时间"
@ -158,23 +160,20 @@ async def search_chat_history(
f"search_chat_history 当前聊天流在黑名单中强制使用本地查询chat_id={chat_id}, keyword={keyword}, participant={participant}" f"search_chat_history 当前聊天流在黑名单中强制使用本地查询chat_id={chat_id}, keyword={keyword}, participant={participant}"
) )
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 添加时间过滤条件 # 添加时间过滤条件
if start_timestamp is not None and end_timestamp is not None: if start_timestamp is not None and end_timestamp is not None:
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集) # 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段 # 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
query = query.where( query = query.where(
( (
(ChatHistory.start_time >= start_timestamp) (ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
& (ChatHistory.start_time <= end_timestamp)
) # 记录开始时间在查询时间段内 ) # 记录开始时间在查询时间段内
| ( | (
(ChatHistory.end_time >= start_timestamp) (ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
& (ChatHistory.end_time <= end_timestamp)
) # 记录结束时间在查询时间段内 ) # 记录结束时间在查询时间段内
| ( | (
(ChatHistory.start_time <= start_timestamp) (ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
& (ChatHistory.end_time >= end_timestamp)
) # 记录完全包含查询时间段 ) # 记录完全包含查询时间段
) )
logger.debug( logger.debug(
@ -302,7 +301,7 @@ async def search_chat_history(
time_desc = f"时间<='{end_str}'" time_desc = f"时间<='{end_str}'"
if time_desc: if time_desc:
conditions.append(time_desc) conditions.append(time_desc)
if conditions: if conditions:
conditions_str = "".join(conditions) conditions_str = "".join(conditions)
return f"未找到满足条件({conditions_str})的聊天记录" return f"未找到满足条件({conditions_str})的聊天记录"

View File

@ -30,7 +30,7 @@ async def query_words(chat_id: str, words: str) -> str:
if separator in words: if separator in words:
words_list = [w.strip() for w in words.split(separator) if w.strip()] words_list = [w.strip() for w in words.split(separator) if w.strip()]
break break
# 如果没有找到分隔符,整个字符串作为一个词语 # 如果没有找到分隔符,整个字符串作为一个词语
if not words_list: if not words_list:
words_list = [words.strip()] words_list = [words.strip()]
@ -76,4 +76,3 @@ def register_tool():
], ],
execute_func=query_words, execute_func=query_words,
) )

View File

@ -123,7 +123,7 @@ async def generate_reply(
# 如果 reply_time_point 未传入,设置为当前时间戳 # 如果 reply_time_point 未传入,设置为当前时间戳
if reply_time_point is None: if reply_time_point is None:
reply_time_point = time.time() reply_time_point = time.time()
# 获取回复器 # 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复") logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type) replyer = get_replyer(chat_stream, chat_id, request_type=request_type)

View File

@ -39,6 +39,18 @@ def unregister_service(service_name: str, plugin_name: Optional[str] = None) ->
return plugin_service_registry.unregister_service(service_name, plugin_name) return plugin_service_registry.unregister_service(service_name, plugin_name)
async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: async def call_service(
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务。""" """调用插件服务。"""
return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs) return await plugin_service_registry.call_service(
service_name,
*args,
plugin_name=plugin_name,
caller_plugin=caller_plugin,
**kwargs,
)

View File

@ -558,7 +558,9 @@ class PluginBase(ABC):
if version_spec: if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec) is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok: if not is_ok:
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})") logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
)
return False return False
if min_version or max_version: if min_version or max_version:

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict from typing import Any, Dict, List
@dataclass @dataclass
@ -11,6 +11,8 @@ class PluginServiceInfo:
version: str = "1.0.0" version: str = "1.0.0"
description: str = "" description: str = ""
enabled: bool = True enabled: bool = True
public: bool = False
allowed_callers: List[str] = field(default_factory=list)
params_schema: Dict[str, Any] = field(default_factory=dict) params_schema: Dict[str, Any] = field(default_factory=dict)
return_schema: Dict[str, Any] = field(default_factory=dict) return_schema: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)

View File

@ -274,6 +274,23 @@ class ComponentRegistry:
logger.error(f"移除组件 {component_name} 时发生错误: {e}") logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False return False
async def remove_components_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有组件。"""
targets = [
(component_info.name, component_info.component_type)
for component_info in self._components.values()
if component_info.plugin_name == plugin_name
]
removed_count = 0
for component_name, component_type in targets:
if await self.remove_component(component_name, component_type, plugin_name):
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
return removed_count
def remove_plugin_registry(self, plugin_name: str) -> bool: def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息 """移除插件注册信息
@ -734,9 +751,7 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count, "workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count, "enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": { "workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
},
} }

View File

@ -200,13 +200,43 @@ class PluginManager:
""" """
重载插件模块 重载插件模块
""" """
old_instance = self.loaded_plugins.get(plugin_name)
if not old_instance:
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
return False
if not await self.remove_registered_plugin(plugin_name): if not await self.remove_registered_plugin(plugin_name):
return False return False
if not self.load_registered_plugin_classes(plugin_name)[0]: if not self.load_registered_plugin_classes(plugin_name)[0]:
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
if rollback_ok:
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
else:
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
return False return False
logger.debug(f"插件 {plugin_name} 重载成功") logger.debug(f"插件 {plugin_name} 重载成功")
return True return True
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
"""重载失败后回滚旧实例。"""
try:
await component_registry.remove_components_by_plugin(plugin_name)
component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
if not old_instance.register_plugin():
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
return False
self.loaded_plugins[plugin_name] = old_instance
return True
except Exception as e:
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
return False
def rescan_plugin_directory(self) -> Tuple[int, int]: def rescan_plugin_directory(self) -> Tuple[int, int]:
""" """
重新扫描插件根目录 重新扫描插件根目录
@ -399,7 +429,9 @@ class PluginManager:
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]: def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。""" """根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()} indegree: Dict[str, int] = {
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph} reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items(): for plugin_name, dependencies in dependency_graph.items():

View File

@ -26,6 +26,9 @@ class PluginServiceRegistry:
if "." in service_info.plugin_name: if "." in service_info.plugin_name:
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False return False
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
return False
full_name = service_info.full_name full_name = service_info.full_name
if full_name in self._services: if full_name in self._services:
@ -52,7 +55,9 @@ class PluginServiceRegistry:
full_name = self._resolve_full_name(service_name, plugin_name) full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None return self._service_handlers.get(full_name) if full_name else None
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]: def list_services(
self, plugin_name: Optional[str] = None, enabled_only: bool = False
) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。""" """列出插件服务。"""
services = self._services.copy() services = self._services.copy()
if plugin_name: if plugin_name:
@ -103,12 +108,33 @@ class PluginServiceRegistry:
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}") logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
return removed_count return removed_count
async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any: async def call_service(
self,
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务(支持同步/异步handler""" """调用插件服务(支持同步/异步handler"""
service_info = self.get_service(service_name, plugin_name) service_info = self.get_service(service_name, plugin_name)
if not service_info: if not service_info:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}") raise ValueError(f"插件服务未注册: {target_name}")
if (
"." not in service_name
and plugin_name is None
and caller_plugin
and service_info.plugin_name != caller_plugin
):
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin):
raise PermissionError(
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
)
if not service_info.enabled: if not service_info.enabled:
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}") raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
@ -116,8 +142,93 @@ class PluginServiceRegistry:
if not handler: if not handler:
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}") raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
self._validate_input_contract(service_info, args, kwargs)
result = handler(*args, **kwargs) result = handler(*args, **kwargs)
return await result if inspect.isawaitable(result) else result resolved_result = await result if inspect.isawaitable(result) else result
self._validate_output_contract(service_info, resolved_result)
return resolved_result
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
"""检查服务调用权限。"""
if caller_plugin is None:
return service_info.public
if caller_plugin == service_info.plugin_name:
return True
if service_info.public:
return True
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""校验服务入参契约。"""
schema = service_info.params_schema
if not schema:
return
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
is_invocation_schema = "args" in properties or "kwargs" in properties
if is_invocation_schema:
payload = {"args": list(args), "kwargs": kwargs}
self._validate_by_schema(payload, schema, path="params")
return
if args:
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
self._validate_by_schema(kwargs, schema, path="params")
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
"""校验服务返回值契约。"""
if not service_info.return_schema:
return
self._validate_by_schema(value, service_info.return_schema, path="return")
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
"""基于简化JSON-Schema校验数据。"""
expected_type = schema.get("type")
if expected_type:
self._validate_type(value, expected_type, path)
enum_values = schema.get("enum")
if enum_values is not None and value not in enum_values:
raise ValueError(f"{path} 不在枚举范围内: {value}")
if expected_type == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
for field in required:
if field not in value:
raise ValueError(f"{path}.{field} 为必填字段")
for field, field_value in value.items():
if field in properties:
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
elif schema.get("additionalProperties", True) is False:
raise ValueError(f"{path}.{field} 不允许额外字段")
if expected_type == "array":
if item_schema := schema.get("items"):
for index, item in enumerate(value):
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
"""校验基础类型。"""
type_checkers: Dict[str, Callable[[Any], bool]] = {
"string": lambda item: isinstance(item, str),
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
"boolean": lambda item: isinstance(item, bool),
"object": lambda item: isinstance(item, dict),
"array": lambda item: isinstance(item, list),
"null": lambda item: item is None,
}
checker = type_checkers.get(expected_type)
if checker and not checker(value):
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]: def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
"""解析服务全名。""" """解析服务全名。"""

View File

@ -1,7 +1,8 @@
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import asyncio
import inspect
import time import time
import uuid import uuid
import inspect
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, MaiMessages from src.plugin_system.base.component_types import EventType, MaiMessages
@ -95,7 +96,9 @@ class WorkflowEngine:
except Exception as e: except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}") workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True) logger.error(
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
)
self._execution_history[workflow_context.trace_id]["status"] = "failed" self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return ( return (
@ -144,11 +147,19 @@ class WorkflowEngine:
step_timing_key = f"{stage.value}:{step_info.full_name}" step_timing_key = f"{stage.value}:{step_info.full_name}"
step_start = time.perf_counter() step_start = time.perf_counter()
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
try: try:
result = handler(context, message) if inspect.iscoroutinefunction(handler):
if inspect.isawaitable(result): coroutine = handler(context, message)
result = await result result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
else:
if timeout_seconds:
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
else:
result = handler(context, message)
if inspect.isawaitable(result):
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
context.timings[step_timing_key] = time.perf_counter() - step_start context.timings[step_timing_key] = time.perf_counter() - step_start
normalized_result = self._normalize_step_result(result) normalized_result = self._normalize_step_result(result)
@ -165,10 +176,30 @@ class WorkflowEngine:
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value) normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
return normalized_result return normalized_result
except asyncio.TimeoutError:
context.timings[step_timing_key] = time.perf_counter() - step_start
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
context.errors.append(f"{step_info.full_name}: {timeout_message}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
)
return WorkflowStepResult(
status="failed",
return_message=timeout_message,
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
},
)
except Exception as e: except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}") context.errors.append(f"{step_info.full_name}: {e}")
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True) logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
)
return WorkflowStepResult( return WorkflowStepResult(
status="failed", status="failed",
return_message=str(e), return_message=str(e),

View File

@ -117,7 +117,7 @@ class PromptManager:
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None: def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
""" """
添加一个上下文构造函数 添加一个上下文构造函数
Args: Args:
name (str): 上下文名称 name (str): 上下文名称
func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数接受 Prompt 名称作为参数返回字符串或返回字符串的协程 func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数接受 Prompt 名称作为参数返回字符串或返回字符串的协程
@ -144,7 +144,7 @@ class PromptManager:
def get_prompt(self, prompt_name: str) -> Prompt: def get_prompt(self, prompt_name: str) -> Prompt:
""" """
获取指定名称的 Prompt 实例的克隆 获取指定名称的 Prompt 实例的克隆
Args: Args:
prompt_name (str): 要获取的 Prompt 名称 prompt_name (str): 要获取的 Prompt 名称
Returns: Returns:
@ -161,7 +161,7 @@ class PromptManager:
async def render_prompt(self, prompt: Prompt) -> str: async def render_prompt(self, prompt: Prompt) -> str:
""" """
渲染一个 Prompt 实例 渲染一个 Prompt 实例
Args: Args:
prompt (Prompt): 要渲染的 Prompt 实例 prompt (Prompt): 要渲染的 Prompt 实例
Returns: Returns:

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容 2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载 3. 详情按需加载
""" """
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional from typing import List, Dict, Optional
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel): class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容""" """聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str chat_id: str
plan_count: int plan_count: int
latest_timestamp: float latest_timestamp: float
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
class PlanLogSummary(BaseModel): class PlanLogSummary(BaseModel):
"""规划日志摘要""" """规划日志摘要"""
chat_id: str chat_id: str
timestamp: float timestamp: float
filename: str filename: str
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
class PlanLogDetail(BaseModel): class PlanLogDetail(BaseModel):
"""规划日志详情""" """规划日志详情"""
type: str type: str
chat_id: str chat_id: str
timestamp: float timestamp: float
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
class PlannerOverview(BaseModel): class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计""" """规划器总览 - 轻量级统计"""
total_chats: int total_chats: int
total_plans: int total_plans: int
chats: List[ChatSummary] chats: List[ChatSummary]
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
class PaginatedChatLogs(BaseModel): class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表""" """分页的聊天日志列表"""
data: List[PlanLogSummary] data: List[PlanLogSummary]
total: int total: int
page: int page: int
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float: def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220""" """从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try: try:
timestamp_str = filename.split('_')[0] timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒 # 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000 return float(timestamp_str) / 1000
except (ValueError, IndexError): except (ValueError, IndexError):
@ -86,41 +92,39 @@ async def get_planner_overview():
""" """
if not PLAN_LOG_DIR.exists(): if not PLAN_LOG_DIR.exists():
return PlannerOverview(total_chats=0, total_plans=0, chats=[]) return PlannerOverview(total_chats=0, total_plans=0, chats=[])
chats = [] chats = []
total_plans = 0 total_plans = 0
for chat_dir in PLAN_LOG_DIR.iterdir(): for chat_dir in PLAN_LOG_DIR.iterdir():
if not chat_dir.is_dir(): if not chat_dir.is_dir():
continue continue
# 只统计json文件数量 # 只统计json文件数量
json_files = list(chat_dir.glob("*.json")) json_files = list(chat_dir.glob("*.json"))
plan_count = len(json_files) plan_count = len(json_files)
total_plans += plan_count total_plans += plan_count
if plan_count == 0: if plan_count == 0:
continue continue
# 从文件名获取最新时间戳 # 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name)) latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name) latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary( chats.append(
chat_id=chat_dir.name, ChatSummary(
plan_count=plan_count, chat_id=chat_dir.name,
latest_timestamp=latest_timestamp, plan_count=plan_count,
latest_filename=latest_file.name latest_timestamp=latest_timestamp,
)) latest_filename=latest_file.name,
)
)
# 按最新时间戳排序 # 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True) chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview( return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs) @router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
chat_id: str, chat_id: str,
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容") search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
): ):
""" """
获取指定聊天的规划日志列表分页 获取指定聊天的规划日志列表分页
@ -137,73 +141,69 @@ async def get_chat_plan_logs(
""" """
chat_dir = PLAN_LOG_DIR / chat_id chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists(): if not chat_dir.exists():
return PaginatedChatLogs( return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序 # 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json")) json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True) json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件 # 如果有搜索关键词,需要过滤文件
if search: if search:
search_lower = search.lower() search_lower = search.lower()
filtered_files = [] filtered_files = []
for log_file in json_files: for log_file in json_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
prompt = data.get('prompt', '') prompt = data.get("prompt", "")
if search_lower in prompt.lower(): if search_lower in prompt.lower():
filtered_files.append(log_file) filtered_files.append(log_file)
except Exception: except Exception:
continue continue
json_files = filtered_files json_files = filtered_files
total = len(json_files) total = len(json_files)
# 分页 - 只读取当前页的文件 # 分页 - 只读取当前页的文件
offset = (page - 1) * page_size offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size] page_files = json_files[offset : offset + page_size]
logs = [] logs = []
for log_file in page_files: for log_file in page_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
reasoning = data.get('reasoning', '') reasoning = data.get("reasoning", "")
actions = data.get('actions', []) actions = data.get("actions", [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')] action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
logs.append(PlanLogSummary( logs.append(
chat_id=data.get('chat_id', chat_id), PlanLogSummary(
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)), chat_id=data.get("chat_id", chat_id),
filename=log_file.name, timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
action_count=len(actions), filename=log_file.name,
action_types=action_types, action_count=len(actions),
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0), action_types=action_types,
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0), total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
reasoning_preview=reasoning[:100] if reasoning else '' llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
)) reasoning_preview=reasoning[:100] if reasoning else "",
)
)
except Exception: except Exception:
# 文件读取失败时使用文件名信息 # 文件读取失败时使用文件名信息
logs.append(PlanLogSummary( logs.append(
chat_id=chat_id, PlanLogSummary(
timestamp=parse_timestamp_from_filename(log_file.name), chat_id=chat_id,
filename=log_file.name, timestamp=parse_timestamp_from_filename(log_file.name),
action_count=0, filename=log_file.name,
action_types=[], action_count=0,
total_plan_ms=0, action_types=[],
llm_duration_ms=0, total_plan_ms=0,
reasoning_preview='[读取失败]' llm_duration_ms=0,
)) reasoning_preview="[读取失败]",
)
return PaginatedChatLogs( )
data=logs,
total=total, return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail) @router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
@ -212,22 +212,23 @@ async def get_log_detail(chat_id: str, filename: str):
log_file = PLAN_LOG_DIR / chat_id / filename log_file = PLAN_LOG_DIR / chat_id / filename
if not log_file.exists(): if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在") raise HTTPException(status_code=404, detail="日志文件不存在")
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return PlanLogDetail(**data) return PlanLogDetail(**data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e
# ========== 兼容旧接口 ========== # ========== 兼容旧接口 ==========
@router.get("/stats") @router.get("/stats")
async def get_planner_stats(): async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口""" """获取规划器统计信息 - 兼容旧接口"""
overview = await get_planner_overview() overview = await get_planner_overview()
# 获取最近10条计划的摘要 # 获取最近10条计划的摘要
recent_plans = [] recent_plans = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取 for chat in overview.chats[:5]: # 从最近5个聊天中获取
@ -236,17 +237,17 @@ async def get_planner_stats():
recent_plans.extend(chat_logs.data) recent_plans.extend(chat_logs.data)
except Exception: except Exception:
continue continue
# 按时间排序取前10 # 按时间排序取前10
recent_plans.sort(key=lambda x: x.timestamp, reverse=True) recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
recent_plans = recent_plans[:10] recent_plans = recent_plans[:10]
return { return {
"total_chats": overview.total_chats, "total_chats": overview.total_chats,
"total_plans": overview.total_plans, "total_plans": overview.total_plans,
"avg_plan_time_ms": 0, "avg_plan_time_ms": 0,
"avg_llm_time_ms": 0, "avg_llm_time_ms": 0,
"recent_plans": recent_plans "recent_plans": recent_plans,
} }
@ -258,44 +259,43 @@ async def get_chat_list():
@router.get("/all-logs") @router.get("/all-logs")
async def get_all_logs( async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
"""获取所有规划日志 - 兼容旧接口""" """获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists(): if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size} return {"data": [], "total": 0, "page": page, "page_size": page_size}
# 收集所有文件 # 收集所有文件
all_files = [] all_files = []
for chat_dir in PLAN_LOG_DIR.iterdir(): for chat_dir in PLAN_LOG_DIR.iterdir():
if chat_dir.is_dir(): if chat_dir.is_dir():
for log_file in chat_dir.glob("*.json"): for log_file in chat_dir.glob("*.json"):
all_files.append((chat_dir.name, log_file)) all_files.append((chat_dir.name, log_file))
# 按时间戳排序 # 按时间戳排序
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True) all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
total = len(all_files) total = len(all_files)
offset = (page - 1) * page_size offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size] page_files = all_files[offset : offset + page_size]
logs = [] logs = []
for chat_id, log_file in page_files: for chat_id, log_file in page_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
reasoning = data.get('reasoning', '') reasoning = data.get("reasoning", "")
logs.append({ logs.append(
"chat_id": data.get('chat_id', chat_id), {
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)), "chat_id": data.get("chat_id", chat_id),
"filename": log_file.name, "timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
"action_count": len(data.get('actions', [])), "filename": log_file.name,
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0), "action_count": len(data.get("actions", [])),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0), "total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
"reasoning_preview": reasoning[:100] if reasoning else '' "llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
}) "reasoning_preview": reasoning[:100] if reasoning else "",
}
)
except Exception: except Exception:
continue continue
return {"data": logs, "total": total, "page": page, "page_size": page_size} return {"data": logs, "total": total, "page": page, "page_size": page_size}

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容 2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载 3. 详情按需加载
""" """
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional from typing import List, Dict, Optional
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel): class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容""" """聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str chat_id: str
reply_count: int reply_count: int
latest_timestamp: float latest_timestamp: float
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
class ReplyLogSummary(BaseModel): class ReplyLogSummary(BaseModel):
"""回复日志摘要""" """回复日志摘要"""
chat_id: str chat_id: str
timestamp: float timestamp: float
filename: str filename: str
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
class ReplyLogDetail(BaseModel): class ReplyLogDetail(BaseModel):
"""回复日志详情""" """回复日志详情"""
type: str type: str
chat_id: str chat_id: str
timestamp: float timestamp: float
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
class ReplierOverview(BaseModel): class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计""" """回复器总览 - 轻量级统计"""
total_chats: int total_chats: int
total_replies: int total_replies: int
chats: List[ReplierChatSummary] chats: List[ReplierChatSummary]
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
class PaginatedReplyLogs(BaseModel): class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表""" """分页的回复日志列表"""
data: List[ReplyLogSummary] data: List[ReplyLogSummary]
total: int total: int
page: int page: int
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float: def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220""" """从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try: try:
timestamp_str = filename.split('_')[0] timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒 # 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000 return float(timestamp_str) / 1000
except (ValueError, IndexError): except (ValueError, IndexError):
@ -89,41 +95,39 @@ async def get_replier_overview():
""" """
if not REPLY_LOG_DIR.exists(): if not REPLY_LOG_DIR.exists():
return ReplierOverview(total_chats=0, total_replies=0, chats=[]) return ReplierOverview(total_chats=0, total_replies=0, chats=[])
chats = [] chats = []
total_replies = 0 total_replies = 0
for chat_dir in REPLY_LOG_DIR.iterdir(): for chat_dir in REPLY_LOG_DIR.iterdir():
if not chat_dir.is_dir(): if not chat_dir.is_dir():
continue continue
# 只统计json文件数量 # 只统计json文件数量
json_files = list(chat_dir.glob("*.json")) json_files = list(chat_dir.glob("*.json"))
reply_count = len(json_files) reply_count = len(json_files)
total_replies += reply_count total_replies += reply_count
if reply_count == 0: if reply_count == 0:
continue continue
# 从文件名获取最新时间戳 # 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name)) latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name) latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary( chats.append(
chat_id=chat_dir.name, ReplierChatSummary(
reply_count=reply_count, chat_id=chat_dir.name,
latest_timestamp=latest_timestamp, reply_count=reply_count,
latest_filename=latest_file.name latest_timestamp=latest_timestamp,
)) latest_filename=latest_file.name,
)
)
# 按最新时间戳排序 # 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True) chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview( return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs) @router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
chat_id: str, chat_id: str,
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容") search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
): ):
""" """
获取指定聊天的回复日志列表分页 获取指定聊天的回复日志列表分页
@ -140,71 +144,67 @@ async def get_chat_reply_logs(
""" """
chat_dir = REPLY_LOG_DIR / chat_id chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists(): if not chat_dir.exists():
return PaginatedReplyLogs( return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序 # 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json")) json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True) json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件 # 如果有搜索关键词,需要过滤文件
if search: if search:
search_lower = search.lower() search_lower = search.lower()
filtered_files = [] filtered_files = []
for log_file in json_files: for log_file in json_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
prompt = data.get('prompt', '') prompt = data.get("prompt", "")
if search_lower in prompt.lower(): if search_lower in prompt.lower():
filtered_files.append(log_file) filtered_files.append(log_file)
except Exception: except Exception:
continue continue
json_files = filtered_files json_files = filtered_files
total = len(json_files) total = len(json_files)
# 分页 - 只读取当前页的文件 # 分页 - 只读取当前页的文件
offset = (page - 1) * page_size offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size] page_files = json_files[offset : offset + page_size]
logs = [] logs = []
for log_file in page_files: for log_file in page_files:
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
output = data.get('output', '') output = data.get("output", "")
logs.append(ReplyLogSummary( logs.append(
chat_id=data.get('chat_id', chat_id), ReplyLogSummary(
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)), chat_id=data.get("chat_id", chat_id),
filename=log_file.name, timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
model=data.get('model', ''), filename=log_file.name,
success=data.get('success', True), model=data.get("model", ""),
llm_ms=data.get('timing', {}).get('llm_ms', 0), success=data.get("success", True),
overall_ms=data.get('timing', {}).get('overall_ms', 0), llm_ms=data.get("timing", {}).get("llm_ms", 0),
output_preview=output[:100] if output else '' overall_ms=data.get("timing", {}).get("overall_ms", 0),
)) output_preview=output[:100] if output else "",
)
)
except Exception: except Exception:
# 文件读取失败时使用文件名信息 # 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary( logs.append(
chat_id=chat_id, ReplyLogSummary(
timestamp=parse_timestamp_from_filename(log_file.name), chat_id=chat_id,
filename=log_file.name, timestamp=parse_timestamp_from_filename(log_file.name),
model='', filename=log_file.name,
success=False, model="",
llm_ms=0, success=False,
overall_ms=0, llm_ms=0,
output_preview='[读取失败]' overall_ms=0,
)) output_preview="[读取失败]",
)
return PaginatedReplyLogs( )
data=logs,
total=total, return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail) @router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
@ -213,35 +213,36 @@ async def get_reply_log_detail(chat_id: str, filename: str):
log_file = REPLY_LOG_DIR / chat_id / filename log_file = REPLY_LOG_DIR / chat_id / filename
if not log_file.exists(): if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在") raise HTTPException(status_code=404, detail="日志文件不存在")
try: try:
with open(log_file, 'r', encoding='utf-8') as f: with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return ReplyLogDetail( return ReplyLogDetail(
type=data.get('type', 'reply'), type=data.get("type", "reply"),
chat_id=data.get('chat_id', chat_id), chat_id=data.get("chat_id", chat_id),
timestamp=data.get('timestamp', 0), timestamp=data.get("timestamp", 0),
prompt=data.get('prompt', ''), prompt=data.get("prompt", ""),
output=data.get('output', ''), output=data.get("output", ""),
processed_output=data.get('processed_output', []), processed_output=data.get("processed_output", []),
model=data.get('model', ''), model=data.get("model", ""),
reasoning=data.get('reasoning', ''), reasoning=data.get("reasoning", ""),
think_level=data.get('think_level', 0), think_level=data.get("think_level", 0),
timing=data.get('timing', {}), timing=data.get("timing", {}),
error=data.get('error'), error=data.get("error"),
success=data.get('success', True) success=data.get("success", True),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e
# ========== 兼容接口 ========== # ========== 兼容接口 ==========
@router.get("/stats") @router.get("/stats")
async def get_replier_stats(): async def get_replier_stats():
"""获取回复器统计信息""" """获取回复器统计信息"""
overview = await get_replier_overview() overview = await get_replier_overview()
# 获取最近10条回复的摘要 # 获取最近10条回复的摘要
recent_replies = [] recent_replies = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取 for chat in overview.chats[:5]: # 从最近5个聊天中获取
@ -250,15 +251,15 @@ async def get_replier_stats():
recent_replies.extend(chat_logs.data) recent_replies.extend(chat_logs.data)
except Exception: except Exception:
continue continue
# 按时间排序取前10 # 按时间排序取前10
recent_replies.sort(key=lambda x: x.timestamp, reverse=True) recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
recent_replies = recent_replies[:10] recent_replies = recent_replies[:10]
return { return {
"total_chats": overview.total_chats, "total_chats": overview.total_chats,
"total_replies": overview.total_replies, "total_replies": overview.total_replies,
"recent_replies": recent_replies "recent_replies": recent_replies,
} }
@ -266,4 +267,4 @@ async def get_replier_stats():
async def get_replier_chat_list(): async def get_replier_chat_list():
"""获取所有聊天ID列表""" """获取所有聊天ID列表"""
overview = await get_replier_overview() overview = await get_replier_overview()
return [chat.chat_id for chat in overview.chats] return [chat.chat_id for chat in overview.chats]

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi import Depends, Cookie, Header, Request, HTTPException from fastapi import Depends, Cookie, Header, Request
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit from .core import get_current_token, get_token_manager, check_auth_rate_limit
async def require_auth( async def require_auth(

View File

@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
# loose: 宽松模式(较宽松的检测,较高的频率限制) # loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP # basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔 # IP白名单配置从配置文件读取逗号分隔
# 支持格式: # 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100 # - 精确IP127.0.0.1, 192.168.1.100
@ -151,7 +152,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
ip_entry = ip_entry.strip() # 去除空格 ip_entry = ip_entry.strip() # 去除空格
if not ip_entry: if not ip_entry:
continue continue
# 跳过注释行(以#开头) # 跳过注释行(以#开头)
if ip_entry.startswith("#"): if ip_entry.startswith("#"):
continue continue
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
def _get_anti_crawler_config(): def _get_anti_crawler_config():
"""获取防爬虫配置""" """获取防爬虫配置"""
from src.config.config import global_config from src.config.config import global_config
return { return {
'mode': global_config.webui.anti_crawler_mode, "mode": global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips), "allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies), "trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff "trust_xff": global_config.webui.trust_xff,
} }
# 初始化配置(将在模块加载时执行) # 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config() _config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode'] ANTI_CRAWLER_MODE = _config["mode"]
ALLOWED_IPS = _config['allowed_ips'] ALLOWED_IPS = _config["allowed_ips"]
TRUSTED_PROXIES = _config['trusted_proxies'] TRUSTED_PROXIES = _config["trusted_proxies"]
TRUST_XFF = _config['trust_xff'] TRUST_XFF = _config["trust_xff"]
def _get_mode_config(mode: str) -> dict: def _get_mode_config(mode: str) -> dict:

View File

@ -333,7 +333,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
statement = select(func.count()).where( statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_at) == True, col(Messages.is_at),
) )
data.at_count = int(session.exec(statement).first() or 0) data.at_count = int(session.exec(statement).first() or 0)
@ -342,7 +342,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
statement = select(func.count()).where( statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_mentioned) == True, col(Messages.is_mentioned),
) )
data.mentioned_count = int(session.exec(statement).first() or 0) data.mentioned_count = int(session.exec(statement).first() or 0)
@ -552,7 +552,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
# 1. 表情包之王 - 使用次数最多的表情包 # 1. 表情包之王 - 使用次数最多的表情包
with get_db_session() as session: with get_db_session() as session:
statement = ( statement = (
select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5) select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
) )
top_emojis = session.exec(statement).all() top_emojis = session.exec(statement).all()
if top_emojis: if top_emojis:
@ -636,7 +636,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
statement = select(func.count()).where( statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_picture) == True, col(Messages.is_picture),
) )
data.image_processed_count = int(session.exec(statement).first() or 0) data.image_processed_count = int(session.exec(statement).first() or 0)
@ -781,12 +781,12 @@ async def get_achievements(year: int = 2025) -> AchievementData:
# 1. 新学到的黑话数量 # 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话 # Jargon 表没有时间字段,统计全部已确认的黑话
with get_db_session() as session: with get_db_session() as session:
statement = select(func.count()).where(col(Jargon.is_jargon) == True) statement = select(func.count()).where(col(Jargon.is_jargon))
data.new_jargon_count = int(session.exec(statement).first() or 0) data.new_jargon_count = int(session.exec(statement).first() or 0)
# 2. 代表性黑话示例 # 2. 代表性黑话示例
with get_db_session() as session: with get_db_session() as session:
statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5) statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5)
jargon_samples = session.exec(statement).all() jargon_samples = session.exec(statement).all()
data.sample_jargons = [ data.sample_jargons = [
{ {

View File

@ -532,7 +532,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
.select_from(Images) .select_from(Images)
.where( .where(
col(Images.image_type) == ImageType.EMOJI, col(Images.image_type) == ImageType.EMOJI,
col(Images.is_registered) == True, col(Images.is_registered),
) )
) )
banned_statement = ( banned_statement = (
@ -540,7 +540,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
.select_from(Images) .select_from(Images)
.where( .where(
col(Images.image_type) == ImageType.EMOJI, col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == True, col(Images.is_banned),
) )
) )
@ -1283,7 +1283,7 @@ async def preheat_thumbnail_cache(
select(Images) select(Images)
.where( .where(
col(Images.image_type) == ImageType.EMOJI, col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == False, col(Images.is_banned).is_(False),
) )
.order_by(col(Images.query_count).desc()) .order_by(col(Images.query_count).desc())
.limit(limit * 2) .limit(limit * 2)

View File

@ -315,15 +315,15 @@ async def get_jargon_stats():
total = session.exec(select(fn.count()).select_from(Jargon)).one() total = session.exec(select(fn.count()).select_from(Jargon)).one()
confirmed_jargon = session.exec( confirmed_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True) select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))
).one() ).one()
confirmed_not_jargon = session.exec( confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False) select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
).one() ).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one() pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
complete_count = session.exec( complete_count = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True) select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))
).one() ).one()
chat_count = session.exec( chat_count = session.exec(

View File

@ -17,36 +17,36 @@ _paragraph_store_cache = None
def _get_paragraph_store(): def _get_paragraph_store():
"""延迟加载段落 embedding store只读模式轻量级 """延迟加载段落 embedding store只读模式轻量级
Returns: Returns:
EmbeddingStore | None: 如果配置启用则返回store否则返回None EmbeddingStore | None: 如果配置启用则返回store否则返回None
""" """
# 检查配置是否启用 # 检查配置是否启用
if not global_config.webui.enable_paragraph_content: if not global_config.webui.enable_paragraph_content:
return None return None
global _paragraph_store_cache global _paragraph_store_cache
if _paragraph_store_cache is not None: if _paragraph_store_cache is not None:
return _paragraph_store_cache return _paragraph_store_cache
try: try:
from src.chat.knowledge.embedding_store import EmbeddingStore from src.chat.knowledge.embedding_store import EmbeddingStore
import os import os
# 获取数据路径 # 获取数据路径
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..")) root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
embedding_dir = os.path.join(root_path, "data/embedding") embedding_dir = os.path.join(root_path, "data/embedding")
# 只加载段落 embedding store轻量级 # 只加载段落 embedding store轻量级
paragraph_store = EmbeddingStore( paragraph_store = EmbeddingStore(
namespace="paragraph", namespace="paragraph",
dir_path=embedding_dir, dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程 max_workers=1, # 只读不需要多线程
chunk_size=100 chunk_size=100,
) )
paragraph_store.load_from_file() paragraph_store.load_from_file()
_paragraph_store_cache = paragraph_store _paragraph_store_cache = paragraph_store
logger.info(f"成功加载段落 embedding store包含 {len(paragraph_store.store)} 个段落") logger.info(f"成功加载段落 embedding store包含 {len(paragraph_store.store)} 个段落")
return paragraph_store return paragraph_store
@ -57,10 +57,10 @@ def _get_paragraph_store():
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]: def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
"""从 embedding store 获取段落完整内容 """从 embedding store 获取段落完整内容
Args: Args:
node_id: 段落节点ID格式为 'paragraph-{hash}' node_id: 段落节点ID格式为 'paragraph-{hash}'
Returns: Returns:
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能) tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
""" """
@ -69,12 +69,12 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
if paragraph_store is None: if paragraph_store is None:
# 功能未启用 # 功能未启用
return None, False return None, False
# 从 store 中获取完整内容 # 从 store 中获取完整内容
paragraph_item = paragraph_store.store.get(node_id) paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None: if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本 # paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '') content: str = getattr(paragraph_item, "str", "")
if content: if content:
return content, True return content, True
return None, True return None, True
@ -156,14 +156,18 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
node_data = graph[node_id] node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph" # 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容 # 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph": if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id) full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else: else:
content = node_data["content"] if "content" in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time)) nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
@ -245,14 +249,18 @@ async def get_knowledge_graph(
try: try:
node_data = graph[node_id] node_data = graph[node_id]
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容 # 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph": if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id) full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else: else:
content = node_data["content"] if "content" in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time)) nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
@ -368,11 +376,15 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
try: try:
node_data = graph[node_id] node_data = graph[node_id]
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容 # 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph": if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id) full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id) content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else: else:
content = node_data["content"] if "content" in node_data else node_id content = node_data["content"] if "content" in node_data else node_id

View File

@ -370,7 +370,7 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
with get_db_session() as session: with get_db_session() as session:
total = len(session.exec(select(PersonInfo.id)).all()) total = len(session.exec(select(PersonInfo.id)).all())
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all()) known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all())
unknown = total - known unknown = total - known
# 按平台统计 # 按平台统计

View File

@ -1762,7 +1762,7 @@ async def update_plugin_config_raw(
try: try:
tomlkit.loads(request.config) tomlkit.loads(request.config)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 备份旧配置 # 备份旧配置
import shutil import shutil

View File

@ -659,4 +659,4 @@ def get_git_mirror_service() -> GitMirrorService:
global _git_mirror_service global _git_mirror_service
if _git_mirror_service is None: if _git_mirror_service is None:
_git_mirror_service = GitMirrorService() _git_mirror_service = GitMirrorService()
return _git_mirror_service return _git_mirror_service