mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev
commit
04a5bf3c6d
1
bot.py
1
bot.py
|
|
@ -50,6 +50,7 @@ print("警告:Dev进入不稳定开发状态,任何插件与WebUI均可能
|
|||
print("\n\n\n\n\n")
|
||||
print("-----------------------------------------")
|
||||
|
||||
|
||||
def run_runner_process():
|
||||
"""
|
||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Core helpers for MCP Bridge Plugin."""
|
||||
|
||||
|
|
|
|||
|
|
@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
|
|||
if not mcp_servers:
|
||||
return ""
|
||||
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -22,21 +22,24 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("mcp_tool_chain")
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("mcp_tool_chain")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolChainStep:
|
||||
"""工具链步骤"""
|
||||
|
||||
tool_name: str # 要调用的工具名(如 mcp_server_tool)
|
||||
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
|
||||
output_key: str = "" # 输出存储的键名,供后续步骤引用
|
||||
description: str = "" # 步骤描述
|
||||
optional: bool = False # 是否可选(失败时继续执行)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": self.tool_name,
|
||||
|
|
@ -45,7 +48,7 @@ class ToolChainStep:
|
|||
"description": self.description,
|
||||
"optional": self.optional,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep":
|
||||
return cls(
|
||||
|
|
@ -60,12 +63,13 @@ class ToolChainStep:
|
|||
@dataclass
|
||||
class ToolChainDefinition:
|
||||
"""工具链定义"""
|
||||
|
||||
name: str # 工具链名称(将作为组合工具的名称)
|
||||
description: str # 工具链描述(供 LLM 理解)
|
||||
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
|
||||
input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述}
|
||||
enabled: bool = True # 是否启用
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
|
|
@ -74,7 +78,7 @@ class ToolChainDefinition:
|
|||
"input_params": self.input_params,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition":
|
||||
steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])]
|
||||
|
|
@ -90,12 +94,13 @@ class ToolChainDefinition:
|
|||
@dataclass
|
||||
class ChainExecutionResult:
|
||||
"""工具链执行结果"""
|
||||
|
||||
success: bool
|
||||
final_output: str # 最终输出(最后一个步骤的结果)
|
||||
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
|
||||
error: str = ""
|
||||
total_duration_ms: float = 0.0
|
||||
|
||||
|
||||
def to_summary(self) -> str:
|
||||
"""生成执行摘要"""
|
||||
lines = []
|
||||
|
|
@ -103,7 +108,7 @@ class ChainExecutionResult:
|
|||
status = "✅" if step.get("success") else "❌"
|
||||
tool = step.get("tool_name", "unknown")
|
||||
duration = step.get("duration_ms", 0)
|
||||
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)")
|
||||
lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
|
||||
if not step.get("success") and step.get("error"):
|
||||
lines.append(f" 错误: {step['error'][:50]}")
|
||||
return "\n".join(lines)
|
||||
|
|
@ -111,49 +116,49 @@ class ChainExecutionResult:
|
|||
|
||||
class ToolChainExecutor:
|
||||
"""工具链执行器"""
|
||||
|
||||
|
||||
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
|
||||
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
|
||||
|
||||
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
|
||||
|
||||
def __init__(self, mcp_manager):
|
||||
self._mcp_manager = mcp_manager
|
||||
|
||||
|
||||
def _resolve_tool_key(self, tool_name: str) -> Optional[str]:
|
||||
"""解析工具名,返回有效的 tool_key
|
||||
|
||||
|
||||
支持:
|
||||
- 直接使用 tool_key(如 mcp_server_tool)
|
||||
- 使用注册后的工具名(会自动转换 - 和 . 为 _)
|
||||
"""
|
||||
all_tools = self._mcp_manager.all_tools
|
||||
|
||||
|
||||
# 直接匹配
|
||||
if tool_name in all_tools:
|
||||
return tool_name
|
||||
|
||||
|
||||
# 尝试转换后匹配(用户可能使用了注册后的名称)
|
||||
normalized = tool_name.replace("-", "_").replace(".", "_")
|
||||
if normalized in all_tools:
|
||||
return normalized
|
||||
|
||||
|
||||
# 尝试查找包含该名称的工具
|
||||
for key in all_tools.keys():
|
||||
if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"):
|
||||
return key
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
chain: ToolChainDefinition,
|
||||
input_args: Dict[str, Any],
|
||||
) -> ChainExecutionResult:
|
||||
"""执行工具链
|
||||
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
input_args: 用户输入的参数
|
||||
|
||||
|
||||
Returns:
|
||||
ChainExecutionResult: 执行结果
|
||||
"""
|
||||
|
|
@ -164,15 +169,15 @@ class ToolChainExecutor:
|
|||
"step": {}, # 各步骤输出,按 output_key 存储
|
||||
"prev": "", # 上一步的输出
|
||||
}
|
||||
|
||||
|
||||
final_output = ""
|
||||
|
||||
|
||||
# 验证必需的输入参数
|
||||
missing_params = []
|
||||
for param_name in chain.input_params.keys():
|
||||
if param_name not in context["input"]:
|
||||
missing_params.append(param_name)
|
||||
|
||||
|
||||
if missing_params:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
|
|
@ -180,7 +185,7 @@ class ToolChainExecutor:
|
|||
error=f"缺少必需参数: {', '.join(missing_params)}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
step_start = time.time()
|
||||
step_result = {
|
||||
|
|
@ -191,96 +196,96 @@ class ToolChainExecutor:
|
|||
"error": "",
|
||||
"duration_ms": 0,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 替换参数中的变量
|
||||
resolved_args = self._resolve_args(step.args_template, context)
|
||||
step_result["resolved_args"] = resolved_args
|
||||
|
||||
|
||||
# 解析工具名
|
||||
tool_key = self._resolve_tool_key(step.tool_name)
|
||||
if not tool_key:
|
||||
step_result["error"] = f"工具 {step.tool_name} 不存在"
|
||||
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在")
|
||||
|
||||
logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在",
|
||||
error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
step_results.append(step_result)
|
||||
continue
|
||||
|
||||
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}")
|
||||
|
||||
|
||||
logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
|
||||
|
||||
# 调用工具
|
||||
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
|
||||
|
||||
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
|
||||
|
||||
if result.success:
|
||||
step_result["success"] = True
|
||||
# 确保 content 不为 None
|
||||
content = result.content if result.content is not None else ""
|
||||
step_result["output"] = content
|
||||
|
||||
|
||||
# 更新上下文
|
||||
context["prev"] = content
|
||||
if step.output_key:
|
||||
context["step"][step.output_key] = content
|
||||
|
||||
|
||||
final_output = content
|
||||
content_preview = content[:100] if content else "(空)"
|
||||
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...")
|
||||
logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
|
||||
else:
|
||||
step_result["error"] = result.error or "未知错误"
|
||||
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}")
|
||||
|
||||
logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}",
|
||||
error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
step_result["error"] = str(e)
|
||||
logger.error(f"工具链步骤 {i+1} 异常: {e}")
|
||||
|
||||
logger.error(f"工具链步骤 {i + 1} 异常: {e}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}",
|
||||
error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
|
||||
step_results.append(step_result)
|
||||
|
||||
|
||||
total_duration = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
return ChainExecutionResult(
|
||||
success=True,
|
||||
final_output=final_output,
|
||||
step_results=step_results,
|
||||
total_duration_ms=total_duration,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解析参数模板,替换变量
|
||||
|
||||
|
||||
支持的变量格式:
|
||||
- ${input.param_name}: 用户输入的参数
|
||||
- ${step.output_key}: 某个步骤的输出
|
||||
|
|
@ -288,50 +293,48 @@ class ToolChainExecutor:
|
|||
- ${prev.field}: 上一步输出(JSON)的某个字段
|
||||
"""
|
||||
resolved = {}
|
||||
|
||||
|
||||
for key, value in args_template.items():
|
||||
if isinstance(value, str):
|
||||
resolved[key] = self._substitute_vars(value, context)
|
||||
elif isinstance(value, dict):
|
||||
resolved[key] = self._resolve_args(value, context)
|
||||
elif isinstance(value, list):
|
||||
resolved[key] = [
|
||||
self._substitute_vars(v, context) if isinstance(v, str) else v
|
||||
for v in value
|
||||
]
|
||||
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
|
||||
else:
|
||||
resolved[key] = value
|
||||
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""替换字符串中的变量"""
|
||||
|
||||
def replacer(match):
|
||||
var_path = match.group(1)
|
||||
return self._get_var_value(var_path, context)
|
||||
|
||||
|
||||
return self.VAR_PATTERN.sub(replacer, template)
|
||||
|
||||
|
||||
def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str:
|
||||
"""获取变量值
|
||||
|
||||
|
||||
Args:
|
||||
var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id"
|
||||
context: 上下文
|
||||
"""
|
||||
parts = self._parse_var_path(var_path)
|
||||
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
|
||||
# 获取根对象
|
||||
root = parts[0]
|
||||
if root not in context:
|
||||
logger.warning(f"变量 {var_path} 的根 '{root}' 不存在")
|
||||
return ""
|
||||
|
||||
|
||||
value = context[root]
|
||||
|
||||
|
||||
# 遍历路径
|
||||
for part in parts[1:]:
|
||||
if isinstance(value, str):
|
||||
|
|
@ -349,7 +352,7 @@ class ToolChainExecutor:
|
|||
value = ""
|
||||
else:
|
||||
value = ""
|
||||
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
|
@ -448,39 +451,39 @@ class ToolChainExecutor:
|
|||
|
||||
class ToolChainManager:
|
||||
"""工具链管理器"""
|
||||
|
||||
|
||||
_instance: Optional["ToolChainManager"] = None
|
||||
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._chains: Dict[str, ToolChainDefinition] = {}
|
||||
self._executor: Optional[ToolChainExecutor] = None
|
||||
|
||||
|
||||
def set_executor(self, mcp_manager) -> None:
|
||||
"""设置执行器"""
|
||||
self._executor = ToolChainExecutor(mcp_manager)
|
||||
|
||||
|
||||
def add_chain(self, chain: ToolChainDefinition) -> bool:
|
||||
"""添加工具链"""
|
||||
if not chain.name:
|
||||
logger.error("工具链名称不能为空")
|
||||
return False
|
||||
|
||||
|
||||
if chain.name in self._chains:
|
||||
logger.warning(f"工具链 {chain.name} 已存在,将被覆盖")
|
||||
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)")
|
||||
return True
|
||||
|
||||
|
||||
def remove_chain(self, name: str) -> bool:
|
||||
"""移除工具链"""
|
||||
if name in self._chains:
|
||||
|
|
@ -488,19 +491,19 @@ class ToolChainManager:
|
|||
logger.info(f"已移除工具链: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_chain(self, name: str) -> Optional[ToolChainDefinition]:
|
||||
"""获取工具链"""
|
||||
return self._chains.get(name)
|
||||
|
||||
|
||||
def get_all_chains(self) -> Dict[str, ToolChainDefinition]:
|
||||
"""获取所有工具链"""
|
||||
return self._chains.copy()
|
||||
|
||||
|
||||
def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]:
|
||||
"""获取所有启用的工具链"""
|
||||
return {name: chain for name, chain in self._chains.items() if chain.enabled}
|
||||
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
|
|
@ -514,64 +517,64 @@ class ToolChainManager:
|
|||
final_output="",
|
||||
error=f"工具链 {chain_name} 不存在",
|
||||
)
|
||||
|
||||
|
||||
if not chain.enabled:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"工具链 {chain_name} 已禁用",
|
||||
)
|
||||
|
||||
|
||||
if not self._executor:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error="工具链执行器未初始化",
|
||||
)
|
||||
|
||||
|
||||
return await self._executor.execute(chain, input_args)
|
||||
|
||||
|
||||
def load_from_json(self, json_str: str) -> Tuple[int, List[str]]:
|
||||
"""从 JSON 字符串加载工具链配置
|
||||
|
||||
|
||||
Returns:
|
||||
(成功加载数量, 错误列表)
|
||||
"""
|
||||
errors = []
|
||||
loaded = 0
|
||||
|
||||
|
||||
try:
|
||||
data = json.loads(json_str) if json_str.strip() else []
|
||||
except json.JSONDecodeError as e:
|
||||
return 0, [f"JSON 解析失败: {e}"]
|
||||
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
|
||||
for i, item in enumerate(data):
|
||||
try:
|
||||
chain = ToolChainDefinition.from_dict(item)
|
||||
if not chain.name:
|
||||
errors.append(f"第 {i+1} 个工具链缺少名称")
|
||||
errors.append(f"第 {i + 1} 个工具链缺少名称")
|
||||
continue
|
||||
if not chain.steps:
|
||||
errors.append(f"工具链 {chain.name} 没有步骤")
|
||||
continue
|
||||
|
||||
|
||||
self.add_chain(chain)
|
||||
loaded += 1
|
||||
except Exception as e:
|
||||
errors.append(f"第 {i+1} 个工具链解析失败: {e}")
|
||||
|
||||
errors.append(f"第 {i + 1} 个工具链解析失败: {e}")
|
||||
|
||||
return loaded, errors
|
||||
|
||||
|
||||
def export_to_json(self, pretty: bool = True) -> str:
|
||||
"""导出所有工具链为 JSON"""
|
||||
chains_data = [chain.to_dict() for chain in self._chains.values()]
|
||||
if pretty:
|
||||
return json.dumps(chains_data, ensure_ascii=False, indent=2)
|
||||
return json.dumps(chains_data, ensure_ascii=False)
|
||||
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有工具链"""
|
||||
self._chains.clear()
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
|
|||
chat_stream=self.message.chat_stream,
|
||||
reply_reason=reply_reason,
|
||||
enable_chinese_typo=False,
|
||||
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
|
||||
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
|
||||
)
|
||||
if result_status:
|
||||
# 发送生成的回复
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
|
|||
|
||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
||||
|
||||
|
||||
class SimpleClass(ConfigBase):
|
||||
a: int = 1
|
||||
b: str = "test"
|
||||
|
|
@ -282,7 +283,7 @@ class TestConfigBase:
|
|||
True,
|
||||
"ConfigBase is not Hashable",
|
||||
id="listset-validation-set-configbase-element_reject",
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
||||
|
|
@ -340,7 +341,7 @@ class TestConfigBase:
|
|||
False,
|
||||
None,
|
||||
id="dict-validation-happy-configbase-value",
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
||||
|
|
@ -353,13 +354,11 @@ class TestConfigBase:
|
|||
field_name = "mapping"
|
||||
|
||||
if expect_error:
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
|
||||
# Act
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
|
||||
|
|
@ -392,7 +391,7 @@ class TestConfigBase:
|
|||
|
||||
# Assert
|
||||
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
|
||||
|
||||
|
||||
def test_discourage_any_usage_suppressed_warning(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = False
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import importlib
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
import asyncio
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
|
|
@ -71,6 +70,7 @@ class DummyLLMRequest:
|
|||
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
|
||||
return ("dummy description", {})
|
||||
|
||||
|
||||
class DummySelect:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
|
@ -81,6 +81,7 @@ class DummySelect:
|
|||
def limit(self, n):
|
||||
return self
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_external_dependencies(monkeypatch):
|
||||
# Provide dummy implementations as modules so that importing image_manager is safe
|
||||
|
|
@ -103,11 +104,11 @@ def patch_external_dependencies(monkeypatch):
|
|||
# Patch MaiImage data model
|
||||
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
|
||||
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
|
||||
|
||||
|
||||
# Patch SQLModel select function
|
||||
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
|
||||
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
|
||||
|
||||
|
||||
# Patch config values used at import-time
|
||||
cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style"))
|
||||
model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm"))
|
||||
|
|
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
|
|||
if tmp_path is not None:
|
||||
tmpdir = Path(tmp_path)
|
||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||
setattr(mod, "IMAGE_DIR", tmpdir)
|
||||
mod.IMAGE_DIR = tmpdir
|
||||
except Exception:
|
||||
pass
|
||||
return mod
|
||||
|
|
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
|
|||
|
||||
# cleanup should run without error
|
||||
mgr.cleanup_invalid_descriptions_in_db()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import pytest
|
||||
|
||||
from src.config.official_configs import ChatConfig
|
||||
from src.config.config import Config
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ def test_auth_required_list(client):
|
|||
"""测试未认证访问列表端点(401)"""
|
||||
# Without mock_token_verify fixture
|
||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
|
||||
response = client.get("/emoji/list")
|
||||
client.get("/emoji/list")
|
||||
# verify_auth_token 返回 False 会触发 HTTPException
|
||||
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
|
||||
# 这里假设它抛出 401
|
||||
|
|
@ -397,7 +397,7 @@ def test_auth_required_update(client, sample_emojis):
|
|||
"""测试未认证访问更新端点(401)"""
|
||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
|
||||
emoji_id = sample_emojis[0].id
|
||||
response = client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
|
||||
client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
|
||||
# Should be unauthorized
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Expression routes pytest tests"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
|
@ -12,7 +11,6 @@ from sqlalchemy import text
|
|||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
|
||||
|
||||
def create_test_app() -> FastAPI:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import Dict, List, Set, Tuple
|
|||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger("evaluation_stats_analyzer")
|
||||
|
||||
|
|
@ -38,10 +38,10 @@ def parse_datetime(dt_str: str) -> datetime | None:
|
|||
def analyze_single_file(file_path: str) -> Dict:
|
||||
"""
|
||||
分析单个JSON文件的统计信息
|
||||
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
|
|
@ -65,40 +65,40 @@ def analyze_single_file(file_path: str) -> Dict:
|
|||
"has_reason": False,
|
||||
"reason_count": 0,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# 基本信息
|
||||
stats["last_updated"] = data.get("last_updated")
|
||||
stats["total_count"] = data.get("total_count", 0)
|
||||
|
||||
|
||||
results = data.get("manual_results", [])
|
||||
stats["actual_count"] = len(results)
|
||||
|
||||
|
||||
if not results:
|
||||
return stats
|
||||
|
||||
|
||||
# 统计通过/不通过
|
||||
suitable_count = sum(1 for r in results if r.get("suitable") is True)
|
||||
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
|
||||
stats["suitable_count"] = suitable_count
|
||||
stats["unsuitable_count"] = unsuitable_count
|
||||
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
|
||||
|
||||
|
||||
# 统计唯一的(situation, style)对
|
||||
pairs: Set[Tuple[str, str]] = set()
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
pairs.add((r["situation"], r["style"]))
|
||||
stats["unique_pairs"] = len(pairs)
|
||||
|
||||
|
||||
# 统计评估者
|
||||
for r in results:
|
||||
evaluator = r.get("evaluator", "unknown")
|
||||
stats["evaluators"][evaluator] += 1
|
||||
|
||||
|
||||
# 统计评估时间
|
||||
evaluation_dates = []
|
||||
for r in results:
|
||||
|
|
@ -107,7 +107,7 @@ def analyze_single_file(file_path: str) -> Dict:
|
|||
dt = parse_datetime(evaluated_at)
|
||||
if dt:
|
||||
evaluation_dates.append(dt)
|
||||
|
||||
|
||||
stats["evaluation_dates"] = evaluation_dates
|
||||
if evaluation_dates:
|
||||
min_date = min(evaluation_dates)
|
||||
|
|
@ -115,18 +115,18 @@ def analyze_single_file(file_path: str) -> Dict:
|
|||
stats["date_range"] = {
|
||||
"start": min_date.isoformat(),
|
||||
"end": max_date.isoformat(),
|
||||
"duration_days": (max_date - min_date).days + 1
|
||||
"duration_days": (max_date - min_date).days + 1,
|
||||
}
|
||||
|
||||
|
||||
# 检查字段存在性
|
||||
stats["has_expression_id"] = any("expression_id" in r for r in results)
|
||||
stats["has_reason"] = any(r.get("reason") for r in results)
|
||||
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
stats["error"] = str(e)
|
||||
logger.error(f"分析文件 {file_name} 时出错: {e}")
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
|
|
@ -136,57 +136,57 @@ def print_file_stats(stats: Dict, index: int = None):
|
|||
print(f"\n{'=' * 80}")
|
||||
print(f"{prefix}文件: {stats['file_name']}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
if stats["error"]:
|
||||
print(f"✗ 错误: {stats['error']}")
|
||||
return
|
||||
|
||||
|
||||
print(f"文件路径: {stats['file_path']}")
|
||||
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
|
||||
|
||||
|
||||
if stats["last_updated"]:
|
||||
print(f"最后更新: {stats['last_updated']}")
|
||||
|
||||
|
||||
print("\n【记录统计】")
|
||||
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||
print(f" 实际记录数: {stats['actual_count']}")
|
||||
|
||||
if stats['total_count'] != stats['actual_count']:
|
||||
diff = stats['total_count'] - stats['actual_count']
|
||||
|
||||
if stats["total_count"] != stats["actual_count"]:
|
||||
diff = stats["total_count"] - stats["actual_count"]
|
||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||
|
||||
|
||||
print("\n【评估结果统计】")
|
||||
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
|
||||
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
|
||||
|
||||
|
||||
print("\n【唯一性统计】")
|
||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||
if stats['actual_count'] > 0:
|
||||
duplicate_count = stats['actual_count'] - stats['unique_pairs']
|
||||
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["actual_count"] > 0:
|
||||
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
|
||||
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
|
||||
print("\n【评估者统计】")
|
||||
if stats['evaluators']:
|
||||
for evaluator, count in stats['evaluators'].most_common():
|
||||
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["evaluators"]:
|
||||
for evaluator, count in stats["evaluators"].most_common():
|
||||
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
|
||||
print("\n【时间统计】")
|
||||
if stats['date_range']:
|
||||
if stats["date_range"]:
|
||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||
else:
|
||||
print(" 无时间信息")
|
||||
|
||||
|
||||
print("\n【字段统计】")
|
||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||
if stats['has_reason']:
|
||||
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["has_reason"]:
|
||||
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||
|
||||
|
||||
|
|
@ -195,35 +195,35 @@ def print_summary(all_stats: List[Dict]):
|
|||
print(f"\n{'=' * 80}")
|
||||
print("汇总统计")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
total_files = len(all_stats)
|
||||
valid_files = [s for s in all_stats if not s.get("error")]
|
||||
error_files = [s for s in all_stats if s.get("error")]
|
||||
|
||||
|
||||
print("\n【文件统计】")
|
||||
print(f" 总文件数: {total_files}")
|
||||
print(f" 成功解析: {len(valid_files)}")
|
||||
print(f" 解析失败: {len(error_files)}")
|
||||
|
||||
|
||||
if error_files:
|
||||
print("\n 失败文件列表:")
|
||||
for stats in error_files:
|
||||
print(f" - {stats['file_name']}: {stats['error']}")
|
||||
|
||||
|
||||
if not valid_files:
|
||||
print("\n没有成功解析的文件")
|
||||
return
|
||||
|
||||
|
||||
# 汇总记录统计
|
||||
total_records = sum(s['actual_count'] for s in valid_files)
|
||||
total_suitable = sum(s['suitable_count'] for s in valid_files)
|
||||
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
|
||||
total_records = sum(s["actual_count"] for s in valid_files)
|
||||
total_suitable = sum(s["suitable_count"] for s in valid_files)
|
||||
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
|
||||
total_unique_pairs = set()
|
||||
|
||||
|
||||
# 收集所有唯一的(situation, style)对
|
||||
for stats in valid_files:
|
||||
try:
|
||||
with open(stats['file_path'], "r", encoding="utf-8") as f:
|
||||
with open(stats["file_path"], "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
for r in results:
|
||||
|
|
@ -231,23 +231,31 @@ def print_summary(all_stats: List[Dict]):
|
|||
total_unique_pairs.add((r["situation"], r["style"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
print("\n【记录汇总】")
|
||||
print(f" 总记录数: {total_records:,} 条")
|
||||
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
|
||||
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
|
||||
print(
|
||||
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 通过: 0 条"
|
||||
)
|
||||
print(
|
||||
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 不通过: 0 条"
|
||||
)
|
||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||
|
||||
|
||||
if total_records > 0:
|
||||
duplicate_count = total_records - len(total_unique_pairs)
|
||||
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
|
||||
# 汇总评估者统计
|
||||
all_evaluators = Counter()
|
||||
for stats in valid_files:
|
||||
all_evaluators.update(stats['evaluators'])
|
||||
|
||||
all_evaluators.update(stats["evaluators"])
|
||||
|
||||
print("\n【评估者汇总】")
|
||||
if all_evaluators:
|
||||
for evaluator, count in all_evaluators.most_common():
|
||||
|
|
@ -255,12 +263,12 @@ def print_summary(all_stats: List[Dict]):
|
|||
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
|
||||
# 汇总时间范围
|
||||
all_dates = []
|
||||
for stats in valid_files:
|
||||
all_dates.extend(stats['evaluation_dates'])
|
||||
|
||||
all_dates.extend(stats["evaluation_dates"])
|
||||
|
||||
if all_dates:
|
||||
min_date = min(all_dates)
|
||||
max_date = max(all_dates)
|
||||
|
|
@ -268,9 +276,9 @@ def print_summary(all_stats: List[Dict]):
|
|||
print(f" 最早评估时间: {min_date.isoformat()}")
|
||||
print(f" 最晚评估时间: {max_date.isoformat()}")
|
||||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||
|
||||
|
||||
# 文件大小汇总
|
||||
total_size = sum(s['file_size'] for s in valid_files)
|
||||
total_size = sum(s["file_size"] for s in valid_files)
|
||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||
print("\n【文件大小汇总】")
|
||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||
|
|
@ -282,35 +290,35 @@ def main():
|
|||
logger.info("=" * 80)
|
||||
logger.info("开始分析评估结果统计信息")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
print(f"\n✗ 错误:未找到temp目录: {TEMP_DIR}")
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
|
||||
if not json_files:
|
||||
print(f"\n✗ 错误:temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
|
||||
json_files.sort() # 按文件名排序
|
||||
|
||||
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
# 分析每个文件
|
||||
all_stats = []
|
||||
for i, json_file in enumerate(json_files, 1):
|
||||
stats = analyze_single_file(json_file)
|
||||
all_stats.append(stats)
|
||||
print_file_stats(stats, index=i)
|
||||
|
||||
|
||||
# 打印汇总统计
|
||||
print_summary(all_stats)
|
||||
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("分析完成")
|
||||
print(f"{'=' * 80}")
|
||||
|
|
@ -318,5 +326,3 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -171,7 +171,9 @@ def main():
|
|||
sys.exit(1)
|
||||
|
||||
if not args.raw_index:
|
||||
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
|
||||
logger.info(
|
||||
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# 解析索引列表(1-based)
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ from collections import defaultdict
|
|||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import db
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import Expression # noqa: E402
|
||||
from src.common.database.database import db # noqa: E402
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
|
||||
logger = get_logger("expression_evaluator_count_analysis_llm")
|
||||
|
||||
|
|
@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.
|
|||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(COUNT_ANALYSIS_FILE):
|
||||
return [], set()
|
||||
|
||||
|
||||
try:
|
||||
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
|
@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
|||
def save_results(evaluation_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(evaluation_results),
|
||||
"evaluation_results": evaluation_results
|
||||
"evaluation_results": evaluation_results,
|
||||
}
|
||||
|
||||
|
||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
|
||||
except Exception as e:
|
||||
|
|
@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]):
|
|||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def select_expressions_for_evaluation(
|
||||
evaluated_pairs: Set[Tuple[str, str]] = None
|
||||
) -> List[Expression]:
|
||||
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
|
||||
"""
|
||||
选择用于评估的表达方式
|
||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目集合,用于避免重复
|
||||
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
if evaluated_pairs is None:
|
||||
evaluated_pairs = set()
|
||||
|
||||
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
|
||||
# 过滤出未评估的项目
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.warning("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
|
||||
# 按count分组
|
||||
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
|
||||
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
|
||||
|
||||
|
||||
logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条")
|
||||
|
||||
|
||||
# 选择所有count>1的项目
|
||||
selected_count_gt1 = count_gt1.copy()
|
||||
|
||||
|
||||
# 选择count=1的项目,数量为count>1数量的2倍
|
||||
count_gt1_count = len(selected_count_gt1)
|
||||
count_eq1_needed = count_gt1_count * 2
|
||||
|
||||
|
||||
if len(count_eq1) < count_eq1_needed:
|
||||
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条")
|
||||
logger.warning(
|
||||
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条"
|
||||
)
|
||||
count_eq1_needed = len(count_eq1)
|
||||
|
||||
|
||||
# 随机选择count=1的项目
|
||||
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
|
||||
|
||||
|
||||
selected = selected_count_gt1 + selected_count_eq1
|
||||
random.shuffle(selected) # 打乱顺序
|
||||
|
||||
logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
|
@ -155,11 +155,11 @@ def select_expressions_for_evaluation(
|
|||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
|
|
@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
|||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
|
|
@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
|||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
|
@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
|||
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
|
||||
|
||||
logger.info(
|
||||
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
|
||||
)
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
|
|
@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
|||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
"""
|
||||
对评估结果进行统计分析
|
||||
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
if not evaluation_results:
|
||||
print("\n没有评估结果可供分析")
|
||||
return
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("统计分析结果")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 按count分组统计
|
||||
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
|
||||
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
|
|
@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
count_groups[count]["suitable"] += 1
|
||||
else:
|
||||
count_groups[count]["unsuitable"] += 1
|
||||
|
||||
|
||||
# 显示每个count的统计
|
||||
print("\n【按count分组统计】")
|
||||
print("-" * 60)
|
||||
|
|
@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
suitable = group["suitable"]
|
||||
unsuitable = group["unsuitable"]
|
||||
pass_rate = (suitable / total * 100) if total > 0 else 0
|
||||
|
||||
|
||||
print(f"Count = {count}:")
|
||||
print(f" 总数: {total}")
|
||||
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
|
||||
# 比较count=1和count>1
|
||||
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
|
||||
|
||||
if count == 1:
|
||||
count_eq1_group["total"] += 1
|
||||
if suitable:
|
||||
|
|
@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
count_gt1_group["suitable"] += 1
|
||||
else:
|
||||
count_gt1_group["unsuitable"] += 1
|
||||
|
||||
|
||||
print("\n【Count=1 vs Count>1 对比】")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
eq1_total = count_eq1_group["total"]
|
||||
eq1_suitable = count_eq1_group["suitable"]
|
||||
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
|
||||
|
||||
|
||||
gt1_total = count_gt1_group["total"]
|
||||
gt1_suitable = count_gt1_group["suitable"]
|
||||
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
|
||||
|
||||
|
||||
print("Count = 1:")
|
||||
print(f" 总数: {eq1_total}")
|
||||
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
|
||||
print()
|
||||
print("Count > 1:")
|
||||
print(f" 总数: {gt1_total}")
|
||||
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
|
||||
# 进行卡方检验(简化版,使用2x2列联表)
|
||||
if eq1_total > 0 and gt1_total > 0:
|
||||
print("【统计显著性检验】")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
# 构建2x2列联表
|
||||
# 通过 不通过
|
||||
# count=1 a b
|
||||
|
|
@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
b = eq1_total - eq1_suitable
|
||||
c = gt1_suitable
|
||||
d = gt1_total - gt1_suitable
|
||||
|
||||
|
||||
# 计算卡方统计量(简化版,使用Pearson卡方检验)
|
||||
n = eq1_total + gt1_total
|
||||
if n > 0:
|
||||
|
|
@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
e_b = (eq1_total * (b + d)) / n
|
||||
e_c = (gt1_total * (a + c)) / n
|
||||
e_d = (gt1_total * (b + d)) / n
|
||||
|
||||
|
||||
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5)
|
||||
min_expected = min(e_a, e_b, e_c, e_d)
|
||||
if min_expected < 5:
|
||||
print("警告:期望频数小于5,卡方检验可能不准确")
|
||||
print("建议使用Fisher精确检验")
|
||||
|
||||
|
||||
# 计算卡方值
|
||||
chi_square = 0
|
||||
if e_a > 0:
|
||||
|
|
@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
chi_square += ((c - e_c) ** 2) / e_c
|
||||
if e_d > 0:
|
||||
chi_square += ((d - e_d) ** 2) / e_d
|
||||
|
||||
|
||||
# 自由度 = (行数-1) * (列数-1) = 1
|
||||
df = 1
|
||||
|
||||
|
||||
# 临界值(α=0.05)
|
||||
chi_square_critical_005 = 3.841
|
||||
chi_square_critical_001 = 6.635
|
||||
|
||||
|
||||
print(f"卡方统计量: {chi_square:.4f}")
|
||||
print(f"自由度: {df}")
|
||||
print(f"临界值 (α=0.05): {chi_square_critical_005}")
|
||||
print(f"临界值 (α=0.01): {chi_square_critical_001}")
|
||||
|
||||
|
||||
if chi_square >= chi_square_critical_001:
|
||||
print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)")
|
||||
elif chi_square >= chi_square_critical_005:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)")
|
||||
else:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)")
|
||||
|
||||
|
||||
# 计算差异大小
|
||||
diff = abs(eq1_pass_rate - gt1_pass_rate)
|
||||
print(f"\n合格率差异: {diff:.2f}%")
|
||||
|
|
@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
|||
print("数据不足,无法进行统计检验")
|
||||
else:
|
||||
print("数据不足,无法进行count=1和count>1的对比分析")
|
||||
|
||||
|
||||
# 保存统计分析结果
|
||||
analysis_result = {
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||
"count_eq1": count_eq1_group,
|
||||
"count_gt1": count_gt1_group,
|
||||
"total_evaluated": len(evaluation_results)
|
||||
"total_evaluated": len(evaluation_results),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
|
||||
with open(analysis_file, "w", encoding="utf-8") as f:
|
||||
|
|
@ -444,7 +444,7 @@ async def main():
|
|||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式按count分组的LLM评估和统计分析")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
|
|
@ -452,97 +452,95 @@ async def main():
|
|||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
evaluation_results = existing_results.copy()
|
||||
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
|
||||
# 检查是否需要继续评估(检查是否还有未评估的count>1项目)
|
||||
# 先查询未评估的count>1项目数量
|
||||
try:
|
||||
all_expressions = list(Expression.select())
|
||||
unevaluated_count_gt1 = [
|
||||
expr for expr in all_expressions
|
||||
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"查询未评估项目失败: {e}")
|
||||
has_unevaluated = False
|
||||
|
||||
|
||||
if has_unevaluated:
|
||||
print("\n" + "=" * 60)
|
||||
print("开始LLM评估")
|
||||
print("=" * 60)
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
|
||||
# 创建LLM实例
|
||||
print("创建LLM实例...")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_count_analysis_llm"
|
||||
request_type="expression_evaluator_count_analysis_llm",
|
||||
)
|
||||
print("✓ LLM实例创建成功\n")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||
db.close()
|
||||
return
|
||||
|
||||
|
||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||
expressions = select_expressions_for_evaluation(
|
||||
evaluated_pairs=evaluated_pairs
|
||||
)
|
||||
|
||||
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
|
||||
|
||||
if not expressions:
|
||||
print("\n没有可评估的项目")
|
||||
else:
|
||||
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
|
||||
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条")
|
||||
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n")
|
||||
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
print(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
print(f" Situation: {expression.situation}")
|
||||
print(f" Style: {expression.style}")
|
||||
print(f" Count: {expression.count}")
|
||||
|
||||
|
||||
llm_result = await llm_evaluate_expression(expression, llm)
|
||||
|
||||
|
||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||
if llm_result.get('error'):
|
||||
if llm_result.get("error"):
|
||||
print(f" 错误: {llm_result['error']}")
|
||||
print()
|
||||
|
||||
|
||||
batch_results.append(llm_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
|
||||
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
evaluation_results.extend(batch_results)
|
||||
|
||||
|
||||
# 保存结果
|
||||
save_results(evaluation_results)
|
||||
else:
|
||||
print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果")
|
||||
|
||||
|
||||
# 进行统计分析
|
||||
if len(evaluation_results) > 0:
|
||||
perform_statistical_analysis(evaluation_results)
|
||||
else:
|
||||
print("\n没有评估结果可供分析")
|
||||
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
|
|
@ -553,4 +551,3 @@ async def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ from typing import List, Dict, Set, Tuple
|
|||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger("expression_evaluator_llm")
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
|||
def load_manual_results() -> List[Dict]:
|
||||
"""
|
||||
加载人工评估结果(自动读取temp目录下所有JSON文件并合并)
|
||||
|
||||
|
||||
Returns:
|
||||
人工评估结果列表(已去重)
|
||||
"""
|
||||
|
|
@ -42,62 +42,62 @@ def load_manual_results() -> List[Dict]:
|
|||
print("\n✗ 错误:未找到temp目录")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
|
||||
if not json_files:
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
print("\n✗ 错误:temp目录下未找到JSON文件")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
|
||||
logger.info(f"找到 {len(json_files)} 个JSON文件")
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件:")
|
||||
for json_file in json_files:
|
||||
print(f" - {os.path.basename(json_file)}")
|
||||
|
||||
|
||||
# 读取并合并所有JSON文件
|
||||
all_results = []
|
||||
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
|
||||
|
||||
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
|
||||
|
||||
# 去重:使用(situation, style)作为唯一标识
|
||||
for result in results:
|
||||
if "situation" not in result or "style" not in result:
|
||||
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
|
||||
continue
|
||||
|
||||
|
||||
pair = (result["situation"], result["style"])
|
||||
if pair not in seen_pairs:
|
||||
seen_pairs.add(pair)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
logger.info(f"从 {os.path.basename(json_file)} 加载了 {len(results)} 条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {json_file} 失败: {e}")
|
||||
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
|
||||
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
|
||||
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
|
|
@ -119,51 +119,50 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
|||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
|
@ -172,68 +171,68 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
|||
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
|
||||
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return {
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm"
|
||||
"evaluator": "llm",
|
||||
}
|
||||
|
||||
|
||||
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
|
||||
"""
|
||||
对比人工评估和LLM评估的结果
|
||||
|
||||
|
||||
Args:
|
||||
manual_results: 人工评估结果列表
|
||||
llm_results: LLM评估结果列表
|
||||
method_name: 评估方法名称(用于标识)
|
||||
|
||||
|
||||
Returns:
|
||||
对比分析结果字典
|
||||
"""
|
||||
# 按(situation, style)建立映射
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
|
||||
total = len(manual_results)
|
||||
matched = 0
|
||||
true_positives = 0
|
||||
true_negatives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
|
||||
|
||||
for manual_result in manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
|
||||
manual_suitable = manual_result["suitable"]
|
||||
llm_suitable = llm_result["suitable"]
|
||||
|
||||
|
||||
if manual_suitable == llm_suitable:
|
||||
matched += 1
|
||||
|
||||
|
||||
if manual_suitable and llm_suitable:
|
||||
true_positives += 1
|
||||
elif not manual_suitable and not llm_suitable:
|
||||
|
|
@ -242,30 +241,36 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
|||
false_positives += 1
|
||||
elif manual_suitable and not llm_suitable:
|
||||
false_negatives += 1
|
||||
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
precision = (
|
||||
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
)
|
||||
recall = (
|
||||
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
)
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
|
||||
specificity = (
|
||||
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
)
|
||||
|
||||
# 计算人工效标的不合适率
|
||||
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
||||
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
|
||||
|
||||
|
||||
# 计算经过LLM删除后剩余项目中的不合适率
|
||||
# 在所有项目中,移除LLM判定为不合适的项目后,剩下的项目 = TP + FP(LLM判定为合适的项目)
|
||||
# 在这些剩下的项目中,按人工评定的不合适项目 = FP(人工认为不合适,但LLM认为合适)
|
||||
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数(保留的项目)
|
||||
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
|
||||
|
||||
|
||||
# 两者百分比相减(评估LLM评定修正后的不合适率是否有降低)
|
||||
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
|
||||
|
||||
|
||||
random_baseline = 50.0
|
||||
accuracy_above_random = accuracy - random_baseline
|
||||
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
|
||||
|
||||
|
||||
return {
|
||||
"method": method_name,
|
||||
"total": total,
|
||||
|
|
@ -283,29 +288,29 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
|||
"specificity": specificity,
|
||||
"manual_unsuitable_rate": manual_unsuitable_rate,
|
||||
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
||||
"rate_difference": rate_difference
|
||||
"rate_difference": rate_difference,
|
||||
}
|
||||
|
||||
|
||||
async def main(count: int | None = None):
|
||||
"""
|
||||
主函数
|
||||
|
||||
|
||||
Args:
|
||||
count: 随机选取的数据条数,如果为None则使用全部数据
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式LLM评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 1. 加载人工评估结果
|
||||
print("\n步骤1: 加载人工评估结果")
|
||||
manual_results = load_manual_results()
|
||||
if not manual_results:
|
||||
return
|
||||
|
||||
|
||||
print(f"成功加载 {len(manual_results)} 条人工评估结果")
|
||||
|
||||
|
||||
# 如果指定了数量,随机选择指定数量的数据
|
||||
if count is not None:
|
||||
if count <= 0:
|
||||
|
|
@ -317,7 +322,7 @@ async def main(count: int | None = None):
|
|||
random.seed() # 使用系统时间作为随机种子
|
||||
manual_results = random.sample(manual_results, count)
|
||||
print(f"随机选取 {len(manual_results)} 条数据进行评估")
|
||||
|
||||
|
||||
# 验证数据完整性
|
||||
valid_manual_results = []
|
||||
for r in manual_results:
|
||||
|
|
@ -325,62 +330,58 @@ async def main(count: int | None = None):
|
|||
valid_manual_results.append(r)
|
||||
else:
|
||||
logger.warning(f"跳过无效数据: {r}")
|
||||
|
||||
|
||||
if len(valid_manual_results) != len(manual_results):
|
||||
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
|
||||
|
||||
|
||||
print(f"有效数据: {len(valid_manual_results)} 条")
|
||||
|
||||
|
||||
# 2. 创建LLM实例并评估
|
||||
print("\n步骤2: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_llm"
|
||||
)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
|
||||
print("\n步骤3: 开始LLM评估")
|
||||
llm_results = []
|
||||
for i, manual_result in enumerate(valid_manual_results, 1):
|
||||
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
||||
llm_results.append(await evaluate_expression_llm(
|
||||
manual_result["situation"],
|
||||
manual_result["style"],
|
||||
llm
|
||||
))
|
||||
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
# 5. 输出FP和FN项目(在评估结果之前)
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
|
||||
# 5.1 输出FP项目(人工评估不通过但LLM误判为通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
fp_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
fp_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
for idx, item in enumerate(fp_items, 1):
|
||||
|
|
@ -389,36 +390,38 @@ async def main(count: int | None = None):
|
|||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get('llm_error'):
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)")
|
||||
|
||||
|
||||
# 5.2 输出FN项目(人工评估通过但LLM误判为不通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估通过但LLM误判为不通过的项目(FN - False Negative)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
fn_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
|
||||
# 人工评估通过,但LLM评估不通过(FN情况)
|
||||
if manual_result["suitable"] and not llm_result["suitable"]:
|
||||
fn_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
fn_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fn_items:
|
||||
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
||||
for idx, item in enumerate(fn_items, 1):
|
||||
|
|
@ -427,33 +430,41 @@ async def main(count: int | None = None):
|
|||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 通过 ✅")
|
||||
print("LLM评估: 不通过 ❌ (误删)")
|
||||
if item.get('llm_error'):
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误删项目(所有人工评估通过的项目都被LLM正确识别为通过)")
|
||||
|
||||
|
||||
# 6. 对比分析并输出结果
|
||||
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果(以人工评估为标准)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 详细评估结果(核心指标优先)
|
||||
print(f"\n--- {comparison['method']} ---")
|
||||
print(f" 总数: {comparison['total']} 条")
|
||||
print()
|
||||
# print(" 【核心能力指标】")
|
||||
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
||||
print(
|
||||
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
||||
print(
|
||||
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
|
|
@ -464,12 +475,18 @@ async def main(count: int | None = None):
|
|||
print()
|
||||
print(" 【不合适率分析】")
|
||||
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
|
||||
print(
|
||||
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
|
||||
)
|
||||
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
||||
print()
|
||||
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(
|
||||
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
|
||||
)
|
||||
print()
|
||||
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
||||
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
|
|
@ -480,21 +497,22 @@ async def main(count: int | None = None):
|
|||
print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐")
|
||||
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
|
||||
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
|
||||
|
||||
|
||||
# 7. 保存结果到JSON文件
|
||||
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"manual_results": valid_manual_results,
|
||||
"llm_results": llm_results,
|
||||
"comparison": comparison
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
json.dump(
|
||||
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成")
|
||||
print("=" * 60)
|
||||
|
|
@ -509,15 +527,9 @@ if __name__ == "__main__":
|
|||
python evaluate_expressions_llm_v6.py # 使用全部数据
|
||||
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
||||
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
||||
"""
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--count",
|
||||
type=int,
|
||||
default=None,
|
||||
help="随机选取的数据条数(默认:使用全部数据)"
|
||||
)
|
||||
|
||||
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(count=args.count))
|
||||
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ from datetime import datetime
|
|||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import db
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression # noqa: E402
|
||||
from src.common.database.database import db # noqa: E402
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger("expression_evaluator_manual")
|
||||
|
||||
|
|
@ -32,13 +32,13 @@ MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
|
|||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(MANUAL_EVAL_FILE):
|
||||
return [], set()
|
||||
|
||||
|
||||
try:
|
||||
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
|
@ -55,22 +55,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
|||
def save_results(manual_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
|
||||
Args:
|
||||
manual_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(manual_results),
|
||||
"manual_results": manual_results
|
||||
"manual_results": manual_results,
|
||||
}
|
||||
|
||||
|
||||
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
|
||||
except Exception as e:
|
||||
|
|
@ -81,45 +81,43 @@ def save_results(manual_results: List[Dict]):
|
|||
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
|
||||
"""
|
||||
获取未评估的表达方式
|
||||
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目(situation, style)元组集合
|
||||
batch_size: 每次获取的数量
|
||||
|
||||
|
||||
Returns:
|
||||
未评估的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
|
||||
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.info("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
|
||||
# 如果未评估数量少于请求数量,返回所有
|
||||
if len(unevaluated) <= batch_size:
|
||||
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
|
||||
return unevaluated
|
||||
|
||||
|
||||
# 随机选择指定数量
|
||||
selected = random.sample(unevaluated, batch_size)
|
||||
logger.info(f"从 {len(unevaluated)} 条未评估项目中随机选择了 {len(selected)} 条")
|
||||
return selected
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取未评估表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
|
@ -127,12 +125,12 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
|
|||
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
|
||||
"""
|
||||
人工评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
index: 当前索引(从1开始)
|
||||
total: 总数
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典,如果用户退出则返回 None
|
||||
"""
|
||||
|
|
@ -146,38 +144,38 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
|
|||
print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)")
|
||||
print(" 输入 'q' 或 'quit' 退出评估")
|
||||
print(" 输入 's' 或 'skip' 跳过当前项目")
|
||||
|
||||
|
||||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
||||
|
||||
if user_input in ['q', 'quit']:
|
||||
|
||||
if user_input in ["q", "quit"]:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ['s', 'skip']:
|
||||
|
||||
if user_input in ["s", "skip"]:
|
||||
print("跳过当前项目")
|
||||
return "skip"
|
||||
|
||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
||||
|
||||
if user_input in ["y", "yes", "1", "是", "通过"]:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
||||
elif user_input in ["n", "no", "0", "否", "不通过"]:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n/q/s)")
|
||||
|
||||
|
||||
result = {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -186,7 +184,7 @@ def main():
|
|||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式人工评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
|
|
@ -194,41 +192,41 @@ def main():
|
|||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
manual_results = existing_results.copy()
|
||||
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("开始人工评估")
|
||||
print("=" * 60)
|
||||
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
|
||||
batch_size = 10
|
||||
batch_count = 0
|
||||
|
||||
|
||||
while True:
|
||||
# 获取未评估的项目
|
||||
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
|
||||
|
||||
|
||||
if not expressions:
|
||||
print("\n" + "=" * 60)
|
||||
print("所有项目都已评估完成!")
|
||||
print("=" * 60)
|
||||
break
|
||||
|
||||
|
||||
batch_count += 1
|
||||
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
|
||||
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
manual_result = manual_evaluate_expression(expression, i, len(expressions))
|
||||
|
||||
|
||||
if manual_result is None:
|
||||
# 用户退出
|
||||
print("\n评估已中断")
|
||||
|
|
@ -237,34 +235,34 @@ def main():
|
|||
manual_results.extend(batch_results)
|
||||
save_results(manual_results)
|
||||
return
|
||||
|
||||
|
||||
if manual_result == "skip":
|
||||
# 跳过当前项目
|
||||
continue
|
||||
|
||||
|
||||
batch_results.append(manual_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
|
||||
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
manual_results.extend(batch_results)
|
||||
|
||||
|
||||
# 保存结果
|
||||
save_results(manual_results)
|
||||
|
||||
|
||||
print(f"\n当前批次完成,已评估总数: {len(manual_results)} 条")
|
||||
|
||||
|
||||
# 询问是否继续
|
||||
while True:
|
||||
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
||||
if continue_input in ['y', 'yes', '1', '是', '继续']:
|
||||
if continue_input in ["y", "yes", "1", "是", "继续"]:
|
||||
break
|
||||
elif continue_input in ['n', 'no', '0', '否', '退出']:
|
||||
elif continue_input in ["n", "no", "0", "否", "退出"]:
|
||||
print("\n评估结束")
|
||||
return
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n)")
|
||||
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
|
|
@ -275,4 +273,3 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -134,9 +134,7 @@ def handle_import_openie(
|
|||
# 在非交互模式下,不再询问用户,而是直接报错终止
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
|
||||
)
|
||||
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
|
||||
sys.exit(1)
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
|
|
@ -189,9 +187,7 @@ def handle_import_openie(
|
|||
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
|
||||
)
|
||||
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
|
||||
else:
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
|
|
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
|
|||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
"""主函数 - 解析参数并运行异步主流程。"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
|
||||
"将其导入到 LPMM 的向量库与知识图中。"
|
||||
)
|
||||
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
|
|
|
|||
|
|
@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
|
|||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
|
||||
)
|
||||
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
|
||||
else:
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
# 保证可以导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
|
@ -32,7 +31,6 @@ def main() -> None:
|
|||
# KG 统计
|
||||
nodes = kg.graph.get_node_list()
|
||||
edges = kg.graph.get_edge_list()
|
||||
node_set: Set[str] = set(nodes)
|
||||
|
||||
para_nodes = [n for n in nodes if n.startswith("paragraph-")]
|
||||
ent_nodes = [n for n in nodes if n.startswith("entity-")]
|
||||
|
|
@ -68,4 +66,3 @@ def main() -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ except ImportError as e:
|
|||
|
||||
logger = get_logger("lpmm_interactive_manager")
|
||||
|
||||
|
||||
async def interactive_add():
|
||||
"""交互式导入知识"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -38,7 +39,7 @@ async def interactive_add():
|
|||
print(" - 支持多段落,段落间请保留空行。")
|
||||
print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -48,7 +49,7 @@ async def interactive_add():
|
|||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
text = "\n".join(lines).strip()
|
||||
if not text:
|
||||
print("\n[!] 内容为空,操作已取消。")
|
||||
|
|
@ -58,7 +59,7 @@ async def interactive_add():
|
|||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.add_content(text)
|
||||
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
print(f" 实际新增段落数: {result.get('count', 0)}")
|
||||
|
|
@ -68,6 +69,7 @@ async def interactive_add():
|
|||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"add_content 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_delete():
|
||||
"""交互式删除知识"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -77,10 +79,10 @@ async def interactive_delete():
|
|||
print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)")
|
||||
print(" 2. 完整文段匹配(删除完全匹配的段落)")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
mode = input("请选择删除模式 (1/2): ").strip()
|
||||
exact_match = False
|
||||
|
||||
|
||||
if mode == "2":
|
||||
exact_match = True
|
||||
print("\n[完整文段匹配模式]")
|
||||
|
|
@ -102,14 +104,18 @@ async def interactive_delete():
|
|||
print("\n[!] 无效选择,默认使用关键词模糊匹配模式。")
|
||||
print("\n[关键词模糊匹配模式]")
|
||||
keyword = input("请输入匹配关键词: ").strip()
|
||||
|
||||
|
||||
if not keyword:
|
||||
print("\n[!] 输入为空,操作已取消。")
|
||||
return
|
||||
|
||||
|
||||
print("-" * 40)
|
||||
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
|
||||
if confirm != 'y':
|
||||
confirm = (
|
||||
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if confirm != "y":
|
||||
print("\n[!] 已取消删除操作。")
|
||||
return
|
||||
|
||||
|
|
@ -117,7 +123,7 @@ async def interactive_delete():
|
|||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.delete(keyword, exact_match=exact_match)
|
||||
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
print(f" 删除条数: {result.get('deleted_count', 0)}")
|
||||
|
|
@ -129,6 +135,7 @@ async def interactive_delete():
|
|||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"delete 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_clear():
|
||||
"""交互式清空知识库"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -141,40 +148,45 @@ async def interactive_clear():
|
|||
print(" - 整个知识图谱")
|
||||
print(" - 此操作不可恢复!")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
# 双重确认
|
||||
confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip()
|
||||
if confirm1 != "YES":
|
||||
print("\n[!] 已取消清空操作。")
|
||||
return
|
||||
|
||||
|
||||
print("\n" + "=" * 40)
|
||||
confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip()
|
||||
if confirm2 != "CLEAR":
|
||||
print("\n[!] 已取消清空操作。")
|
||||
return
|
||||
|
||||
|
||||
print("\n[进度] 正在清空知识库...")
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.clear_all()
|
||||
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
stats = result.get("stats", {})
|
||||
before = stats.get("before", {})
|
||||
after = stats.get("after", {})
|
||||
print("\n[统计信息]")
|
||||
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
|
||||
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
|
||||
print(
|
||||
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
|
||||
)
|
||||
print(
|
||||
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
|
||||
)
|
||||
else:
|
||||
print(f"\n[×] 失败:{result['message']}")
|
||||
except Exception as e:
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"clear_all 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_search():
|
||||
"""交互式查询知识"""
|
||||
print("\n" + "=" * 40)
|
||||
|
|
@ -182,25 +194,25 @@ async def interactive_search():
|
|||
print("=" * 40)
|
||||
print("说明:输入查询问题或关键词,系统会返回相关的知识段落。")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
# 确保 LPMM 已初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
print("\n[!] 警告:LPMM 知识库在配置中未启用。")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
lpmm_start_up()
|
||||
except Exception as e:
|
||||
print(f"\n[!] LPMM 初始化失败: {e}")
|
||||
logger.error(f"LPMM 初始化失败: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
|
||||
query = input("请输入查询问题或关键词: ").strip()
|
||||
|
||||
|
||||
if not query:
|
||||
print("\n[!] 查询内容为空,操作已取消。")
|
||||
return
|
||||
|
||||
|
||||
# 询问返回条数
|
||||
print("-" * 40)
|
||||
limit_str = input("希望返回的相关知识条数(默认3,直接回车使用默认值): ").strip()
|
||||
|
|
@ -210,11 +222,11 @@ async def interactive_search():
|
|||
except ValueError:
|
||||
limit = 3
|
||||
print("[!] 输入无效,使用默认值 3。")
|
||||
|
||||
|
||||
print("\n[进度] 正在查询知识库...")
|
||||
try:
|
||||
result = await query_lpmm_knowledge(query, limit=limit)
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[查询结果]")
|
||||
print("=" * 60)
|
||||
|
|
@ -224,6 +236,7 @@ async def interactive_search():
|
|||
print(f"\n[×] 查询失败: {e}")
|
||||
logger.error(f"查询异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主循环"""
|
||||
while True:
|
||||
|
|
@ -236,9 +249,9 @@ async def main():
|
|||
print("║ 4. 清空知识库 (Clear All) ⚠️ ║")
|
||||
print("║ 0. 退出 (Exit) ║")
|
||||
print("╚" + "═" * 38 + "╝")
|
||||
|
||||
|
||||
choice = input("请选择操作编号: ").strip()
|
||||
|
||||
|
||||
if choice == "1":
|
||||
await interactive_add()
|
||||
elif choice == "2":
|
||||
|
|
@ -253,6 +266,7 @@ async def main():
|
|||
else:
|
||||
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 运行主循环
|
||||
|
|
@ -262,4 +276,3 @@ if __name__ == "__main__":
|
|||
except Exception as e:
|
||||
print(f"\n[!] 程序运行出错: {e}")
|
||||
logger.error(f"Main loop 异常: {e}", exc_info=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,18 +21,18 @@ PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
|
|||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from src.common.logger import get_logger # type: ignore
|
||||
from src.config.config import global_config, model_config # type: ignore
|
||||
from src.common.logger import get_logger # type: ignore # noqa: E402
|
||||
from src.config.config import global_config, model_config # type: ignore # noqa: E402
|
||||
|
||||
# 引入各功能脚本的入口函数
|
||||
from import_openie import main as import_openie_main # type: ignore
|
||||
from info_extraction import main as info_extraction_main # type: ignore
|
||||
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore
|
||||
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore
|
||||
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore
|
||||
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore
|
||||
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore
|
||||
from raw_data_preprocessor import load_raw_data # type: ignore
|
||||
from import_openie import main as import_openie_main # type: ignore # noqa: E402
|
||||
from info_extraction import main as info_extraction_main # type: ignore # noqa: E402
|
||||
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore # noqa: E402
|
||||
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore # noqa: E402
|
||||
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore # noqa: E402
|
||||
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore # noqa: E402
|
||||
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore # noqa: E402
|
||||
from raw_data_preprocessor import load_raw_data # type: ignore # noqa: E402
|
||||
|
||||
|
||||
logger = get_logger("lpmm_manager")
|
||||
|
|
@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
|
|||
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
||||
txt_files = list(raw_dir.glob("*.txt"))
|
||||
if not txt_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
|
||||
"info_extraction 可能立即退出或无数据可处理。"
|
||||
)
|
||||
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,info_extraction 可能立即退出或无数据可处理。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
|
||||
)
|
||||
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
|
||||
return False
|
||||
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
|
|
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
|
|||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
json_files = list(openie_dir.glob("*.json"))
|
||||
if not json_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
|
||||
"import_openie 可能会因为找不到批次而失败。"
|
||||
)
|
||||
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,import_openie 可能会因为找不到批次而失败。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
|
||||
)
|
||||
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
|
||||
return False
|
||||
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
|
|
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
|
|||
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
||||
try:
|
||||
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
||||
print(
|
||||
"[WARN] 当前配置 lpmm_knowledge.enable = false,"
|
||||
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
|
||||
)
|
||||
print("[WARN] 当前配置 lpmm_knowledge.enable = false,刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
|
||||
except Exception:
|
||||
# 配置异常时不阻断主流程,仅忽略提示
|
||||
pass
|
||||
|
|
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
|||
if action == "prepare_raw":
|
||||
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
elif action == "info_extract":
|
||||
if not _check_before_info_extract("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行信息提取。")
|
||||
|
|
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
|||
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
||||
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
non_interactive = "--non-interactive" in extra_args
|
||||
if not _check_before_info_extract(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
||||
|
|
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
|
|||
)
|
||||
|
||||
# 快速选项:按推荐方式清理所有相关实体/关系
|
||||
quick_all = input(
|
||||
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
|
||||
).strip().lower()
|
||||
quick_all = (
|
||||
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
|
||||
)
|
||||
if quick_all in ("", "y", "yes"):
|
||||
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
||||
else:
|
||||
|
|
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
|
|||
|
||||
def _interactive_build_batch_inspect_args() -> List[str]:
|
||||
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
||||
path = _interactive_choose_openie_file(
|
||||
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
|
||||
)
|
||||
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
|
||||
if not path:
|
||||
return []
|
||||
return ["--openie-file", path]
|
||||
|
|
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
|
|||
|
||||
def _interactive_build_test_args() -> List[str]:
|
||||
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
||||
print(
|
||||
"\n[TEST] 你可以:\n"
|
||||
"- 直接回车使用内置的默认测试用例;\n"
|
||||
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
|
||||
)
|
||||
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
|
||||
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
||||
if not query:
|
||||
return []
|
||||
|
|
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
|
|||
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
||||
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
||||
|
||||
new_dim = input(
|
||||
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
|
||||
).strip()
|
||||
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
|
||||
if new_dim and not new_dim.isdigit():
|
||||
print("输入的维度不是纯数字,已取消操作。")
|
||||
return
|
||||
|
|
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,53 +28,55 @@ from maim_message import UserInfo, GroupInfo
|
|||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
existing_module._process_single_question,
|
||||
)
|
||||
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
|
|
@ -89,11 +91,11 @@ def _import_memory_retrieval():
|
|||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
|
|
@ -126,16 +128,16 @@ def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStrea
|
|||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
|
|
@ -150,21 +152,21 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
|||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
|
|
@ -180,7 +182,7 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
|||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
|
|
@ -205,25 +207,25 @@ def format_thinking_steps(thinking_steps: list) -> str:
|
|||
"""格式化思考步骤为可读字符串"""
|
||||
if not thinking_steps:
|
||||
return "无思考步骤"
|
||||
|
||||
|
||||
lines = []
|
||||
for step in thinking_steps:
|
||||
iteration = step.get("iteration", "?")
|
||||
thought = step.get("thought", "")
|
||||
actions = step.get("actions", [])
|
||||
observations = step.get("observations", [])
|
||||
|
||||
|
||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
||||
if thought:
|
||||
lines.append(f"思考: {thought[:200]}...")
|
||||
|
||||
|
||||
if actions:
|
||||
lines.append("行动:")
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "unknown")
|
||||
action_params = action.get("action_params", {})
|
||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
||||
|
||||
|
||||
if observations:
|
||||
lines.append("观察:")
|
||||
for obs in observations:
|
||||
|
|
@ -231,7 +233,7 @@ def format_thinking_steps(thinking_steps: list) -> str:
|
|||
if len(str(obs)) > 200:
|
||||
obs_str += "..."
|
||||
lines.append(f" - {obs_str}")
|
||||
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
|
|
@ -242,31 +244,32 @@ async def test_memory_retrieval(
|
|||
max_iterations: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试记忆检索功能
|
||||
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
|
||||
Returns:
|
||||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[测试] 记忆检索测试")
|
||||
print("[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
||||
try:
|
||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
||||
|
||||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
|
|
@ -274,24 +277,24 @@ async def test_memory_retrieval(
|
|||
except Exception as e:
|
||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# 获取 global_config(此时应该已经加载)
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print(f"\n[配置]")
|
||||
|
||||
print("\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
||||
|
||||
# 执行检索
|
||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
|
|
@ -299,14 +302,14 @@ async def test_memory_retrieval(
|
|||
timeout=timeout,
|
||||
initial_info="",
|
||||
)
|
||||
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
|
||||
# 获取token使用情况
|
||||
token_usage = get_token_usage_since(start_time)
|
||||
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"question": question,
|
||||
|
|
@ -318,41 +321,41 @@ async def test_memory_retrieval(
|
|||
"iteration_count": len(thinking_steps),
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
|
||||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print(f"\n[结果]")
|
||||
print("\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(f" 答案: (未找到答案)")
|
||||
print(" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print(f"\n[Token使用情况]")
|
||||
|
||||
print("\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage['model_usage']:
|
||||
print(f"\n[按模型统计]")
|
||||
for model_name, usage in token_usage['model_usage'].items():
|
||||
|
||||
if token_usage["model_usage"]:
|
||||
print("\n[按模型统计]")
|
||||
for model_name, usage in token_usage["model_usage"].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print(f"\n[迭代详情]")
|
||||
|
||||
print("\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -375,12 +378,12 @@ def main() -> None:
|
|||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("记忆检索测试工具")
|
||||
|
|
@ -389,7 +392,7 @@ def main() -> None:
|
|||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
|
||||
# 交互式输入最大迭代次数
|
||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
||||
max_iterations = None
|
||||
|
|
@ -402,7 +405,7 @@ def main() -> None:
|
|||
except ValueError:
|
||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
||||
max_iterations = None
|
||||
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
|
|
@ -410,7 +413,7 @@ def main() -> None:
|
|||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 运行测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
|
|
@ -421,7 +424,7 @@ def main() -> None:
|
|||
max_iterations=max_iterations,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
|
|
@ -429,7 +432,7 @@ def main() -> None:
|
|||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
|
|
@ -444,4 +447,3 @@ def main() -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -455,6 +455,7 @@ class ExpressionSelector:
|
|||
expr_obj.save()
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
|
|||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
class JargonExplainer:
|
||||
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
||||
|
||||
|
|
|
|||
|
|
@ -60,31 +60,31 @@ def calculate_style_similarity(style1: str, style2: str) -> float:
|
|||
"""
|
||||
计算两个 style 的相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py)
|
||||
|
||||
|
||||
Args:
|
||||
style1: 第一个 style
|
||||
style2: 第二个 style
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
if not style1 or not style2:
|
||||
return 0.0
|
||||
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
|
||||
cleaned_style1 = remove_ignored_words(style1)
|
||||
cleaned_style2 = remove_ignored_words(style2)
|
||||
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_style1 or not cleaned_style2:
|
||||
return 0.0
|
||||
|
||||
|
||||
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
|
||||
|
||||
|
||||
|
|
@ -495,4 +495,4 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
|
|||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
return expressions, jargon_entries
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import time
|
|||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
|
@ -119,9 +118,7 @@ class MessageRecorder:
|
|||
|
||||
# 触发 expression_learner 和 jargon_miner 的处理
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(messages)
|
||||
)
|
||||
asyncio.create_task(self._trigger_expression_learning(messages))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
|
|
@ -130,9 +127,7 @@ class MessageRecorder:
|
|||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self, messages: List[Any]
|
||||
) -> None:
|
||||
async def _trigger_expression_learning(self, messages: List[Any]) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import time
|
||||
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
|
|
@ -120,7 +120,7 @@ class ActionPlanner:
|
|||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
|
|
@ -128,7 +128,7 @@ class ActionPlanner:
|
|||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
|
@ -170,13 +170,10 @@ class ActionPlanner:
|
|||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
|
||||
)
|
||||
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
|
|
|
|||
|
|
@ -112,10 +112,10 @@ class Conversation:
|
|||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||
"platform": msg.user_info.platform if msg.user_info else "",
|
||||
}
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages_dict
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
|
|
|
|||
|
|
@ -66,9 +66,9 @@ class DirectMessageSender:
|
|||
|
||||
# 发送消息(直接调用底层 API)
|
||||
from src.chat.message_receive.uni_message_sender import _send_message
|
||||
|
||||
|
||||
sent = await _send_message(message, show_log=True)
|
||||
|
||||
|
||||
if sent:
|
||||
# 存储消息
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from src.common.logger import get_logger
|
|||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
|
|
@ -13,15 +13,15 @@ logger = get_logger("observation_info")
|
|||
|
||||
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
|
||||
"""Convert PFC dict format to DatabaseMessages object
|
||||
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
|
||||
Returns:
|
||||
DatabaseMessages object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
|
||||
|
||||
return DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", 0.0),
|
||||
|
|
|
|||
|
|
@ -42,9 +42,7 @@ class GoalAnalyzer:
|
|||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
|
|
@ -60,7 +58,7 @@ class GoalAnalyzer:
|
|||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
|
|
@ -68,7 +66,7 @@ class GoalAnalyzer:
|
|||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
from typing import List, Tuple, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import Message
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge import qa_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
|
||||
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
|
||||
|
|
@ -16,9 +14,7 @@ class KnowledgeFetcher:
|
|||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
|
|
@ -50,13 +46,7 @@ class KnowledgeFetcher:
|
|||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
db_messages = [dict_to_database_message(m) for m in chat_history]
|
||||
chat_history_text = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
_ = chat_history
|
||||
|
||||
# NOTE: Hippocampus memory system was redesigned in v0.12.2
|
||||
# The old get_memory_from_text API no longer exists
|
||||
|
|
@ -64,7 +54,7 @@ class KnowledgeFetcher:
|
|||
# TODO: Integrate with new memory system if needed
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
|
||||
|
||||
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
|
||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
# text=f"{query}\n{chat_history_text}",
|
||||
|
|
|
|||
|
|
@ -14,10 +14,7 @@ class ReplyChecker:
|
|||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="reply_check"
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
|
|
@ -27,7 +24,7 @@ class ReplyChecker:
|
|||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
|
|
@ -35,7 +32,7 @@ class ReplyChecker:
|
|||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class ReplyGenerator:
|
|||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
|
|
@ -107,7 +107,7 @@ class ReplyGenerator:
|
|||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
|
|
|||
|
|
@ -704,10 +704,7 @@ class BrainChatting:
|
|||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._new_message_event.wait(),
|
||||
timeout=wait_seconds
|
||||
)
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -731,7 +728,9 @@ class BrainChatting:
|
|||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)"
|
||||
)
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
|
@ -749,10 +748,7 @@ class BrainChatting:
|
|||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._new_message_event.wait(),
|
||||
timeout=wait_seconds
|
||||
)
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
|
|
|
|||
|
|
@ -431,15 +431,21 @@ class BrainPlanner:
|
|||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return extracted_reasoning, [
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
return (
|
||||
extracted_reasoning,
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class EmbeddingStore:
|
|||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
|
||||
|
||||
|
||||
self.dirty = False # 标记是否有新增数据需要重建索引
|
||||
|
||||
# 多线程配置参数验证和设置
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union, Dict, Any
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
|
|
@ -192,17 +192,15 @@ class IEProcess:
|
|||
|
||||
results = []
|
||||
total = len(paragraphs)
|
||||
|
||||
|
||||
for i, pg in enumerate(paragraphs, start=1):
|
||||
# 打印进度日志,让用户知道没有卡死
|
||||
logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...")
|
||||
|
||||
|
||||
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
|
||||
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
|
||||
try:
|
||||
entities, triples = await asyncio.to_thread(
|
||||
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
|
||||
)
|
||||
entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
|
||||
|
||||
if entities is not None:
|
||||
results.append(
|
||||
|
|
|
|||
|
|
@ -395,8 +395,7 @@ class KGManager:
|
|||
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
||||
if not appear_cnt or appear_cnt <= 0:
|
||||
logger.debug(
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,"
|
||||
f"将使用 1.0 作为默认出现次数参与权重计算"
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,将使用 1.0 作为默认出现次数参与权重计算"
|
||||
)
|
||||
appear_cnt = 1.0
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
||||
|
|
|
|||
|
|
@ -11,31 +11,30 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
|||
|
||||
logger = get_logger("LPMM-Plugin-API")
|
||||
|
||||
|
||||
class LPMMOperations:
|
||||
"""
|
||||
LPMM 内部操作接口。
|
||||
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
async def _run_cancellable_executor(
|
||||
self, func: Callable, *args, **kwargs
|
||||
) -> Any:
|
||||
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
在线程池中执行可取消的同步操作。
|
||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
||||
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
|
||||
|
||||
|
||||
Args:
|
||||
func: 要执行的同步函数
|
||||
*args: 函数的位置参数
|
||||
**kwargs: 函数的关键字参数
|
||||
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: 当任务被取消时
|
||||
"""
|
||||
|
|
@ -51,42 +50,42 @@ class LPMMOperations:
|
|||
# 如果全局没初始化,尝试初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
|
||||
|
||||
|
||||
lpmm_start_up()
|
||||
qa_mgr = get_qa_manager()
|
||||
|
||||
|
||||
if qa_mgr is None:
|
||||
raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。")
|
||||
|
||||
|
||||
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
|
||||
|
||||
async def add_content(self, text: str, auto_split: bool = True) -> dict:
|
||||
"""
|
||||
向知识库添加新内容。
|
||||
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
auto_split: 是否自动按双换行符分割段落。
|
||||
- True: 自动分割(默认),支持多段文本(用双换行分隔)
|
||||
- False: 不分割,将整个文本作为完整一段处理
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
|
||||
# 1. 分段处理
|
||||
if auto_split:
|
||||
# 自动按双换行符分割
|
||||
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
else:
|
||||
# 不分割,作为完整一段
|
||||
text_stripped = text.strip()
|
||||
if not text_stripped:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
paragraphs = [text_stripped]
|
||||
|
||||
|
||||
if not paragraphs:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
|
||||
|
|
@ -94,14 +93,16 @@ class LPMMOperations:
|
|||
from src.chat.knowledge.ie_process import IEProcess
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
|
||||
|
||||
llm_ner = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
|
||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
||||
extracted_docs = await ie_process.process_paragraphs(paragraphs)
|
||||
|
||||
|
||||
# 3. 构造并导入数据
|
||||
# 这里我们手动实现导入逻辑,不依赖外部脚本
|
||||
# a. 准备段落
|
||||
|
|
@ -115,7 +116,7 @@ class LPMMOperations:
|
|||
# store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本
|
||||
new_raw_paragraphs = {}
|
||||
new_triple_list_data = {}
|
||||
|
||||
|
||||
for pg_hash, passage in raw_paragraphs.items():
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_mgr.stored_pg_hashes:
|
||||
|
|
@ -128,26 +129,22 @@ class LPMMOperations:
|
|||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
||||
# store_new_data_set 会自动处理嵌入生成和存储
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
await self._run_cancellable_executor(
|
||||
embed_mgr.store_new_data_set,
|
||||
new_raw_paragraphs,
|
||||
new_triple_list_data
|
||||
)
|
||||
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
||||
|
||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
||||
await self._run_cancellable_executor(
|
||||
kg_mgr.build_kg,
|
||||
new_triple_list_data,
|
||||
embed_mgr
|
||||
)
|
||||
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
||||
|
||||
# 4. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(new_raw_paragraphs),
|
||||
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
||||
return {"status": "cancelled", "message": "导入操作已被用户中断"}
|
||||
|
|
@ -158,11 +155,11 @@ class LPMMOperations:
|
|||
async def search(self, query: str, top_k: int = 3) -> List[str]:
|
||||
"""
|
||||
检索知识库。
|
||||
|
||||
|
||||
Args:
|
||||
query: 查询问题。
|
||||
top_k: 返回最相关的条目数。
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 相关文段列表。
|
||||
"""
|
||||
|
|
@ -179,21 +176,21 @@ class LPMMOperations:
|
|||
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
|
||||
"""
|
||||
根据关键词或完整文段删除知识库内容。
|
||||
|
||||
|
||||
Args:
|
||||
keyword: 匹配关键词或完整文段。
|
||||
exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
|
||||
# 1. 查找匹配的段落
|
||||
to_delete_keys = []
|
||||
to_delete_hashes = []
|
||||
|
||||
|
||||
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
|
||||
if exact_match:
|
||||
# 完整文段匹配
|
||||
|
|
@ -205,29 +202,25 @@ class LPMMOperations:
|
|||
if keyword in item.str:
|
||||
to_delete_keys.append(key)
|
||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
||||
|
||||
|
||||
if not to_delete_keys:
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
|
||||
|
||||
# 2. 执行删除
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
|
||||
# a. 从向量库删除
|
||||
deleted_count, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
||||
to_delete_keys
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
|
||||
|
||||
# b. 从知识图谱删除
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs,
|
||||
to_delete_hashes,
|
||||
ent_hashes=None,
|
||||
remove_orphan_entities=True
|
||||
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
|
|
@ -235,9 +228,13 @@ class LPMMOperations:
|
|||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
||||
|
|
@ -249,13 +246,13 @@ class LPMMOperations:
|
|||
async def clear_all(self) -> dict:
|
||||
"""
|
||||
清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "message": "描述", "stats": {...}}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
|
||||
# 记录清空前的统计信息
|
||||
before_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
|
|
@ -264,40 +261,37 @@ class LPMMOperations:
|
|||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
|
||||
# 1. 清空所有向量库
|
||||
# 获取所有keys
|
||||
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
|
||||
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
|
||||
|
||||
|
||||
# 删除所有段落向量
|
||||
para_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
||||
para_keys
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes.clear()
|
||||
|
||||
|
||||
# 删除所有实体向量
|
||||
if ent_keys:
|
||||
ent_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.entities_embedding_store.delete_items,
|
||||
ent_keys
|
||||
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
||||
)
|
||||
else:
|
||||
ent_deleted = 0
|
||||
|
||||
|
||||
# 删除所有关系向量
|
||||
if rel_keys:
|
||||
rel_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.relation_embedding_store.delete_items,
|
||||
rel_keys
|
||||
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
||||
)
|
||||
else:
|
||||
rel_deleted = 0
|
||||
|
||||
|
||||
# 2. 清空所有 embedding store 的索引和映射
|
||||
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
|
||||
def _clear_embedding_indices():
|
||||
|
|
@ -310,7 +304,7 @@ class LPMMOperations:
|
|||
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
|
||||
|
||||
|
||||
# 清空实体索引
|
||||
embed_mgr.entities_embedding_store.faiss_index = None
|
||||
embed_mgr.entities_embedding_store.idx2hash = None
|
||||
|
|
@ -320,7 +314,7 @@ class LPMMOperations:
|
|||
os.remove(embed_mgr.entities_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
|
||||
|
||||
|
||||
# 清空关系索引
|
||||
embed_mgr.relation_embedding_store.faiss_index = None
|
||||
embed_mgr.relation_embedding_store.idx2hash = None
|
||||
|
|
@ -330,9 +324,9 @@ class LPMMOperations:
|
|||
os.remove(embed_mgr.relation_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
|
||||
|
||||
|
||||
await self._run_cancellable_executor(_clear_embedding_indices)
|
||||
|
||||
|
||||
# 3. 清空知识图谱
|
||||
# 获取所有段落hash
|
||||
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
|
||||
|
|
@ -341,24 +335,22 @@ class LPMMOperations:
|
|||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs,
|
||||
all_pg_hashes,
|
||||
ent_hashes=None,
|
||||
remove_orphan_entities=True
|
||||
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
|
||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
||||
from quick_algo import di_graph
|
||||
|
||||
kg_mgr.graph = di_graph.DiGraph()
|
||||
kg_mgr.stored_paragraph_hashes.clear()
|
||||
kg_mgr.ent_appear_cnt.clear()
|
||||
|
||||
|
||||
# 4. 保存所有数据(此时所有store都是空的,索引也是None)
|
||||
# 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
|
||||
after_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
||||
|
|
@ -366,14 +358,14 @@ class LPMMOperations:
|
|||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
|
||||
"stats": {
|
||||
"before": before_stats,
|
||||
"after": after_stats,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
|
|
@ -383,6 +375,6 @@ class LPMMOperations:
|
|||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# 内部使用的单例
|
||||
lpmm_ops = LPMMOperations()
|
||||
|
||||
|
|
|
|||
|
|
@ -136,4 +136,3 @@ class PlanReplyLogger:
|
|||
return str(value)
|
||||
# Fallback to string for other complex types
|
||||
return str(value)
|
||||
|
||||
|
|
|
|||
|
|
@ -85,17 +85,17 @@ class ChatBot:
|
|||
|
||||
async def _create_pfc_chat(self, message: MessageRecv):
|
||||
"""创建或获取PFC对话实例
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
chat_id = str(message.chat_stream.stream_id)
|
||||
private_name = str(message.message_info.user_info.user_nickname)
|
||||
|
||||
|
||||
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
|
||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class Message(MessageBase):
|
|||
if processed_text:
|
||||
return f"{global_config.bot.nickname}: {processed_text}"
|
||||
return None
|
||||
|
||||
|
||||
tasks = [process_forward_node(node_dict) for node_dict in segment.data]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
segments_text = []
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
|||
|
||||
# 如果未开启 API Server,直接跳过 Fallback
|
||||
if not global_config.maim_message.enable_api_server:
|
||||
logger.debug(f"[API Server Fallback] API Server未开启,跳过fallback")
|
||||
logger.debug("[API Server Fallback] API Server未开启,跳过fallback")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
|
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
|||
extra_server = getattr(global_api, "extra_server", None)
|
||||
|
||||
if not extra_server:
|
||||
logger.warning(f"[API Server Fallback] extra_server不存在")
|
||||
logger.warning("[API Server Fallback] extra_server不存在")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
if not extra_server.is_running():
|
||||
logger.warning(f"[API Server Fallback] extra_server未运行")
|
||||
logger.warning("[API Server Fallback] extra_server未运行")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
|
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
|||
)
|
||||
|
||||
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...")
|
||||
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
|
||||
results = await extra_server.send_message(api_message)
|
||||
logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ logger = get_logger("planner")
|
|||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
|
|
@ -48,7 +49,7 @@ class ActionPlanner:
|
|||
self.last_obs_time_mark = 0.0
|
||||
|
||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||
|
||||
|
||||
# 黑话缓存:使用 OrderedDict 实现 LRU,最多缓存10个
|
||||
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
|
||||
self.unknown_words_cache_limit = 10
|
||||
|
|
@ -111,20 +112,29 @@ class ActionPlanner:
|
|||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
|
||||
platform = (
|
||||
getattr(message, "user_info", None)
|
||||
and message.user_info.platform
|
||||
or getattr(message, "chat_info", None)
|
||||
and message.chat_info.platform
|
||||
or "qq"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
|
|
@ -137,6 +147,7 @@ class ActionPlanner:
|
|||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
|
|
@ -165,7 +176,7 @@ class ActionPlanner:
|
|||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
|
||||
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
|
|
@ -244,7 +255,7 @@ class ActionPlanner:
|
|||
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
|
||||
"""
|
||||
更新黑话缓存,将新的黑话加入缓存
|
||||
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表
|
||||
"""
|
||||
|
|
@ -254,7 +265,7 @@ class ActionPlanner:
|
|||
word = word.strip()
|
||||
if not word:
|
||||
continue
|
||||
|
||||
|
||||
# 如果已存在,移到末尾(LRU)
|
||||
if word in self.unknown_words_cache:
|
||||
self.unknown_words_cache.move_to_end(word)
|
||||
|
|
@ -269,10 +280,10 @@ class ActionPlanner:
|
|||
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
|
||||
"""
|
||||
合并新提取的黑话和缓存中的黑话
|
||||
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表(可能为None)
|
||||
|
||||
|
||||
Returns:
|
||||
合并后的黑话列表(去重)
|
||||
"""
|
||||
|
|
@ -284,31 +295,29 @@ class ActionPlanner:
|
|||
word = word.strip()
|
||||
if word:
|
||||
cleaned_new_words.append(word)
|
||||
|
||||
|
||||
# 获取缓存中的黑话列表
|
||||
cached_words = list(self.unknown_words_cache.keys())
|
||||
|
||||
|
||||
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
|
||||
merged_words: List[str] = []
|
||||
seen = set()
|
||||
|
||||
|
||||
# 先添加新提取的
|
||||
for word in cleaned_new_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
|
||||
# 再添加缓存的(如果不在新提取的列表中)
|
||||
for word in cached_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
|
||||
return merged_words
|
||||
|
||||
def _process_unknown_words_cache(
|
||||
self, actions: List[ActionPlannerInfo]
|
||||
) -> None:
|
||||
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
|
||||
"""
|
||||
处理黑话缓存逻辑:
|
||||
1. 检查是否有 reply action 提取了 unknown_words
|
||||
|
|
@ -316,7 +325,7 @@ class ActionPlanner:
|
|||
3. 如果缓存数量大于5,移除最老的2个
|
||||
4. 对于每个 reply action,合并缓存和新提取的黑话
|
||||
5. 更新缓存
|
||||
|
||||
|
||||
Args:
|
||||
actions: 解析后的动作列表
|
||||
"""
|
||||
|
|
@ -330,7 +339,7 @@ class ActionPlanner:
|
|||
removed_count += 1
|
||||
if removed_count > 0:
|
||||
logger.debug(f"{self.log_prefix}缓存数量大于5,移除最老的{removed_count}个缓存")
|
||||
|
||||
|
||||
# 检查是否有 reply action 提取了 unknown_words
|
||||
has_extracted_unknown_words = False
|
||||
for action in actions:
|
||||
|
|
@ -340,22 +349,22 @@ class ActionPlanner:
|
|||
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
|
||||
has_extracted_unknown_words = True
|
||||
break
|
||||
|
||||
|
||||
# 如果当前 plan 的 reply 没有提取,移除最老的1个
|
||||
if not has_extracted_unknown_words:
|
||||
if len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存")
|
||||
|
||||
|
||||
# 对于每个 reply action,合并缓存和新提取的黑话
|
||||
for action in actions:
|
||||
if action.action_type == "reply":
|
||||
action_data = action.action_data or {}
|
||||
new_words = action_data.get("unknown_words")
|
||||
|
||||
|
||||
# 合并新提取的和缓存的黑话列表
|
||||
merged_words = self._merge_unknown_words_with_cache(new_words)
|
||||
|
||||
|
||||
# 更新 action_data
|
||||
if merged_words:
|
||||
action_data["unknown_words"] = merged_words
|
||||
|
|
@ -366,7 +375,7 @@ class ActionPlanner:
|
|||
else:
|
||||
# 如果没有合并后的黑话,移除 unknown_words 字段
|
||||
action_data.pop("unknown_words", None)
|
||||
|
||||
|
||||
# 更新缓存(将新提取的黑话加入缓存)
|
||||
if new_words:
|
||||
self._update_unknown_words_cache(new_words)
|
||||
|
|
@ -442,15 +451,19 @@ class ActionPlanner:
|
|||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
|
||||
if (
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
):
|
||||
has_reply_to_force_message = True
|
||||
break
|
||||
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
# 移除所有 no_reply action(如果有)
|
||||
actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
|
||||
|
||||
# 创建强制回复 action
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
force_reply_action = ActionPlannerInfo(
|
||||
|
|
@ -577,10 +590,11 @@ class ActionPlanner:
|
|||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = ""
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
"5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]'
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
|
|
@ -590,7 +604,9 @@ class ActionPlanner:
|
|||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
"6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
|
|
@ -741,15 +757,21 @@ class ActionPlanner:
|
|||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return f"LLM 请求失败,模型出现问题: {req_e}", [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
return (
|
||||
f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
|
|
|
|||
|
|
@ -1071,7 +1071,6 @@ class DefaultReplyer:
|
|||
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
|
||||
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
|
||||
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = global_config.personality.multiple_reply_style
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
|
|||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
|
|
@ -807,7 +808,7 @@ class PrivateReplyer:
|
|||
reply_style = global_config.personality.reply_style
|
||||
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
|
||||
|
||||
if is_bot_self(platform, user_id):
|
||||
prompt_template = prompt_manager.get_prompt("private_replyer_self")
|
||||
prompt_template.add_context("target", target)
|
||||
|
|
|
|||
|
|
@ -519,7 +519,7 @@ def _build_readable_messages_internal(
|
|||
output_lines: List[str] = []
|
||||
|
||||
prev_timestamp: Optional[float] = None
|
||||
for timestamp, name, content, is_action in detailed_message:
|
||||
for timestamp, name, content, _is_action in detailed_message:
|
||||
# 检查是否需要插入长时间间隔提示
|
||||
if long_time_notice and prev_timestamp is not None:
|
||||
time_diff = timestamp - prev_timestamp
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from src.common.logger import get_logger
|
|||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
|
||||
class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
|
|||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
self.session_id = session_id
|
||||
|
|
@ -33,4 +34,4 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
|||
platform=self.platform,
|
||||
user_id=self.user_id,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
|
|
|||
|
|
@ -221,5 +221,7 @@ if not supports_truecolor():
|
|||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
else:
|
||||
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
|
||||
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold)
|
||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
escape_str = rgb_pair_to_ansi_truecolor(
|
||||
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
|
||||
)
|
||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from .server import get_global_server
|
|||
|
||||
global_api = None
|
||||
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
global global_api
|
||||
|
|
@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
|
||||
return False
|
||||
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3. Setup Message Bridge
|
||||
# Initialize refined route map if not exists
|
||||
if not hasattr(global_api, "platform_map"):
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
|
||||
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
|
||||
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
|
||||
|
|
@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
|
||||
|
||||
if platform:
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to update platform map: {e}")
|
||||
|
|
@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
if "raw_message" not in msg_dict:
|
||||
msg_dict["raw_message"] = None
|
||||
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3.5. Register custom message handlers (bridge to Legacy handlers)
|
||||
# message_id_echo: handles message ID echo from adapters
|
||||
# 兼容新旧两个版本的 maim_message:
|
||||
# - 旧版: handler(payload)
|
||||
# - 新版: handler(payload, metadata)
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
# Bridge to the Legacy custom handler registered in main.py
|
||||
try:
|
||||
# The Legacy handler expects the payload format directly
|
||||
if hasattr(global_api, "_custom_message_handlers"):
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
if handler:
|
||||
await handler(payload)
|
||||
api_logger.debug(f"Processed message_id_echo: {payload}")
|
||||
|
|
@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
except Exception as e:
|
||||
api_logger.warning(f"Failed to process message_id_echo: {e}")
|
||||
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 4. Initialize Server
|
||||
extra_server = WebSocketServer(config=server_config)
|
||||
|
|
@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||
global_api.stop = patched_stop
|
||||
|
||||
# Attach for reference
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
|
||||
except ImportError:
|
||||
get_logger("maim_message").error(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
|
|||
|
||||
logger = get_logger("file_utils")
|
||||
|
||||
|
||||
class FileUtils:
|
||||
@staticmethod
|
||||
def save_binary_to_file(file_path: Path, data: bytes):
|
||||
|
|
@ -35,7 +36,7 @@ class FileUtils:
|
|||
except Exception as e:
|
||||
logger.error(f"保存文件 {file_path} 失败: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_file_path_by_hash(data_hash: str) -> Path:
|
||||
"""
|
||||
|
|
@ -52,4 +53,4 @@ class FileUtils:
|
|||
if binary_data := session.exec(statement).first():
|
||||
return Path(binary_data.full_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")
|
||||
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")
|
||||
|
|
|
|||
|
|
@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
|||
|
||||
reason = ",".join(reasons)
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||
|
||||
|
|
|
|||
|
|
@ -86,8 +86,8 @@ def init_dream_tools(chat_id: str) -> None:
|
|||
finish_maintenance = make_finish_maintenance(chat_id)
|
||||
|
||||
search_jargon = make_search_jargon(chat_id)
|
||||
delete_jargon = make_delete_jargon(chat_id)
|
||||
update_jargon = make_update_jargon(chat_id)
|
||||
_delete_jargon = make_delete_jargon(chat_id)
|
||||
_update_jargon = make_update_jargon(chat_id)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
|
|
|
|||
|
|
@ -54,8 +54,6 @@ async def generate_dream_summary(
|
|||
) -> None:
|
||||
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
||||
try:
|
||||
|
||||
|
||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||
tool_results_map: dict[str, str] = {}
|
||||
for msg in conversation_messages:
|
||||
|
|
|
|||
|
|
@ -4,4 +4,3 @@ dream agent 工具实现模块。
|
|||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
|
|||
return f"create_chat_history 执行失败: {e}"
|
||||
|
||||
return create_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||
return f"delete_chat_history 执行失败: {e}"
|
||||
|
||||
return delete_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||
return f"delete_jargon 执行失败: {e}"
|
||||
|
||||
return delete_jargon
|
||||
|
||||
|
|
|
|||
|
|
@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
|||
return msg
|
||||
|
||||
return finish_maintenance
|
||||
|
||||
|
|
|
|||
|
|
@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
|||
return f"get_chat_history_detail 执行失败: {e}"
|
||||
|
||||
return get_chat_history_detail
|
||||
|
||||
|
|
|
|||
|
|
@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
|
|||
return f"search_chat_history 执行失败: {e}"
|
||||
|
||||
return search_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||
return f"update_chat_history 执行失败: {e}"
|
||||
|
||||
return update_chat_history
|
||||
|
||||
|
|
|
|||
|
|
@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||
return f"update_jargon 执行失败: {e}"
|
||||
|
||||
return update_jargon
|
||||
|
||||
|
|
|
|||
|
|
@ -458,8 +458,8 @@ def _default_normal_response_parser(
|
|||
if not isinstance(arguments, dict):
|
||||
# 此时为了调试方便,建议打印出 arguments 的类型
|
||||
raise RespParseException(
|
||||
resp,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
|
||||
resp,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
|
||||
)
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||
except json.JSONDecodeError as e:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import time
|
|||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
|
@ -34,7 +34,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(ThinkingQuestion).where(
|
||||
(ThinkingQuestion.found_answer == False)
|
||||
col(ThinkingQuestion.found_answer).is_(False)
|
||||
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
|
|
@ -786,8 +786,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
|
|||
str: 格式化的查询历史字符串
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
start_time = current_time - time_window_seconds
|
||||
_current_time = time.time()
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
|
|
@ -838,15 +837,14 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
|
|||
List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx"
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
start_time = current_time - time_window_seconds
|
||||
_current_time = time.time()
|
||||
|
||||
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ThinkingQuestion)
|
||||
.where(col(ThinkingQuestion.context) == chat_id)
|
||||
.where(col(ThinkingQuestion.found_answer) == True)
|
||||
.where(col(ThinkingQuestion.found_answer))
|
||||
.where(col(ThinkingQuestion.answer).is_not(None))
|
||||
.where(col(ThinkingQuestion.answer) != "")
|
||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||
|
|
|
|||
|
|
@ -105,25 +105,27 @@ async def search_chat_history(
|
|||
# 检查参数
|
||||
if not keyword and not participant and not start_time and not end_time:
|
||||
return "未指定查询参数(需要提供keyword、participant、start_time或end_time之一)"
|
||||
|
||||
|
||||
# 解析时间参数
|
||||
start_timestamp = None
|
||||
end_timestamp = None
|
||||
|
||||
|
||||
if start_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_time)
|
||||
except ValueError as e:
|
||||
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
|
||||
end_timestamp = parse_datetime_to_timestamp(end_time)
|
||||
except ValueError as e:
|
||||
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
|
||||
# 验证时间范围
|
||||
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
|
||||
return "开始时间不能晚于结束时间"
|
||||
|
|
@ -158,23 +160,20 @@ async def search_chat_history(
|
|||
f"search_chat_history 当前聊天流在黑名单中,强制使用本地查询,chat_id={chat_id}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
|
||||
# 添加时间过滤条件
|
||||
if start_timestamp is not None and end_timestamp is not None:
|
||||
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
|
||||
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
||||
query = query.where(
|
||||
(
|
||||
(ChatHistory.start_time >= start_timestamp)
|
||||
& (ChatHistory.start_time <= end_timestamp)
|
||||
(ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
|
||||
) # 记录开始时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.end_time >= start_timestamp)
|
||||
& (ChatHistory.end_time <= end_timestamp)
|
||||
(ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
|
||||
) # 记录结束时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.start_time <= start_timestamp)
|
||||
& (ChatHistory.end_time >= end_timestamp)
|
||||
(ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
|
||||
) # 记录完全包含查询时间段
|
||||
)
|
||||
logger.debug(
|
||||
|
|
@ -302,7 +301,7 @@ async def search_chat_history(
|
|||
time_desc = f"时间<='{end_str}'"
|
||||
if time_desc:
|
||||
conditions.append(time_desc)
|
||||
|
||||
|
||||
if conditions:
|
||||
conditions_str = "且".join(conditions)
|
||||
return f"未找到满足条件({conditions_str})的聊天记录"
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ async def query_words(chat_id: str, words: str) -> str:
|
|||
if separator in words:
|
||||
words_list = [w.strip() for w in words.split(separator) if w.strip()]
|
||||
break
|
||||
|
||||
|
||||
# 如果没有找到分隔符,整个字符串作为一个词语
|
||||
if not words_list:
|
||||
words_list = [words.strip()]
|
||||
|
|
@ -76,4 +76,3 @@ def register_tool():
|
|||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ async def generate_reply(
|
|||
# 如果 reply_time_point 未传入,设置为当前时间戳
|
||||
if reply_time_point is None:
|
||||
reply_time_point = time.time()
|
||||
|
||||
|
||||
# 获取回复器
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,18 @@ def unregister_service(service_name: str, plugin_name: Optional[str] = None) ->
|
|||
return plugin_service_registry.unregister_service(service_name, plugin_name)
|
||||
|
||||
|
||||
async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
|
||||
async def call_service(
|
||||
service_name: str,
|
||||
*args: Any,
|
||||
plugin_name: Optional[str] = None,
|
||||
caller_plugin: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""调用插件服务。"""
|
||||
return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs)
|
||||
return await plugin_service_registry.call_service(
|
||||
service_name,
|
||||
*args,
|
||||
plugin_name=plugin_name,
|
||||
caller_plugin=caller_plugin,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -558,7 +558,9 @@ class PluginBase(ABC):
|
|||
if version_spec:
|
||||
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
||||
if not is_ok:
|
||||
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
|
||||
logger.error(
|
||||
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
|
||||
)
|
||||
return False
|
||||
|
||||
if min_version or max_version:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -11,6 +11,8 @@ class PluginServiceInfo:
|
|||
version: str = "1.0.0"
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
public: bool = False
|
||||
allowed_callers: List[str] = field(default_factory=list)
|
||||
params_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
return_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -274,6 +274,23 @@ class ComponentRegistry:
|
|||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
async def remove_components_by_plugin(self, plugin_name: str) -> int:
|
||||
"""移除某插件注册的所有组件。"""
|
||||
targets = [
|
||||
(component_info.name, component_info.component_type)
|
||||
for component_info in self._components.values()
|
||||
if component_info.plugin_name == plugin_name
|
||||
]
|
||||
|
||||
removed_count = 0
|
||||
for component_name, component_type in targets:
|
||||
if await self.remove_component(component_name, component_type, plugin_name):
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
|
||||
return removed_count
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
"""移除插件注册信息
|
||||
|
||||
|
|
@ -734,9 +751,7 @@ class ComponentRegistry:
|
|||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
"workflow_steps": workflow_step_count,
|
||||
"enabled_workflow_steps": enabled_workflow_step_count,
|
||||
"workflow_steps_by_stage": {
|
||||
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
|
||||
},
|
||||
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -200,13 +200,43 @@ class PluginManager:
|
|||
"""
|
||||
重载插件模块
|
||||
"""
|
||||
old_instance = self.loaded_plugins.get(plugin_name)
|
||||
if not old_instance:
|
||||
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
|
||||
return False
|
||||
|
||||
if not await self.remove_registered_plugin(plugin_name):
|
||||
return False
|
||||
|
||||
if not self.load_registered_plugin_classes(plugin_name)[0]:
|
||||
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
|
||||
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
|
||||
if rollback_ok:
|
||||
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
|
||||
else:
|
||||
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
|
||||
return False
|
||||
|
||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||
return True
|
||||
|
||||
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
|
||||
"""重载失败后回滚旧实例。"""
|
||||
try:
|
||||
await component_registry.remove_components_by_plugin(plugin_name)
|
||||
component_registry.remove_plugin_registry(plugin_name)
|
||||
plugin_service_registry.remove_services_by_plugin(plugin_name)
|
||||
|
||||
if not old_instance.register_plugin():
|
||||
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
|
||||
return False
|
||||
|
||||
self.loaded_plugins[plugin_name] = old_instance
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件根目录
|
||||
|
|
@ -399,7 +429,9 @@ class PluginManager:
|
|||
|
||||
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
||||
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
||||
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
|
||||
indegree: Dict[str, int] = {
|
||||
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
|
||||
}
|
||||
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
||||
|
||||
for plugin_name, dependencies in dependency_graph.items():
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ class PluginServiceRegistry:
|
|||
if "." in service_info.plugin_name:
|
||||
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
|
||||
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
|
||||
return False
|
||||
|
||||
full_name = service_info.full_name
|
||||
if full_name in self._services:
|
||||
|
|
@ -52,7 +55,9 @@ class PluginServiceRegistry:
|
|||
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||
return self._service_handlers.get(full_name) if full_name else None
|
||||
|
||||
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
||||
def list_services(
|
||||
self, plugin_name: Optional[str] = None, enabled_only: bool = False
|
||||
) -> Dict[str, PluginServiceInfo]:
|
||||
"""列出插件服务。"""
|
||||
services = self._services.copy()
|
||||
if plugin_name:
|
||||
|
|
@ -103,12 +108,33 @@ class PluginServiceRegistry:
|
|||
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
|
||||
return removed_count
|
||||
|
||||
async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
|
||||
async def call_service(
|
||||
self,
|
||||
service_name: str,
|
||||
*args: Any,
|
||||
plugin_name: Optional[str] = None,
|
||||
caller_plugin: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""调用插件服务(支持同步/异步handler)。"""
|
||||
service_info = self.get_service(service_name, plugin_name)
|
||||
if not service_info:
|
||||
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
||||
raise ValueError(f"插件服务未注册: {target_name}")
|
||||
|
||||
if (
|
||||
"." not in service_name
|
||||
and plugin_name is None
|
||||
and caller_plugin
|
||||
and service_info.plugin_name != caller_plugin
|
||||
):
|
||||
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
|
||||
|
||||
if not self._is_call_authorized(service_info, caller_plugin):
|
||||
raise PermissionError(
|
||||
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
|
||||
)
|
||||
|
||||
if not service_info.enabled:
|
||||
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
|
||||
|
||||
|
|
@ -116,8 +142,93 @@ class PluginServiceRegistry:
|
|||
if not handler:
|
||||
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
|
||||
|
||||
self._validate_input_contract(service_info, args, kwargs)
|
||||
|
||||
result = handler(*args, **kwargs)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
resolved_result = await result if inspect.isawaitable(result) else result
|
||||
self._validate_output_contract(service_info, resolved_result)
|
||||
return resolved_result
|
||||
|
||||
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
|
||||
"""检查服务调用权限。"""
|
||||
if caller_plugin is None:
|
||||
return service_info.public
|
||||
if caller_plugin == service_info.plugin_name:
|
||||
return True
|
||||
if service_info.public:
|
||||
return True
|
||||
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
|
||||
return "*" in allowed_callers or caller_plugin in allowed_callers
|
||||
|
||||
def _validate_input_contract(
|
||||
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""校验服务入参契约。"""
|
||||
schema = service_info.params_schema
|
||||
if not schema:
|
||||
return
|
||||
|
||||
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
|
||||
is_invocation_schema = "args" in properties or "kwargs" in properties
|
||||
|
||||
if is_invocation_schema:
|
||||
payload = {"args": list(args), "kwargs": kwargs}
|
||||
self._validate_by_schema(payload, schema, path="params")
|
||||
return
|
||||
|
||||
if args:
|
||||
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
|
||||
self._validate_by_schema(kwargs, schema, path="params")
|
||||
|
||||
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
|
||||
"""校验服务返回值契约。"""
|
||||
if not service_info.return_schema:
|
||||
return
|
||||
self._validate_by_schema(value, service_info.return_schema, path="return")
|
||||
|
||||
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
|
||||
"""基于简化JSON-Schema校验数据。"""
|
||||
expected_type = schema.get("type")
|
||||
if expected_type:
|
||||
self._validate_type(value, expected_type, path)
|
||||
|
||||
enum_values = schema.get("enum")
|
||||
if enum_values is not None and value not in enum_values:
|
||||
raise ValueError(f"{path} 不在枚举范围内: {value}")
|
||||
|
||||
if expected_type == "object":
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
for field in required:
|
||||
if field not in value:
|
||||
raise ValueError(f"{path}.{field} 为必填字段")
|
||||
|
||||
for field, field_value in value.items():
|
||||
if field in properties:
|
||||
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
|
||||
elif schema.get("additionalProperties", True) is False:
|
||||
raise ValueError(f"{path}.{field} 不允许额外字段")
|
||||
|
||||
if expected_type == "array":
|
||||
if item_schema := schema.get("items"):
|
||||
for index, item in enumerate(value):
|
||||
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
|
||||
|
||||
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
|
||||
"""校验基础类型。"""
|
||||
type_checkers: Dict[str, Callable[[Any], bool]] = {
|
||||
"string": lambda item: isinstance(item, str),
|
||||
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
|
||||
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
|
||||
"boolean": lambda item: isinstance(item, bool),
|
||||
"object": lambda item: isinstance(item, dict),
|
||||
"array": lambda item: isinstance(item, list),
|
||||
"null": lambda item: item is None,
|
||||
}
|
||||
checker = type_checkers.get(expected_type)
|
||||
if checker and not checker(value):
|
||||
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
|
||||
|
||||
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
|
||||
"""解析服务全名。"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
import uuid
|
||||
import inspect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, MaiMessages
|
||||
|
|
@ -95,7 +96,9 @@ class WorkflowEngine:
|
|||
except Exception as e:
|
||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
||||
workflow_context.errors.append(f"{stage_key}: {e}")
|
||||
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
|
||||
)
|
||||
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||
return (
|
||||
|
|
@ -144,11 +147,19 @@ class WorkflowEngine:
|
|||
|
||||
step_timing_key = f"{stage.value}:{step_info.full_name}"
|
||||
step_start = time.perf_counter()
|
||||
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
|
||||
|
||||
try:
|
||||
result = handler(context, message)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
coroutine = handler(context, message)
|
||||
result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
|
||||
else:
|
||||
if timeout_seconds:
|
||||
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
|
||||
else:
|
||||
result = handler(context, message)
|
||||
if inspect.isawaitable(result):
|
||||
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
|
||||
normalized_result = self._normalize_step_result(result)
|
||||
|
|
@ -165,10 +176,30 @@ class WorkflowEngine:
|
|||
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
|
||||
return normalized_result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
|
||||
context.errors.append(f"{step_info.full_name}: {timeout_message}")
|
||||
logger.error(
|
||||
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
|
||||
)
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=timeout_message,
|
||||
diagnostics={
|
||||
"stage": stage.value,
|
||||
"step": step_info.full_name,
|
||||
"trace_id": context.trace_id,
|
||||
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
context.errors.append(f"{step_info.full_name}: {e}")
|
||||
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
|
||||
)
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class PromptManager:
|
|||
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||
"""
|
||||
添加一个上下文构造函数
|
||||
|
||||
|
||||
Args:
|
||||
name (str): 上下文名称
|
||||
func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数,接受 Prompt 名称作为参数,返回字符串或返回字符串的协程
|
||||
|
|
@ -144,7 +144,7 @@ class PromptManager:
|
|||
def get_prompt(self, prompt_name: str) -> Prompt:
|
||||
"""
|
||||
获取指定名称的 Prompt 实例的克隆
|
||||
|
||||
|
||||
Args:
|
||||
prompt_name (str): 要获取的 Prompt 名称
|
||||
Returns:
|
||||
|
|
@ -161,7 +161,7 @@ class PromptManager:
|
|||
async def render_prompt(self, prompt: Prompt) -> str:
|
||||
"""
|
||||
渲染一个 Prompt 实例
|
||||
|
||||
|
||||
Args:
|
||||
prompt (Prompt): 要渲染的 Prompt 实例
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
|
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
|
|||
|
||||
class ChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
|
||||
chat_id: str
|
||||
plan_count: int
|
||||
latest_timestamp: float
|
||||
|
|
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
|
|||
|
||||
class PlanLogSummary(BaseModel):
|
||||
"""规划日志摘要"""
|
||||
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
|
|
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
|
|||
|
||||
class PlanLogDetail(BaseModel):
|
||||
"""规划日志详情"""
|
||||
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
|
|
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
|
|||
|
||||
class PlannerOverview(BaseModel):
|
||||
"""规划器总览 - 轻量级统计"""
|
||||
|
||||
total_chats: int
|
||||
total_plans: int
|
||||
chats: List[ChatSummary]
|
||||
|
|
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
|
|||
|
||||
class PaginatedChatLogs(BaseModel):
|
||||
"""分页的聊天日志列表"""
|
||||
|
||||
data: List[PlanLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
|
|||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
timestamp_str = filename.split("_")[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
|
|
@ -86,41 +92,39 @@ async def get_planner_overview():
|
|||
"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
|
||||
|
||||
|
||||
chats = []
|
||||
total_plans = 0
|
||||
|
||||
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
plan_count = len(json_files)
|
||||
total_plans += plan_count
|
||||
|
||||
|
||||
if plan_count == 0:
|
||||
continue
|
||||
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
|
||||
chats.append(
|
||||
ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name,
|
||||
)
|
||||
)
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return PlannerOverview(
|
||||
total_chats=len(chats),
|
||||
total_plans=total_plans,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||
|
|
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
|
|||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||
):
|
||||
"""
|
||||
获取指定聊天的规划日志列表(分页)
|
||||
|
|
@ -137,73 +141,69 @@ async def get_chat_plan_logs(
|
|||
"""
|
||||
chat_dir = PLAN_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedChatLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
prompt = data.get("prompt", "")
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
page_files = json_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
actions = data.get('actions', [])
|
||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
||||
))
|
||||
reasoning = data.get("reasoning", "")
|
||||
actions = data.get("actions", [])
|
||||
action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
|
||||
logs.append(
|
||||
PlanLogSummary(
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
|
||||
llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else "",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedChatLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
logs.append(
|
||||
PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview="[读取失败]",
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||
|
|
@ -212,22 +212,23 @@ async def get_log_detail(chat_id: str, filename: str):
|
|||
log_file = PLAN_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return PlanLogDetail(**data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ========== 兼容旧接口 ==========
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_planner_stats():
|
||||
"""获取规划器统计信息 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
|
||||
|
||||
# 获取最近10条计划的摘要
|
||||
recent_plans = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
|
|
@ -236,17 +237,17 @@ async def get_planner_stats():
|
|||
recent_plans.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_plans = recent_plans[:10]
|
||||
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_plans": overview.total_plans,
|
||||
"avg_plan_time_ms": 0,
|
||||
"avg_llm_time_ms": 0,
|
||||
"recent_plans": recent_plans
|
||||
"recent_plans": recent_plans,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -258,44 +259,43 @@ async def get_chat_list():
|
|||
|
||||
|
||||
@router.get("/all-logs")
|
||||
async def get_all_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100)
|
||||
):
|
||||
async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
|
||||
"""获取所有规划日志 - 兼容旧接口"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
# 收集所有文件
|
||||
all_files = []
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if chat_dir.is_dir():
|
||||
for log_file in chat_dir.glob("*.json"):
|
||||
all_files.append((chat_dir.name, log_file))
|
||||
|
||||
|
||||
# 按时间戳排序
|
||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||
|
||||
|
||||
total = len(all_files)
|
||||
offset = (page - 1) * page_size
|
||||
page_files = all_files[offset:offset + page_size]
|
||||
|
||||
page_files = all_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for chat_id, log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
logs.append({
|
||||
"chat_id": data.get('chat_id', chat_id),
|
||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get('actions', [])),
|
||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
||||
})
|
||||
reasoning = data.get("reasoning", "")
|
||||
logs.append(
|
||||
{
|
||||
"chat_id": data.get("chat_id", chat_id),
|
||||
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get("actions", [])),
|
||||
"total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
|
||||
"llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else "",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
|
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
|
|||
|
||||
class ReplierChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
|
||||
chat_id: str
|
||||
reply_count: int
|
||||
latest_timestamp: float
|
||||
|
|
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
|
|||
|
||||
class ReplyLogSummary(BaseModel):
|
||||
"""回复日志摘要"""
|
||||
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
|
|
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
|
|||
|
||||
class ReplyLogDetail(BaseModel):
|
||||
"""回复日志详情"""
|
||||
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
|
|
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
|
|||
|
||||
class ReplierOverview(BaseModel):
|
||||
"""回复器总览 - 轻量级统计"""
|
||||
|
||||
total_chats: int
|
||||
total_replies: int
|
||||
chats: List[ReplierChatSummary]
|
||||
|
|
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
|
|||
|
||||
class PaginatedReplyLogs(BaseModel):
|
||||
"""分页的回复日志列表"""
|
||||
|
||||
data: List[ReplyLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
|
|||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
timestamp_str = filename.split("_")[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
|
|
@ -89,41 +95,39 @@ async def get_replier_overview():
|
|||
"""
|
||||
if not REPLY_LOG_DIR.exists():
|
||||
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
|
||||
|
||||
|
||||
chats = []
|
||||
total_replies = 0
|
||||
|
||||
|
||||
for chat_dir in REPLY_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
reply_count = len(json_files)
|
||||
total_replies += reply_count
|
||||
|
||||
|
||||
if reply_count == 0:
|
||||
continue
|
||||
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
|
||||
chats.append(
|
||||
ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name,
|
||||
)
|
||||
)
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return ReplierOverview(
|
||||
total_chats=len(chats),
|
||||
total_replies=total_replies,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||
|
|
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
|
|||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||
):
|
||||
"""
|
||||
获取指定聊天的回复日志列表(分页)
|
||||
|
|
@ -140,71 +144,67 @@ async def get_chat_reply_logs(
|
|||
"""
|
||||
chat_dir = REPLY_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedReplyLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
prompt = data.get("prompt", "")
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
page_files = json_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
output = data.get('output', '')
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get('model', ''),
|
||||
success=data.get('success', True),
|
||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
||||
output_preview=output[:100] if output else ''
|
||||
))
|
||||
output = data.get("output", "")
|
||||
logs.append(
|
||||
ReplyLogSummary(
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get("model", ""),
|
||||
success=data.get("success", True),
|
||||
llm_ms=data.get("timing", {}).get("llm_ms", 0),
|
||||
overall_ms=data.get("timing", {}).get("overall_ms", 0),
|
||||
output_preview=output[:100] if output else "",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model='',
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedReplyLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
logs.append(
|
||||
ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model="",
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview="[读取失败]",
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||
|
|
@ -213,35 +213,36 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
|||
log_file = REPLY_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return ReplyLogDetail(
|
||||
type=data.get('type', 'reply'),
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', 0),
|
||||
prompt=data.get('prompt', ''),
|
||||
output=data.get('output', ''),
|
||||
processed_output=data.get('processed_output', []),
|
||||
model=data.get('model', ''),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
think_level=data.get('think_level', 0),
|
||||
timing=data.get('timing', {}),
|
||||
error=data.get('error'),
|
||||
success=data.get('success', True)
|
||||
type=data.get("type", "reply"),
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", 0),
|
||||
prompt=data.get("prompt", ""),
|
||||
output=data.get("output", ""),
|
||||
processed_output=data.get("processed_output", []),
|
||||
model=data.get("model", ""),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
think_level=data.get("think_level", 0),
|
||||
timing=data.get("timing", {}),
|
||||
error=data.get("error"),
|
||||
success=data.get("success", True),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ========== 兼容接口 ==========
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_replier_stats():
|
||||
"""获取回复器统计信息"""
|
||||
overview = await get_replier_overview()
|
||||
|
||||
|
||||
# 获取最近10条回复的摘要
|
||||
recent_replies = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
|
|
@ -250,15 +251,15 @@ async def get_replier_stats():
|
|||
recent_replies.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_replies = recent_replies[:10]
|
||||
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_replies": overview.total_replies,
|
||||
"recent_replies": recent_replies
|
||||
"recent_replies": recent_replies,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -266,4 +267,4 @@ async def get_replier_stats():
|
|||
async def get_replier_chat_list():
|
||||
"""获取所有聊天ID列表"""
|
||||
overview = await get_replier_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
from fastapi import Depends, Cookie, Header, Request, HTTPException
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
|
||||
from fastapi import Depends, Cookie, Header, Request
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit
|
||||
|
||||
|
||||
async def require_auth(
|
||||
|
|
|
|||
|
|
@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
|||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
|
|
@ -151,7 +152,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
|
|||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
|
||||
# 跳过注释行(以#开头)
|
||||
if ip_entry.startswith("#"):
|
||||
continue
|
||||
|
|
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
|||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
"mode": global_config.webui.anti_crawler_mode,
|
||||
"allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
"trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
"trust_xff": global_config.webui.trust_xff,
|
||||
}
|
||||
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
ANTI_CRAWLER_MODE = _config["mode"]
|
||||
ALLOWED_IPS = _config["allowed_ips"]
|
||||
TRUSTED_PROXIES = _config["trusted_proxies"]
|
||||
TRUST_XFF = _config["trust_xff"]
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
|||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.is_at) == True,
|
||||
col(Messages.is_at),
|
||||
)
|
||||
data.at_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
|
|
@ -342,7 +342,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
|||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.is_mentioned) == True,
|
||||
col(Messages.is_mentioned),
|
||||
)
|
||||
data.mentioned_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
|
|
@ -552,7 +552,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
# 1. 表情包之王 - 使用次数最多的表情包
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5)
|
||||
select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
|
||||
)
|
||||
top_emojis = session.exec(statement).all()
|
||||
if top_emojis:
|
||||
|
|
@ -636,7 +636,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.is_picture) == True,
|
||||
col(Messages.is_picture),
|
||||
)
|
||||
data.image_processed_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
|
|
@ -781,12 +781,12 @@ async def get_achievements(year: int = 2025) -> AchievementData:
|
|||
# 1. 新学到的黑话数量
|
||||
# Jargon 表没有时间字段,统计全部已确认的黑话
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(col(Jargon.is_jargon) == True)
|
||||
statement = select(func.count()).where(col(Jargon.is_jargon))
|
||||
data.new_jargon_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 2. 代表性黑话示例
|
||||
with get_db_session() as session:
|
||||
statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5)
|
||||
statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5)
|
||||
jargon_samples = session.exec(statement).all()
|
||||
data.sample_jargons = [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -532,7 +532,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
|||
.select_from(Images)
|
||||
.where(
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
col(Images.is_registered) == True,
|
||||
col(Images.is_registered),
|
||||
)
|
||||
)
|
||||
banned_statement = (
|
||||
|
|
@ -540,7 +540,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
|||
.select_from(Images)
|
||||
.where(
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
col(Images.is_banned) == True,
|
||||
col(Images.is_banned),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1283,7 +1283,7 @@ async def preheat_thumbnail_cache(
|
|||
select(Images)
|
||||
.where(
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
col(Images.is_banned) == False,
|
||||
col(Images.is_banned).is_(False),
|
||||
)
|
||||
.order_by(col(Images.query_count).desc())
|
||||
.limit(limit * 2)
|
||||
|
|
|
|||
|
|
@ -315,15 +315,15 @@ async def get_jargon_stats():
|
|||
total = session.exec(select(fn.count()).select_from(Jargon)).one()
|
||||
|
||||
confirmed_jargon = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))
|
||||
).one()
|
||||
confirmed_not_jargon = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
|
||||
).one()
|
||||
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
|
||||
|
||||
complete_count = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))
|
||||
).one()
|
||||
|
||||
chat_count = session.exec(
|
||||
|
|
|
|||
|
|
@ -17,36 +17,36 @@ _paragraph_store_cache = None
|
|||
|
||||
def _get_paragraph_store():
|
||||
"""延迟加载段落 embedding store(只读模式,轻量级)
|
||||
|
||||
|
||||
Returns:
|
||||
EmbeddingStore | None: 如果配置启用则返回store,否则返回None
|
||||
"""
|
||||
# 检查配置是否启用
|
||||
if not global_config.webui.enable_paragraph_content:
|
||||
return None
|
||||
|
||||
|
||||
global _paragraph_store_cache
|
||||
if _paragraph_store_cache is not None:
|
||||
return _paragraph_store_cache
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.embedding_store import EmbeddingStore
|
||||
import os
|
||||
|
||||
|
||||
# 获取数据路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||
embedding_dir = os.path.join(root_path, "data/embedding")
|
||||
|
||||
|
||||
# 只加载段落 embedding store(轻量级)
|
||||
paragraph_store = EmbeddingStore(
|
||||
namespace="paragraph",
|
||||
dir_path=embedding_dir,
|
||||
max_workers=1, # 只读不需要多线程
|
||||
chunk_size=100
|
||||
chunk_size=100,
|
||||
)
|
||||
paragraph_store.load_from_file()
|
||||
|
||||
|
||||
_paragraph_store_cache = paragraph_store
|
||||
logger.info(f"成功加载段落 embedding store,包含 {len(paragraph_store.store)} 个段落")
|
||||
return paragraph_store
|
||||
|
|
@ -57,10 +57,10 @@ def _get_paragraph_store():
|
|||
|
||||
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||
"""从 embedding store 获取段落完整内容
|
||||
|
||||
|
||||
Args:
|
||||
node_id: 段落节点ID,格式为 'paragraph-{hash}'
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
|
||||
"""
|
||||
|
|
@ -69,12 +69,12 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
|||
if paragraph_store is None:
|
||||
# 功能未启用
|
||||
return None, False
|
||||
|
||||
|
||||
# 从 store 中获取完整内容
|
||||
paragraph_item = paragraph_store.store.get(node_id)
|
||||
if paragraph_item is not None:
|
||||
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
||||
content: str = getattr(paragraph_item, 'str', '')
|
||||
content: str = getattr(paragraph_item, "str", "")
|
||||
if content:
|
||||
return content, True
|
||||
return None, True
|
||||
|
|
@ -156,14 +156,18 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
|||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
|
||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
|
|
@ -245,14 +249,18 @@ async def get_knowledge_graph(
|
|||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
|
||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type_val == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
|
|
@ -368,11 +376,15 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
|
|||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
|
||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
|
|
|||
|
|
@ -370,7 +370,7 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
|||
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(PersonInfo.id)).all())
|
||||
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all())
|
||||
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all())
|
||||
unknown = total - known
|
||||
|
||||
# 按平台统计
|
||||
|
|
|
|||
|
|
@ -1762,7 +1762,7 @@ async def update_plugin_config_raw(
|
|||
try:
|
||||
tomlkit.loads(request.config)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 备份旧配置
|
||||
import shutil
|
||||
|
|
|
|||
|
|
@ -659,4 +659,4 @@ def get_git_mirror_service() -> GitMirrorService:
|
|||
global _git_mirror_service
|
||||
if _git_mirror_service is None:
|
||||
_git_mirror_service = GitMirrorService()
|
||||
return _git_mirror_service
|
||||
return _git_mirror_service
|
||||
|
|
|
|||
Loading…
Reference in New Issue