Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev

pull/1496/head
UnCLAS-Prommer 2026-02-21 23:50:53 +08:00
commit 04a5bf3c6d
No known key found for this signature in database
91 changed files with 2110 additions and 1967 deletions

1
bot.py
View File

@ -50,6 +50,7 @@ print("警告Dev进入不稳定开发状态任何插件与WebUI均可能
print("\n\n\n\n\n")
print("-----------------------------------------")
def run_runner_process():
"""
Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程

View File

@ -1,2 +1 @@
"""Core helpers for MCP Bridge Plugin."""

View File

@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
if not mcp_servers:
return ""
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -22,21 +22,24 @@ from typing import Any, Dict, List, Optional, Tuple
try:
from src.common.logger import get_logger
logger = get_logger("mcp_tool_chain")
except ImportError:
import logging
logger = logging.getLogger("mcp_tool_chain")
@dataclass
class ToolChainStep:
"""工具链步骤"""
tool_name: str # 要调用的工具名(如 mcp_server_tool
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
output_key: str = "" # 输出存储的键名,供后续步骤引用
description: str = "" # 步骤描述
optional: bool = False # 是否可选(失败时继续执行)
def to_dict(self) -> Dict[str, Any]:
return {
"tool_name": self.tool_name,
@ -45,7 +48,7 @@ class ToolChainStep:
"description": self.description,
"optional": self.optional,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep":
return cls(
@ -60,12 +63,13 @@ class ToolChainStep:
@dataclass
class ToolChainDefinition:
"""工具链定义"""
name: str # 工具链名称(将作为组合工具的名称)
description: str # 工具链描述(供 LLM 理解)
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述}
enabled: bool = True # 是否启用
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
@ -74,7 +78,7 @@ class ToolChainDefinition:
"input_params": self.input_params,
"enabled": self.enabled,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition":
steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])]
@ -90,12 +94,13 @@ class ToolChainDefinition:
@dataclass
class ChainExecutionResult:
"""工具链执行结果"""
success: bool
final_output: str # 最终输出(最后一个步骤的结果)
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
error: str = ""
total_duration_ms: float = 0.0
def to_summary(self) -> str:
"""生成执行摘要"""
lines = []
@ -103,7 +108,7 @@ class ChainExecutionResult:
status = "" if step.get("success") else ""
tool = step.get("tool_name", "unknown")
duration = step.get("duration_ms", 0)
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)")
lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
if not step.get("success") and step.get("error"):
lines.append(f" 错误: {step['error'][:50]}")
return "\n".join(lines)
@ -111,49 +116,49 @@ class ChainExecutionResult:
class ToolChainExecutor:
"""工具链执行器"""
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
def __init__(self, mcp_manager):
self._mcp_manager = mcp_manager
def _resolve_tool_key(self, tool_name: str) -> Optional[str]:
"""解析工具名,返回有效的 tool_key
支持:
- 直接使用 tool_key mcp_server_tool
- 使用注册后的工具名会自动转换 - . _
"""
all_tools = self._mcp_manager.all_tools
# 直接匹配
if tool_name in all_tools:
return tool_name
# 尝试转换后匹配(用户可能使用了注册后的名称)
normalized = tool_name.replace("-", "_").replace(".", "_")
if normalized in all_tools:
return normalized
# 尝试查找包含该名称的工具
for key in all_tools.keys():
if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"):
return key
return None
async def execute(
self,
chain: ToolChainDefinition,
input_args: Dict[str, Any],
) -> ChainExecutionResult:
"""执行工具链
Args:
chain: 工具链定义
input_args: 用户输入的参数
Returns:
ChainExecutionResult: 执行结果
"""
@ -164,15 +169,15 @@ class ToolChainExecutor:
"step": {}, # 各步骤输出,按 output_key 存储
"prev": "", # 上一步的输出
}
final_output = ""
# 验证必需的输入参数
missing_params = []
for param_name in chain.input_params.keys():
if param_name not in context["input"]:
missing_params.append(param_name)
if missing_params:
return ChainExecutionResult(
success=False,
@ -180,7 +185,7 @@ class ToolChainExecutor:
error=f"缺少必需参数: {', '.join(missing_params)}",
total_duration_ms=(time.time() - start_time) * 1000,
)
for i, step in enumerate(chain.steps):
step_start = time.time()
step_result = {
@ -191,96 +196,96 @@ class ToolChainExecutor:
"error": "",
"duration_ms": 0,
}
try:
# 替换参数中的变量
resolved_args = self._resolve_args(step.args_template, context)
step_result["resolved_args"] = resolved_args
# 解析工具名
tool_key = self._resolve_tool_key(step.tool_name)
if not tool_key:
step_result["error"] = f"工具 {step.tool_name} 不存在"
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在")
logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
if not step.optional:
step_results.append(step_result)
return ChainExecutionResult(
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在",
error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
total_duration_ms=(time.time() - start_time) * 1000,
)
step_results.append(step_result)
continue
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}")
logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
# 调用工具
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration
if result.success:
step_result["success"] = True
# 确保 content 不为 None
content = result.content if result.content is not None else ""
step_result["output"] = content
# 更新上下文
context["prev"] = content
if step.output_key:
context["step"][step.output_key] = content
final_output = content
content_preview = content[:100] if content else "(空)"
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...")
logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
else:
step_result["error"] = result.error or "未知错误"
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}")
logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
if not step.optional:
step_results.append(step_result)
return ChainExecutionResult(
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}",
error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
total_duration_ms=(time.time() - start_time) * 1000,
)
except Exception as e:
step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration
step_result["error"] = str(e)
logger.error(f"工具链步骤 {i+1} 异常: {e}")
logger.error(f"工具链步骤 {i + 1} 异常: {e}")
if not step.optional:
step_results.append(step_result)
return ChainExecutionResult(
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}",
error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
total_duration_ms=(time.time() - start_time) * 1000,
)
step_results.append(step_result)
total_duration = (time.time() - start_time) * 1000
return ChainExecutionResult(
success=True,
final_output=final_output,
step_results=step_results,
total_duration_ms=total_duration,
)
def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""解析参数模板,替换变量
支持的变量格式:
- ${input.param_name}: 用户输入的参数
- ${step.output_key}: 某个步骤的输出
@ -288,50 +293,48 @@ class ToolChainExecutor:
- ${prev.field}: 上一步输出JSON的某个字段
"""
resolved = {}
for key, value in args_template.items():
if isinstance(value, str):
resolved[key] = self._substitute_vars(value, context)
elif isinstance(value, dict):
resolved[key] = self._resolve_args(value, context)
elif isinstance(value, list):
resolved[key] = [
self._substitute_vars(v, context) if isinstance(v, str) else v
for v in value
]
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
else:
resolved[key] = value
return resolved
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
"""替换字符串中的变量"""
def replacer(match):
var_path = match.group(1)
return self._get_var_value(var_path, context)
return self.VAR_PATTERN.sub(replacer, template)
def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str:
"""获取变量值
Args:
var_path: 变量路径 "input.query", "step.search_result", "prev", "prev.id"
context: 上下文
"""
parts = self._parse_var_path(var_path)
if not parts:
return ""
# 获取根对象
root = parts[0]
if root not in context:
logger.warning(f"变量 {var_path} 的根 '{root}' 不存在")
return ""
value = context[root]
# 遍历路径
for part in parts[1:]:
if isinstance(value, str):
@ -349,7 +352,7 @@ class ToolChainExecutor:
value = ""
else:
value = ""
# 确保返回字符串
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
@ -448,39 +451,39 @@ class ToolChainExecutor:
class ToolChainManager:
"""工具链管理器"""
_instance: Optional["ToolChainManager"] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._chains: Dict[str, ToolChainDefinition] = {}
self._executor: Optional[ToolChainExecutor] = None
def set_executor(self, mcp_manager) -> None:
"""设置执行器"""
self._executor = ToolChainExecutor(mcp_manager)
def add_chain(self, chain: ToolChainDefinition) -> bool:
"""添加工具链"""
if not chain.name:
logger.error("工具链名称不能为空")
return False
if chain.name in self._chains:
logger.warning(f"工具链 {chain.name} 已存在,将被覆盖")
self._chains[chain.name] = chain
logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)")
return True
def remove_chain(self, name: str) -> bool:
"""移除工具链"""
if name in self._chains:
@ -488,19 +491,19 @@ class ToolChainManager:
logger.info(f"已移除工具链: {name}")
return True
return False
def get_chain(self, name: str) -> Optional[ToolChainDefinition]:
"""获取工具链"""
return self._chains.get(name)
def get_all_chains(self) -> Dict[str, ToolChainDefinition]:
"""获取所有工具链"""
return self._chains.copy()
def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]:
"""获取所有启用的工具链"""
return {name: chain for name, chain in self._chains.items() if chain.enabled}
async def execute_chain(
self,
chain_name: str,
@ -514,64 +517,64 @@ class ToolChainManager:
final_output="",
error=f"工具链 {chain_name} 不存在",
)
if not chain.enabled:
return ChainExecutionResult(
success=False,
final_output="",
error=f"工具链 {chain_name} 已禁用",
)
if not self._executor:
return ChainExecutionResult(
success=False,
final_output="",
error="工具链执行器未初始化",
)
return await self._executor.execute(chain, input_args)
def load_from_json(self, json_str: str) -> Tuple[int, List[str]]:
"""从 JSON 字符串加载工具链配置
Returns:
(成功加载数量, 错误列表)
"""
errors = []
loaded = 0
try:
data = json.loads(json_str) if json_str.strip() else []
except json.JSONDecodeError as e:
return 0, [f"JSON 解析失败: {e}"]
if not isinstance(data, list):
data = [data]
for i, item in enumerate(data):
try:
chain = ToolChainDefinition.from_dict(item)
if not chain.name:
errors.append(f"{i+1} 个工具链缺少名称")
errors.append(f"{i + 1} 个工具链缺少名称")
continue
if not chain.steps:
errors.append(f"工具链 {chain.name} 没有步骤")
continue
self.add_chain(chain)
loaded += 1
except Exception as e:
errors.append(f"{i+1} 个工具链解析失败: {e}")
errors.append(f"{i + 1} 个工具链解析失败: {e}")
return loaded, errors
def export_to_json(self, pretty: bool = True) -> str:
"""导出所有工具链为 JSON"""
chains_data = [chain.to_dict() for chain in self._chains.values()]
if pretty:
return json.dumps(chains_data, ensure_ascii=False, indent=2)
return json.dumps(chains_data, ensure_ascii=False)
def clear(self) -> None:
"""清空所有工具链"""
self._chains.clear()

View File

@ -238,7 +238,7 @@ class TestCommand(BaseCommand):
chat_stream=self.message.chat_stream,
reply_reason=reply_reason,
enable_chinese_typo=False,
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
)
if result_status:
# 发送生成的回复

View File

@ -46,6 +46,7 @@ def patch_attrdoc_post_init():
config_base_module.logger = logging.getLogger("config_base_test_logger")
class SimpleClass(ConfigBase):
a: int = 1
b: str = "test"
@ -282,7 +283,7 @@ class TestConfigBase:
True,
"ConfigBase is not Hashable",
id="listset-validation-set-configbase-element_reject",
)
),
],
)
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
@ -340,7 +341,7 @@ class TestConfigBase:
False,
None,
id="dict-validation-happy-configbase-value",
)
),
],
)
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
@ -353,13 +354,11 @@ class TestConfigBase:
field_name = "mapping"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_dict_type(annotation, field_name)
@ -392,7 +391,7 @@ class TestConfigBase:
# Assert
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
def test_discourage_any_usage_suppressed_warning(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = False

View File

@ -4,7 +4,6 @@ import importlib
import pytest
from pathlib import Path
import importlib.util
import asyncio
class DummyLogger:
@ -71,6 +70,7 @@ class DummyLLMRequest:
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
return ("dummy description", {})
class DummySelect:
def __init__(self, *a, **k):
pass
@ -81,6 +81,7 @@ class DummySelect:
def limit(self, n):
return self
@pytest.fixture(autouse=True)
def patch_external_dependencies(monkeypatch):
# Provide dummy implementations as modules so that importing image_manager is safe
@ -103,11 +104,11 @@ def patch_external_dependencies(monkeypatch):
# Patch MaiImage data model
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
# Patch SQLModel select function
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
# Patch config values used at import-time
cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style"))
model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm"))
@ -134,7 +135,7 @@ def _load_image_manager_module(tmp_path=None):
if tmp_path is not None:
tmpdir = Path(tmp_path)
tmpdir.mkdir(parents=True, exist_ok=True)
setattr(mod, "IMAGE_DIR", tmpdir)
mod.IMAGE_DIR = tmpdir
except Exception:
pass
return mod
@ -197,4 +198,3 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
# cleanup should run without error
mgr.cleanup_invalid_descriptions_in_db()

View File

@ -1,5 +1,3 @@
import pytest
from src.config.official_configs import ChatConfig
from src.config.config import Config
from src.webui.config_schema import ConfigSchemaGenerator

View File

@ -387,7 +387,7 @@ def test_auth_required_list(client):
"""测试未认证访问列表端点401"""
# Without mock_token_verify fixture
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
response = client.get("/emoji/list")
client.get("/emoji/list")
# verify_auth_token 返回 False 会触发 HTTPException
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
# 这里假设它抛出 401
@ -397,7 +397,7 @@ def test_auth_required_update(client, sample_emojis):
"""测试未认证访问更新端点401"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
emoji_id = sample_emojis[0].id
response = client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
# Should be unauthorized

View File

@ -1,6 +1,5 @@
"""Expression routes pytest tests"""
from datetime import datetime
from typing import Generator
from unittest.mock import MagicMock
@ -12,7 +11,6 @@ from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
def create_test_app() -> FastAPI:

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Set, Tuple
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.logger import get_logger
from src.common.logger import get_logger # noqa: E402
logger = get_logger("evaluation_stats_analyzer")
@ -38,10 +38,10 @@ def parse_datetime(dt_str: str) -> datetime | None:
def analyze_single_file(file_path: str) -> Dict:
"""
分析单个JSON文件的统计信息
Args:
file_path: JSON文件路径
Returns:
统计信息字典
"""
@ -65,40 +65,40 @@ def analyze_single_file(file_path: str) -> Dict:
"has_reason": False,
"reason_count": 0,
}
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 基本信息
stats["last_updated"] = data.get("last_updated")
stats["total_count"] = data.get("total_count", 0)
results = data.get("manual_results", [])
stats["actual_count"] = len(results)
if not results:
return stats
# 统计通过/不通过
suitable_count = sum(1 for r in results if r.get("suitable") is True)
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
stats["suitable_count"] = suitable_count
stats["unsuitable_count"] = unsuitable_count
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
# 统计唯一的(situation, style)对
pairs: Set[Tuple[str, str]] = set()
for r in results:
if "situation" in r and "style" in r:
pairs.add((r["situation"], r["style"]))
stats["unique_pairs"] = len(pairs)
# 统计评估者
for r in results:
evaluator = r.get("evaluator", "unknown")
stats["evaluators"][evaluator] += 1
# 统计评估时间
evaluation_dates = []
for r in results:
@ -107,7 +107,7 @@ def analyze_single_file(file_path: str) -> Dict:
dt = parse_datetime(evaluated_at)
if dt:
evaluation_dates.append(dt)
stats["evaluation_dates"] = evaluation_dates
if evaluation_dates:
min_date = min(evaluation_dates)
@ -115,18 +115,18 @@ def analyze_single_file(file_path: str) -> Dict:
stats["date_range"] = {
"start": min_date.isoformat(),
"end": max_date.isoformat(),
"duration_days": (max_date - min_date).days + 1
"duration_days": (max_date - min_date).days + 1,
}
# 检查字段存在性
stats["has_expression_id"] = any("expression_id" in r for r in results)
stats["has_reason"] = any(r.get("reason") for r in results)
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
except Exception as e:
stats["error"] = str(e)
logger.error(f"分析文件 {file_name} 时出错: {e}")
return stats
@ -136,57 +136,57 @@ def print_file_stats(stats: Dict, index: int = None):
print(f"\n{'=' * 80}")
print(f"{prefix}文件: {stats['file_name']}")
print(f"{'=' * 80}")
if stats["error"]:
print(f"✗ 错误: {stats['error']}")
return
print(f"文件路径: {stats['file_path']}")
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
if stats["last_updated"]:
print(f"最后更新: {stats['last_updated']}")
print("\n【记录统计】")
print(f" 文件中的 total_count: {stats['total_count']}")
print(f" 实际记录数: {stats['actual_count']}")
if stats['total_count'] != stats['actual_count']:
diff = stats['total_count'] - stats['actual_count']
if stats["total_count"] != stats["actual_count"]:
diff = stats["total_count"] - stats["actual_count"]
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
print("\n【评估结果统计】")
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
print("\n【唯一性统计】")
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}")
if stats['actual_count'] > 0:
duplicate_count = stats['actual_count'] - stats['unique_pairs']
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["actual_count"] > 0:
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
print("\n【评估者统计】")
if stats['evaluators']:
for evaluator, count in stats['evaluators'].most_common():
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["evaluators"]:
for evaluator, count in stats["evaluators"].most_common():
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
print("\n【时间统计】")
if stats['date_range']:
if stats["date_range"]:
print(f" 最早评估时间: {stats['date_range']['start']}")
print(f" 最晚评估时间: {stats['date_range']['end']}")
print(f" 评估时间跨度: {stats['date_range']['duration_days']}")
else:
print(" 无时间信息")
print("\n【字段统计】")
print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}")
print(f" 包含 reason: {'' if stats['has_reason'] else ''}")
if stats['has_reason']:
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["has_reason"]:
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
@ -195,35 +195,35 @@ def print_summary(all_stats: List[Dict]):
print(f"\n{'=' * 80}")
print("汇总统计")
print(f"{'=' * 80}")
total_files = len(all_stats)
valid_files = [s for s in all_stats if not s.get("error")]
error_files = [s for s in all_stats if s.get("error")]
print("\n【文件统计】")
print(f" 总文件数: {total_files}")
print(f" 成功解析: {len(valid_files)}")
print(f" 解析失败: {len(error_files)}")
if error_files:
print("\n 失败文件列表:")
for stats in error_files:
print(f" - {stats['file_name']}: {stats['error']}")
if not valid_files:
print("\n没有成功解析的文件")
return
# 汇总记录统计
total_records = sum(s['actual_count'] for s in valid_files)
total_suitable = sum(s['suitable_count'] for s in valid_files)
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
total_records = sum(s["actual_count"] for s in valid_files)
total_suitable = sum(s["suitable_count"] for s in valid_files)
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
total_unique_pairs = set()
# 收集所有唯一的(situation, style)对
for stats in valid_files:
try:
with open(stats['file_path'], "r", encoding="utf-8") as f:
with open(stats["file_path"], "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
for r in results:
@ -231,23 +231,31 @@ def print_summary(all_stats: List[Dict]):
total_unique_pairs.add((r["situation"], r["style"]))
except Exception:
pass
print("\n【记录汇总】")
print(f" 总记录数: {total_records:,}")
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
print(
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
if total_records > 0
else " 通过: 0 条"
)
print(
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
if total_records > 0
else " 不通过: 0 条"
)
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}")
if total_records > 0:
duplicate_count = total_records - len(total_unique_pairs)
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
# 汇总评估者统计
all_evaluators = Counter()
for stats in valid_files:
all_evaluators.update(stats['evaluators'])
all_evaluators.update(stats["evaluators"])
print("\n【评估者汇总】")
if all_evaluators:
for evaluator, count in all_evaluators.most_common():
@ -255,12 +263,12 @@ def print_summary(all_stats: List[Dict]):
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
# 汇总时间范围
all_dates = []
for stats in valid_files:
all_dates.extend(stats['evaluation_dates'])
all_dates.extend(stats["evaluation_dates"])
if all_dates:
min_date = min(all_dates)
max_date = max(all_dates)
@ -268,9 +276,9 @@ def print_summary(all_stats: List[Dict]):
print(f" 最早评估时间: {min_date.isoformat()}")
print(f" 最晚评估时间: {max_date.isoformat()}")
print(f" 总时间跨度: {(max_date - min_date).days + 1}")
# 文件大小汇总
total_size = sum(s['file_size'] for s in valid_files)
total_size = sum(s["file_size"] for s in valid_files)
avg_size = total_size / len(valid_files) if valid_files else 0
print("\n【文件大小汇总】")
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
@ -282,35 +290,35 @@ def main():
logger.info("=" * 80)
logger.info("开始分析评估结果统计信息")
logger.info("=" * 80)
if not os.path.exists(TEMP_DIR):
print(f"\n✗ 错误未找到temp目录: {TEMP_DIR}")
logger.error(f"未找到temp目录: {TEMP_DIR}")
return
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
print(f"\n✗ 错误temp目录下未找到JSON文件: {TEMP_DIR}")
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
return
json_files.sort() # 按文件名排序
print(f"\n找到 {len(json_files)} 个JSON文件")
print("=" * 80)
# 分析每个文件
all_stats = []
for i, json_file in enumerate(json_files, 1):
stats = analyze_single_file(json_file)
all_stats.append(stats)
print_file_stats(stats, index=i)
# 打印汇总统计
print_summary(all_stats)
print(f"\n{'=' * 80}")
print("分析完成")
print(f"{'=' * 80}")
@ -318,5 +326,3 @@ def main():
if __name__ == "__main__":
main()

View File

@ -171,7 +171,9 @@ def main():
sys.exit(1)
if not args.raw_index:
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
logger.info(
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
)
sys.exit(1)
# 解析索引列表1-based

View File

@ -22,11 +22,11 @@ from collections import defaultdict
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression
from src.common.database.database import db
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.database.database_model import Expression # noqa: E402
from src.common.database.database import db # noqa: E402
from src.common.logger import get_logger # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
logger = get_logger("expression_evaluator_count_analysis_llm")
@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(COUNT_ANALYSIS_FILE):
return [], set()
try:
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
def save_results(evaluation_results: List[Dict]):
"""
保存评估结果到文件
Args:
evaluation_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results),
"evaluation_results": evaluation_results
"evaluation_results": evaluation_results,
}
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
except Exception as e:
@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]):
print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation(
evaluated_pairs: Set[Tuple[str, str]] = None
) -> List[Expression]:
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
"""
选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目
Args:
evaluated_pairs: 已评估的项目集合用于避免重复
Returns:
选中的表达方式列表
"""
if evaluated_pairs is None:
evaluated_pairs = set()
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.warning("所有项目都已评估完成")
return []
# 按count分组
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
logger.info(f"未评估项目中count=1的有{len(count_eq1)}count>1的有{len(count_gt1)}")
# 选择所有count>1的项目
selected_count_gt1 = count_gt1.copy()
# 选择count=1的项目数量为count>1数量的2倍
count_gt1_count = len(selected_count_gt1)
count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed:
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}")
logger.warning(
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}"
)
count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序
logger.info(f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍")
logger.info(
f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍"
)
return selected
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
@ -155,11 +155,11 @@ def select_expressions_for_evaluation(
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
logger.info(
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
)
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": expression.situation,
"style": expression.style,
@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
"reason": reason,
"error": error,
"evaluator": "llm",
"evaluated_at": datetime.now().isoformat()
"evaluated_at": datetime.now().isoformat(),
}
def perform_statistical_analysis(evaluation_results: List[Dict]):
"""
对评估结果进行统计分析
Args:
evaluation_results: 评估结果列表
"""
if not evaluation_results:
print("\n没有评估结果可供分析")
return
print("\n" + "=" * 60)
print("统计分析结果")
print("=" * 60)
# 按count分组统计
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
count_groups[count]["suitable"] += 1
else:
count_groups[count]["unsuitable"] += 1
# 显示每个count的统计
print("\n【按count分组统计】")
print("-" * 60)
@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
suitable = group["suitable"]
unsuitable = group["unsuitable"]
pass_rate = (suitable / total * 100) if total > 0 else 0
print(f"Count = {count}:")
print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
print()
# 比较count=1和count>1
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
if count == 1:
count_eq1_group["total"] += 1
if suitable:
@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
count_gt1_group["suitable"] += 1
else:
count_gt1_group["unsuitable"] += 1
print("\n【Count=1 vs Count>1 对比】")
print("-" * 60)
eq1_total = count_eq1_group["total"]
eq1_suitable = count_eq1_group["suitable"]
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
gt1_total = count_gt1_group["total"]
gt1_suitable = count_gt1_group["suitable"]
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
print("Count = 1:")
print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
print()
print("Count > 1:")
print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
print()
# 进行卡方检验简化版使用2x2列联表
if eq1_total > 0 and gt1_total > 0:
print("【统计显著性检验】")
print("-" * 60)
# 构建2x2列联表
# 通过 不通过
# count=1 a b
@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
b = eq1_total - eq1_suitable
c = gt1_suitable
d = gt1_total - gt1_suitable
# 计算卡方统计量简化版使用Pearson卡方检验
n = eq1_total + gt1_total
if n > 0:
@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
e_b = (eq1_total * (b + d)) / n
e_c = (gt1_total * (a + c)) / n
e_d = (gt1_total * (b + d)) / n
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5
min_expected = min(e_a, e_b, e_c, e_d)
if min_expected < 5:
print("警告期望频数小于5卡方检验可能不准确")
print("建议使用Fisher精确检验")
# 计算卡方值
chi_square = 0
if e_a > 0:
@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
chi_square += ((c - e_c) ** 2) / e_c
if e_d > 0:
chi_square += ((d - e_d) ** 2) / e_d
# 自由度 = (行数-1) * (列数-1) = 1
df = 1
# 临界值(α=0.05
chi_square_critical_005 = 3.841
chi_square_critical_001 = 6.635
print(f"卡方统计量: {chi_square:.4f}")
print(f"自由度: {df}")
print(f"临界值 (α=0.05): {chi_square_critical_005}")
print(f"临界值 (α=0.01): {chi_square_critical_001}")
if chi_square >= chi_square_critical_001:
print("结论: 在α=0.01水平下count=1和count>1的合格率存在显著差异p<0.01")
elif chi_square >= chi_square_critical_005:
print("结论: 在α=0.05水平下count=1和count>1的合格率存在显著差异p<0.05")
else:
print("结论: 在α=0.05水平下count=1和count>1的合格率不存在显著差异p≥0.05")
# 计算差异大小
diff = abs(eq1_pass_rate - gt1_pass_rate)
print(f"\n合格率差异: {diff:.2f}%")
@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print("数据不足,无法进行统计检验")
else:
print("数据不足无法进行count=1和count>1的对比分析")
# 保存统计分析结果
analysis_result = {
"analysis_time": datetime.now().isoformat(),
"count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group,
"count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results)
"total_evaluated": len(evaluation_results),
}
try:
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
with open(analysis_file, "w", encoding="utf-8") as f:
@ -444,7 +444,7 @@ async def main():
logger.info("=" * 60)
logger.info("开始表达方式按count分组的LLM评估和统计分析")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
@ -452,97 +452,95 @@ async def main():
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
evaluation_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
# 检查是否需要继续评估检查是否还有未评估的count>1项目
# 先查询未评估的count>1项目数量
try:
all_expressions = list(Expression.select())
unevaluated_count_gt1 = [
expr for expr in all_expressions
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
]
has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e:
logger.error(f"查询未评估项目失败: {e}")
has_unevaluated = False
if has_unevaluated:
print("\n" + "=" * 60)
print("开始LLM评估")
print("=" * 60)
print("评估结果会自动保存到文件\n")
# 创建LLM实例
print("创建LLM实例...")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm"
request_type="expression_evaluator_count_analysis_llm",
)
print("✓ LLM实例创建成功\n")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}")
db.close()
return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation(
evaluated_pairs=evaluated_pairs
)
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
if not expressions:
print("\n没有可评估的项目")
else:
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)}")
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)}\n")
batch_results = []
for i, expression in enumerate(expressions, 1):
print(f"LLM评估进度: {i}/{len(expressions)}")
print(f" Situation: {expression.situation}")
print(f" Style: {expression.style}")
print(f" Count: {expression.count}")
llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get('error'):
if llm_result.get("error"):
print(f" 错误: {llm_result['error']}")
print()
batch_results.append(llm_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
# 添加延迟以避免API限流
await asyncio.sleep(0.3)
# 将当前批次结果添加到总结果中
evaluation_results.extend(batch_results)
# 保存结果
save_results(evaluation_results)
else:
print(f"\n所有count>1的项目都已评估完成已有 {len(evaluation_results)} 条评估结果")
# 进行统计分析
if len(evaluation_results) > 0:
perform_statistical_analysis(evaluation_results)
else:
print("\n没有评估结果可供分析")
# 关闭数据库连接
try:
db.close()
@ -553,4 +551,3 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())

View File

@ -20,9 +20,9 @@ from typing import List, Dict, Set, Tuple
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
from src.common.logger import get_logger # noqa: E402
logger = get_logger("expression_evaluator_llm")
@ -33,7 +33,7 @@ TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
def load_manual_results() -> List[Dict]:
"""
加载人工评估结果自动读取temp目录下所有JSON文件并合并
Returns:
人工评估结果列表已去重
"""
@ -42,62 +42,62 @@ def load_manual_results() -> List[Dict]:
print("\n✗ 错误未找到temp目录")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
print("\n✗ 错误temp目录下未找到JSON文件")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
logger.info(f"找到 {len(json_files)} 个JSON文件")
print(f"\n找到 {len(json_files)} 个JSON文件:")
for json_file in json_files:
print(f" - {os.path.basename(json_file)}")
# 读取并合并所有JSON文件
all_results = []
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
for json_file in json_files:
try:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
# 去重:使用(situation, style)作为唯一标识
for result in results:
if "situation" not in result or "style" not in result:
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
continue
pair = (result["situation"], result["style"])
if pair not in seen_pairs:
seen_pairs.add(pair)
all_results.append(result)
logger.info(f"{os.path.basename(json_file)} 加载了 {len(results)} 条结果")
except Exception as e:
logger.error(f"加载文件 {json_file} 失败: {e}")
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
continue
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
return all_results
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
@ -119,51 +119,50 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
@ -172,68 +171,68 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": situation,
"style": style,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm"
"evaluator": "llm",
}
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
"""
对比人工评估和LLM评估的结果
Args:
manual_results: 人工评估结果列表
llm_results: LLM评估结果列表
method_name: 评估方法名称用于标识
Returns:
对比分析结果字典
"""
# 按(situation, style)建立映射
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
total = len(manual_results)
matched = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for manual_result in manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable:
matched += 1
if manual_suitable and llm_suitable:
true_positives += 1
elif not manual_suitable and not llm_suitable:
@ -242,30 +241,36 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
false_positives += 1
elif manual_suitable and not llm_suitable:
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
precision = (
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
)
recall = (
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
)
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
specificity = (
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
)
# 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
# 计算经过LLM删除后剩余项目中的不合适率
# 在所有项目中移除LLM判定为不合适的项目后剩下的项目 = TP + FPLLM判定为合适的项目
# 在这些剩下的项目中,按人工评定的不合适项目 = FP人工认为不合适但LLM认为合适
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数保留的项目
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
# 两者百分比相减评估LLM评定修正后的不合适率是否有降低
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return {
"method": method_name,
"total": total,
@ -283,29 +288,29 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
"specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference
"rate_difference": rate_difference,
}
async def main(count: int | None = None):
"""
主函数
Args:
count: 随机选取的数据条数如果为None则使用全部数据
"""
logger.info("=" * 60)
logger.info("开始表达方式LLM评估")
logger.info("=" * 60)
# 1. 加载人工评估结果
print("\n步骤1: 加载人工评估结果")
manual_results = load_manual_results()
if not manual_results:
return
print(f"成功加载 {len(manual_results)} 条人工评估结果")
# 如果指定了数量,随机选择指定数量的数据
if count is not None:
if count <= 0:
@ -317,7 +322,7 @@ async def main(count: int | None = None):
random.seed() # 使用系统时间作为随机种子
manual_results = random.sample(manual_results, count)
print(f"随机选取 {len(manual_results)} 条数据进行评估")
# 验证数据完整性
valid_manual_results = []
for r in manual_results:
@ -325,62 +330,58 @@ async def main(count: int | None = None):
valid_manual_results.append(r)
else:
logger.warning(f"跳过无效数据: {r}")
if len(valid_manual_results) != len(manual_results):
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
print(f"有效数据: {len(valid_manual_results)}")
# 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_llm"
)
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
print("\n步骤3: 开始LLM评估")
llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm(
manual_result["situation"],
manual_result["style"],
llm
))
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
# 5.1 输出FP项目人工评估不通过但LLM误判为通过
print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60)
fp_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
fp_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1):
@ -389,36 +390,38 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'):
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 5.2 输出FN项目人工评估通过但LLM误判为不通过
print("\n" + "=" * 60)
print("人工评估通过但LLM误判为不通过的项目FN - False Negative")
print("=" * 60)
fn_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
fn_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
for idx, item in enumerate(fn_items, 1):
@ -427,33 +430,41 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}")
print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)")
if item.get('llm_error'):
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误删项目所有人工评估通过的项目都被LLM正确识别为通过")
# 6. 对比分析并输出结果
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)")
print("=" * 60)
# 详细评估结果(核心指标优先)
print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}")
print()
# print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}")
print(
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}")
print(
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
)
print(
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
@ -464,12 +475,18 @@ async def main(count: int | None = None):
print()
print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
print(
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
)
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
print(f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
)
print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
@ -480,21 +497,22 @@ async def main(count: int | None = None):
print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 7. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"manual_results": valid_manual_results,
"llm_results": llm_results,
"comparison": comparison
}, f, ensure_ascii=False, indent=2)
json.dump(
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
f,
ensure_ascii=False,
indent=2,
)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60)
print("评估完成")
print("=" * 60)
@ -509,15 +527,9 @@ if __name__ == "__main__":
python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
"""
""",
)
parser.add_argument(
"-n", "--count",
type=int,
default=None,
help="随机选取的数据条数(默认:使用全部数据)"
)
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
args = parser.parse_args()
asyncio.run(main(count=args.count))

View File

@ -18,9 +18,9 @@ from datetime import datetime
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression
from src.common.database.database import db
from src.common.logger import get_logger
from src.common.database.database_model import Expression # noqa: E402
from src.common.database.database import db # noqa: E402
from src.common.logger import get_logger # noqa: E402
logger = get_logger("expression_evaluator_manual")
@ -32,13 +32,13 @@ MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(MANUAL_EVAL_FILE):
return [], set()
try:
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
@ -55,22 +55,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
def save_results(manual_results: List[Dict]):
"""
保存评估结果到文件
Args:
manual_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(manual_results),
"manual_results": manual_results
"manual_results": manual_results,
}
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
except Exception as e:
@ -81,45 +81,43 @@ def save_results(manual_results: List[Dict]):
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
"""
获取未评估的表达方式
Args:
evaluated_pairs: 已评估的项目(situation, style)元组集合
batch_size: 每次获取的数量
Returns:
未评估的表达方式列表
"""
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.info("所有项目都已评估完成")
return []
# 如果未评估数量少于请求数量,返回所有
if len(unevaluated) <= batch_size:
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
return unevaluated
# 随机选择指定数量
selected = random.sample(unevaluated, batch_size)
logger.info(f"{len(unevaluated)} 条未评估项目中随机选择了 {len(selected)}")
return selected
except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
@ -127,12 +125,12 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
"""
人工评估单个表达方式
Args:
expression: 表达方式对象
index: 当前索引从1开始
total: 总数
Returns:
评估结果字典如果用户退出则返回 None
"""
@ -146,38 +144,38 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估")
print(" 输入 's''skip' 跳过当前项目")
while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ['q', 'quit']:
if user_input in ["q", "quit"]:
print("退出评估")
return None
if user_input in ['s', 'skip']:
if user_input in ["s", "skip"]:
print("跳过当前项目")
return "skip"
if user_input in ['y', 'yes', '1', '', '通过']:
if user_input in ["y", "yes", "1", "", "通过"]:
suitable = True
break
elif user_input in ['n', 'no', '0', '', '不通过']:
elif user_input in ["n", "no", "0", "", "不通过"]:
suitable = False
break
else:
print("输入无效,请重新输入 (y/n/q/s)")
result = {
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": None,
"evaluator": "manual",
"evaluated_at": datetime.now().isoformat()
"evaluated_at": datetime.now().isoformat(),
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result
@ -186,7 +184,7 @@ def main():
logger.info("=" * 60)
logger.info("开始表达方式人工评估")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
@ -194,41 +192,41 @@ def main():
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
manual_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
print("\n" + "=" * 60)
print("开始人工评估")
print("=" * 60)
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
print("评估结果会自动保存到文件\n")
batch_size = 10
batch_count = 0
while True:
# 获取未评估的项目
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
if not expressions:
print("\n" + "=" * 60)
print("所有项目都已评估完成!")
print("=" * 60)
break
batch_count += 1
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
batch_results = []
for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None:
# 用户退出
print("\n评估已中断")
@ -237,34 +235,34 @@ def main():
manual_results.extend(batch_results)
save_results(manual_results)
return
if manual_result == "skip":
# 跳过当前项目
continue
batch_results.append(manual_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
# 将当前批次结果添加到总结果中
manual_results.extend(batch_results)
# 保存结果
save_results(manual_results)
print(f"\n当前批次完成,已评估总数: {len(manual_results)}")
# 询问是否继续
while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ['y', 'yes', '1', '', '继续']:
if continue_input in ["y", "yes", "1", "", "继续"]:
break
elif continue_input in ['n', 'no', '0', '', '退出']:
elif continue_input in ["n", "no", "0", "", "退出"]:
print("\n评估结束")
return
else:
print("输入无效,请重新输入 (y/n)")
# 关闭数据库连接
try:
db.close()
@ -275,4 +273,3 @@ def main():
if __name__ == "__main__":
main()

View File

@ -134,9 +134,7 @@ def handle_import_openie(
# 在非交互模式下,不再询问用户,而是直接报错终止
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
if non_interactive:
logger.error(
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
)
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
sys.exit(1)
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower()
@ -189,9 +187,7 @@ def handle_import_openie(
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
# 新增确认提示
if non_interactive:
logger.warning(
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
)
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
else:
print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
def main(argv: Optional[list[str]] = None) -> None:
"""主函数 - 解析参数并运行异步主流程。"""
parser = argparse.ArgumentParser(
description=(
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
"将其导入到 LPMM 的向量库与知识图中。"
)
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
)
parser.add_argument(
"--non-interactive",

View File

@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
ensure_dirs() # 确保目录存在
# 新增用户确认提示
if non_interactive:
logger.warning(
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
)
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
else:
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@ -1,6 +1,5 @@
import os
import sys
from typing import Set
# 保证可以导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -32,7 +31,6 @@ def main() -> None:
# KG 统计
nodes = kg.graph.get_node_list()
edges = kg.graph.get_edge_list()
node_set: Set[str] = set(nodes)
para_nodes = [n for n in nodes if n.startswith("paragraph-")]
ent_nodes = [n for n in nodes if n.startswith("entity-")]
@ -68,4 +66,3 @@ def main() -> None:
if __name__ == "__main__":
main()

View File

@ -29,6 +29,7 @@ except ImportError as e:
logger = get_logger("lpmm_interactive_manager")
async def interactive_add():
"""交互式导入知识"""
print("\n" + "=" * 40)
@ -38,7 +39,7 @@ async def interactive_add():
print(" - 支持多段落,段落间请保留空行。")
print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。")
print("-" * 40)
lines = []
while True:
try:
@ -48,7 +49,7 @@ async def interactive_add():
lines.append(line)
except EOFError:
break
text = "\n".join(lines).strip()
if not text:
print("\n[!] 内容为空,操作已取消。")
@ -58,7 +59,7 @@ async def interactive_add():
try:
# 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.add_content(text)
if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}")
print(f" 实际新增段落数: {result.get('count', 0)}")
@ -68,6 +69,7 @@ async def interactive_add():
print(f"\n[×] 发生异常: {e}")
logger.error(f"add_content 异常: {e}", exc_info=True)
async def interactive_delete():
"""交互式删除知识"""
print("\n" + "=" * 40)
@ -77,10 +79,10 @@ async def interactive_delete():
print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)")
print(" 2. 完整文段匹配(删除完全匹配的段落)")
print("-" * 40)
mode = input("请选择删除模式 (1/2): ").strip()
exact_match = False
if mode == "2":
exact_match = True
print("\n[完整文段匹配模式]")
@ -102,14 +104,18 @@ async def interactive_delete():
print("\n[!] 无效选择,默认使用关键词模糊匹配模式。")
print("\n[关键词模糊匹配模式]")
keyword = input("请输入匹配关键词: ").strip()
if not keyword:
print("\n[!] 输入为空,操作已取消。")
return
print("-" * 40)
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
if confirm != 'y':
confirm = (
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
.strip()
.lower()
)
if confirm != "y":
print("\n[!] 已取消删除操作。")
return
@ -117,7 +123,7 @@ async def interactive_delete():
try:
# 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.delete(keyword, exact_match=exact_match)
if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}")
print(f" 删除条数: {result.get('deleted_count', 0)}")
@ -129,6 +135,7 @@ async def interactive_delete():
print(f"\n[×] 发生异常: {e}")
logger.error(f"delete 异常: {e}", exc_info=True)
async def interactive_clear():
"""交互式清空知识库"""
print("\n" + "=" * 40)
@ -141,40 +148,45 @@ async def interactive_clear():
print(" - 整个知识图谱")
print(" - 此操作不可恢复!")
print("-" * 40)
# 双重确认
confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip()
if confirm1 != "YES":
print("\n[!] 已取消清空操作。")
return
print("\n" + "=" * 40)
confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip()
if confirm2 != "CLEAR":
print("\n[!] 已取消清空操作。")
return
print("\n[进度] 正在清空知识库...")
try:
# 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.clear_all()
if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}")
stats = result.get("stats", {})
before = stats.get("before", {})
after = stats.get("after", {})
print("\n[统计信息]")
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
print(
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
)
print(
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
)
else:
print(f"\n[×] 失败:{result['message']}")
except Exception as e:
print(f"\n[×] 发生异常: {e}")
logger.error(f"clear_all 异常: {e}", exc_info=True)
async def interactive_search():
"""交互式查询知识"""
print("\n" + "=" * 40)
@ -182,25 +194,25 @@ async def interactive_search():
print("=" * 40)
print("说明:输入查询问题或关键词,系统会返回相关的知识段落。")
print("-" * 40)
# 确保 LPMM 已初始化
if not global_config.lpmm_knowledge.enable:
print("\n[!] 警告LPMM 知识库在配置中未启用。")
return
try:
lpmm_start_up()
except Exception as e:
print(f"\n[!] LPMM 初始化失败: {e}")
logger.error(f"LPMM 初始化失败: {e}", exc_info=True)
return
query = input("请输入查询问题或关键词: ").strip()
if not query:
print("\n[!] 查询内容为空,操作已取消。")
return
# 询问返回条数
print("-" * 40)
limit_str = input("希望返回的相关知识条数默认3直接回车使用默认值: ").strip()
@ -210,11 +222,11 @@ async def interactive_search():
except ValueError:
limit = 3
print("[!] 输入无效,使用默认值 3。")
print("\n[进度] 正在查询知识库...")
try:
result = await query_lpmm_knowledge(query, limit=limit)
print("\n" + "=" * 60)
print("[查询结果]")
print("=" * 60)
@ -224,6 +236,7 @@ async def interactive_search():
print(f"\n[×] 查询失败: {e}")
logger.error(f"查询异常: {e}", exc_info=True)
async def main():
"""主循环"""
while True:
@ -236,9 +249,9 @@ async def main():
print("║ 4. 清空知识库 (Clear All) ⚠️ ║")
print("║ 0. 退出 (Exit) ║")
print("" + "" * 38 + "")
choice = input("请选择操作编号: ").strip()
if choice == "1":
await interactive_add()
elif choice == "2":
@ -253,6 +266,7 @@ async def main():
else:
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
if __name__ == "__main__":
try:
# 运行主循环
@ -262,4 +276,3 @@ if __name__ == "__main__":
except Exception as e:
print(f"\n[!] 程序运行出错: {e}")
logger.error(f"Main loop 异常: {e}", exc_info=True)

View File

@ -21,18 +21,18 @@ PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
if PROJECT_ROOT not in sys.path:
sys.path.append(PROJECT_ROOT)
from src.common.logger import get_logger # type: ignore
from src.config.config import global_config, model_config # type: ignore
from src.common.logger import get_logger # type: ignore # noqa: E402
from src.config.config import global_config, model_config # type: ignore # noqa: E402
# 引入各功能脚本的入口函数
from import_openie import main as import_openie_main # type: ignore
from info_extraction import main as info_extraction_main # type: ignore
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore
from raw_data_preprocessor import load_raw_data # type: ignore
from import_openie import main as import_openie_main # type: ignore # noqa: E402
from info_extraction import main as info_extraction_main # type: ignore # noqa: E402
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore # noqa: E402
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore # noqa: E402
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore # noqa: E402
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore # noqa: E402
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore # noqa: E402
from raw_data_preprocessor import load_raw_data # type: ignore # noqa: E402
logger = get_logger("lpmm_manager")
@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
txt_files = list(raw_dir.glob("*.txt"))
if not txt_files:
msg = (
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
"info_extraction 可能立即退出或无数据可处理。"
)
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件info_extraction 可能立即退出或无数据可处理。"
print(msg)
if non_interactive:
logger.error(
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
)
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
return False
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
return cont == "y"
@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
json_files = list(openie_dir.glob("*.json"))
if not json_files:
msg = (
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
"import_openie 可能会因为找不到批次而失败。"
)
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件import_openie 可能会因为找不到批次而失败。"
print(msg)
if non_interactive:
logger.error(
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
)
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
return False
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
return cont == "y"
@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
try:
if not getattr(global_config.lpmm_knowledge, "enable", False):
print(
"[WARN] 当前配置 lpmm_knowledge.enable = false"
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
)
print("[WARN] 当前配置 lpmm_knowledge.enable = false刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
except Exception:
# 配置异常时不阻断主流程,仅忽略提示
pass
@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
if action == "prepare_raw":
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
sha_list, raw_data = load_raw_data()
print(
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
elif action == "info_extract":
if not _check_before_info_extract("--non-interactive" in extra_args):
print("已根据用户选择,取消执行信息提取。")
@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
sha_list, raw_data = load_raw_data()
print(
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
non_interactive = "--non-interactive" in extra_args
if not _check_before_info_extract(non_interactive):
print("已根据用户选择,取消 full_import信息提取阶段被取消")
@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
)
# 快速选项:按推荐方式清理所有相关实体/关系
quick_all = input(
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
).strip().lower()
quick_all = (
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
)
if quick_all in ("", "y", "yes"):
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
else:
@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
def _interactive_build_batch_inspect_args() -> List[str]:
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
path = _interactive_choose_openie_file(
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
)
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
if not path:
return []
return ["--openie-file", path]
@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
def _interactive_build_test_args() -> List[str]:
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
print(
"\n[TEST] 你可以:\n"
"- 直接回车使用内置的默认测试用例;\n"
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
)
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
if not query:
return []
@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
new_dim = input(
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
).strip()
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
if new_dim and not new_dim.isdigit():
print("输入的维度不是纯数字,已取消操作。")
return
@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
if __name__ == "__main__":
main()

View File

@ -28,53 +28,55 @@ from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
if hasattr(existing_module, "init_memory_retrieval_prompt"):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
existing_module._process_single_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
@ -89,11 +91,11 @@ def _import_memory_retrieval():
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
@ -126,16 +128,16 @@ def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStrea
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
@ -150,21 +152,21 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
@ -180,7 +182,7 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
@ -205,25 +207,25 @@ def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串"""
if not thinking_steps:
return "无思考步骤"
lines = []
for step in thinking_steps:
iteration = step.get("iteration", "?")
thought = step.get("thought", "")
actions = step.get("actions", [])
observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---")
if thought:
lines.append(f"思考: {thought[:200]}...")
if actions:
lines.append("行动:")
for action in actions:
action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations:
lines.append("观察:")
for obs in observations:
@ -231,7 +233,7 @@ def format_thinking_steps(thinking_steps: list) -> str:
if len(str(obs)) > 200:
obs_str += "..."
lines.append(f" - {obs_str}")
return "\n".join(lines)
@ -242,31 +244,32 @@ async def test_memory_retrieval(
max_iterations: Optional[int] = None,
) -> Dict[str, Any]:
"""测试记忆检索功能
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
包含测试结果的字典
"""
print("\n" + "=" * 80)
print(f"[测试] 记忆检索测试")
print("[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
# 记录开始时间
start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
@ -274,24 +277,24 @@ async def test_memory_retrieval(
except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise
# 获取 global_config此时应该已经加载
from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds
print(f"\n[配置]")
print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
# 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
@ -299,14 +302,14 @@ async def test_memory_retrieval(
timeout=timeout,
initial_info="",
)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
# 获取token使用情况
token_usage = get_token_usage_since(start_time)
# 构建结果
result = {
"question": question,
@ -318,41 +321,41 @@ async def test_memory_retrieval(
"iteration_count": len(thinking_steps),
"token_usage": token_usage,
}
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print(f"\n[结果]")
print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(f" 答案: (未找到答案)")
print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print(f"\n[Token使用情况]")
print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage['model_usage']:
print(f"\n[按模型统计]")
for model_name, usage in token_usage['model_usage'].items():
if token_usage["model_usage"]:
print("\n[按模型统计]")
for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print(f"\n[迭代详情]")
print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
return result
@ -375,12 +378,12 @@ def main() -> None:
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("记忆检索测试工具")
@ -389,7 +392,7 @@ def main() -> None:
if not question:
print("错误: 问题不能为空")
return
# 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None
@ -402,7 +405,7 @@ def main() -> None:
except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None
# 连接数据库
try:
db.connect(reuse_if_open=True)
@ -410,7 +413,7 @@ def main() -> None:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行测试
try:
result = asyncio.run(
@ -421,7 +424,7 @@ def main() -> None:
max_iterations=max_iterations,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
@ -429,7 +432,7 @@ def main() -> None:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
@ -444,4 +447,3 @@ def main() -> None:
if __name__ == "__main__":
main()

View File

@ -455,6 +455,7 @@ class ExpressionSelector:
expr_obj.save()
logger.debug("表达方式激活: 更新last_active_time in db")
try:
expression_selector = ExpressionSelector()
except Exception as e:

View File

@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
logger = get_logger("jargon")
class JargonExplainer:
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""

View File

@ -60,31 +60,31 @@ def calculate_style_similarity(style1: str, style2: str) -> float:
"""
计算两个 style 的相似度返回0-1之间的值
在计算前会移除"使用""句式"这两个词参考 expression_similarity_analysis.py
Args:
style1: 第一个 style
style2: 第二个 style
Returns:
float: 相似度值范围0-1
"""
if not style1 or not style2:
return 0.0
# 移除"使用"和"句式"这两个词
def remove_ignored_words(text: str) -> str:
"""移除需要忽略的词"""
text = text.replace("使用", "")
text = text.replace("句式", "")
return text.strip()
cleaned_style1 = remove_ignored_words(style1)
cleaned_style2 = remove_ignored_words(style2)
# 如果清理后文本为空返回0
if not cleaned_style1 or not cleaned_style2:
return 0.0
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
@ -495,4 +495,4 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
if content and source_id:
jargon_entries.append((content, source_id))
return expressions, jargon_entries
return expressions, jargon_entries

View File

@ -2,7 +2,6 @@ import time
import asyncio
from typing import List, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression
@ -119,9 +118,7 @@ class MessageRecorder:
# 触发 expression_learner 和 jargon_miner 的处理
if self.enable_expression_learning:
asyncio.create_task(
self._trigger_expression_learning(messages)
)
asyncio.create_task(self._trigger_expression_learning(messages))
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
@ -130,9 +127,7 @@ class MessageRecorder:
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self, messages: List[Any]
) -> None:
async def _trigger_expression_learning(self, messages: List[Any]) -> None:
"""
触发 expression 学习使用指定的消息列表

View File

@ -1,5 +1,5 @@
import time
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
from typing import Tuple, Optional # 增加了 Optional
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@ -120,7 +120,7 @@ class ActionPlanner:
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
@ -128,7 +128,7 @@ class ActionPlanner:
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};"
@ -170,13 +170,10 @@ class ActionPlanner:
)
break
else:
logger.debug(
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
)
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
except Exception as e:
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
# --- 获取超时提示信息 ---
# (这部分逻辑不变)
timeout_context = ""

View File

@ -112,10 +112,10 @@ class Conversation:
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
"platform": msg.user_info.platform if msg.user_info else "",
}
},
}
initial_messages_dict.append(msg_dict)
# 将加载的消息填充到 ObservationInfo 的 chat_history
self.observation_info.chat_history = initial_messages_dict
self.observation_info.chat_history_str = chat_talking_prompt + "\n"

View File

@ -66,9 +66,9 @@ class DirectMessageSender:
# 发送消息(直接调用底层 API
from src.chat.message_receive.uni_message_sender import _send_message
sent = await _send_message(message, show_log=True)
if sent:
# 存储消息
await self.storage.store_message(message, chat_stream)

View File

@ -5,7 +5,7 @@ from src.common.logger import get_logger
from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType, Notification
from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
from src.common.data_models.database_data_model import DatabaseMessages
import traceback # 导入 traceback 用于调试
logger = get_logger("observation_info")
@ -13,15 +13,15 @@ logger = get_logger("observation_info")
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
"""Convert PFC dict format to DatabaseMessages object
Args:
msg_dict: Message in PFC dict format with nested user_info
Returns:
DatabaseMessages object compatible with build_readable_messages()
"""
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
return DatabaseMessages(
message_id=msg_dict.get("message_id", ""),
time=msg_dict.get("time", 0.0),

View File

@ -42,9 +42,7 @@ class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
)
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname
@ -60,7 +58,7 @@ class GoalAnalyzer:
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
@ -68,7 +66,7 @@ class GoalAnalyzer:
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};"

View File

@ -1,13 +1,11 @@
from typing import List, Tuple, Dict, Any
from src.common.logger import get_logger
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
# from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.message_receive.message import Message
from src.config.config import model_config
from src.chat.knowledge import qa_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
logger = get_logger("knowledge_fetcher")
@ -16,9 +14,7 @@ class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils
)
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
self.private_name = private_name
def _lpmm_get_knowledge(self, query: str) -> str:
@ -50,13 +46,7 @@ class KnowledgeFetcher:
Returns:
Tuple[str, str]: (获取的知识, 知识来源)
"""
db_messages = [dict_to_database_message(m) for m in chat_history]
chat_history_text = build_readable_messages(
db_messages,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
)
_ = chat_history
# NOTE: Hippocampus memory system was redesigned in v0.12.2
# The old get_memory_from_text API no longer exists
@ -64,7 +54,7 @@ class KnowledgeFetcher:
# TODO: Integrate with new memory system if needed
knowledge_text = ""
sources_text = "无记忆匹配" # 默认值
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
# text=f"{query}\n{chat_history_text}",

View File

@ -14,10 +14,7 @@ class ReplyChecker:
"""回复检查器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="reply_check"
)
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname
self.private_name = private_name
@ -27,7 +24,7 @@ class ReplyChecker:
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
@ -35,7 +32,7 @@ class ReplyChecker:
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};"

View File

@ -99,7 +99,7 @@ class ReplyGenerator:
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
@ -107,7 +107,7 @@ class ReplyGenerator:
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};"

View File

@ -704,10 +704,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断
try:
await asyncio.wait_for(
self._new_message_event.wait(),
timeout=wait_seconds
)
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
# 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError:
@ -731,7 +728,9 @@ class BrainChatting:
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)")
logger.info(
f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds} 秒(可被新消息打断)"
)
# 清除事件状态,准备等待新消息
self._new_message_event.clear()
@ -749,10 +748,7 @@ class BrainChatting:
# 等待指定时间,但可被新消息打断
try:
await asyncio.wait_for(
self._new_message_event.wait(),
timeout=wait_seconds
)
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
# 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError:

View File

@ -431,15 +431,21 @@ class BrainPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
return extracted_reasoning, [
ActionPlannerInfo(
action_type="complete_talk",
reasoning=extracted_reasoning,
action_data={},
action_message=None,
available_actions=available_actions,
)
], llm_content, llm_reasoning, llm_duration_ms
return (
extracted_reasoning,
[
ActionPlannerInfo(
action_type="complete_talk",
reasoning=extracted_reasoning,
action_data={},
action_message=None,
available_actions=available_actions,
)
],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应
if llm_content:

View File

@ -105,7 +105,7 @@ class EmbeddingStore:
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
self.dirty = False # 标记是否有新增数据需要重建索引
# 多线程配置参数验证和设置

View File

@ -1,7 +1,7 @@
import asyncio
import json
import time
from typing import List, Union, Dict, Any
from typing import List, Union
from .global_logger import logger
from . import prompt_template
@ -192,17 +192,15 @@ class IEProcess:
results = []
total = len(paragraphs)
for i, pg in enumerate(paragraphs, start=1):
# 打印进度日志,让用户知道没有卡死
logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...")
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
try:
entities, triples = await asyncio.to_thread(
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
)
entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
if entities is not None:
results.append(

View File

@ -395,8 +395,7 @@ class KGManager:
appear_cnt = self.ent_appear_cnt.get(ent_hash)
if not appear_cnt or appear_cnt <= 0:
logger.debug(
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0"
f"将使用 1.0 作为默认出现次数参与权重计算"
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0将使用 1.0 作为默认出现次数参与权重计算"
)
appear_cnt = 1.0
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)

View File

@ -11,31 +11,30 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API")
class LPMMOperations:
"""
LPMM 内部操作接口
封装了 LPMM 的核心操作供插件系统 API 或其他内部组件调用
"""
def __init__(self):
self._initialized = False
async def _run_cancellable_executor(
self, func: Callable, *args, **kwargs
) -> Any:
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
"""
在线程池中执行可取消的同步操作
当任务被取消时 Ctrl+C会立即响应并抛出 CancelledError
注意线程池中的操作可能仍在运行但协程会立即返回不会阻塞主进程
Args:
func: 要执行的同步函数
*args: 函数的位置参数
**kwargs: 函数的关键字参数
Returns:
函数的返回值
Raises:
asyncio.CancelledError: 当任务被取消时
"""
@ -51,42 +50,42 @@ class LPMMOperations:
# 如果全局没初始化,尝试初始化
if not global_config.lpmm_knowledge.enable:
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
lpmm_start_up()
qa_mgr = get_qa_manager()
if qa_mgr is None:
raise RuntimeError("无法获取 LPMM QAManager请检查 LPMM 是否已正确安装和配置。")
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
async def add_content(self, text: str, auto_split: bool = True) -> dict:
"""
向知识库添加新内容
Args:
text: 原始文本
auto_split: 是否自动按双换行符分割段落
- True: 自动分割默认支持多段文本用双换行分隔
- False: 不分割将整个文本作为完整一段处理
Returns:
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 分段处理
if auto_split:
# 自动按双换行符分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
else:
# 不分割,作为完整一段
text_stripped = text.strip()
if not text_stripped:
return {"status": "error", "message": "文本内容为空"}
paragraphs = [text_stripped]
if not paragraphs:
return {"status": "error", "message": "文本内容为空"}
@ -94,14 +93,16 @@ class LPMMOperations:
from src.chat.knowledge.ie_process import IEProcess
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
llm_ner = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf)
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
extracted_docs = await ie_process.process_paragraphs(paragraphs)
# 3. 构造并导入数据
# 这里我们手动实现导入逻辑,不依赖外部脚本
# a. 准备段落
@ -115,7 +116,7 @@ class LPMMOperations:
# store_new_data_set 期望的格式raw_paragraphs 的键是段落hash不带前缀值是段落文本
new_raw_paragraphs = {}
new_triple_list_data = {}
for pg_hash, passage in raw_paragraphs.items():
key = f"paragraph-{pg_hash}"
if key not in embed_mgr.stored_pg_hashes:
@ -128,26 +129,22 @@ class LPMMOperations:
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor(
embed_mgr.store_new_data_set,
new_raw_paragraphs,
new_triple_list_data
)
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
# 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor(
kg_mgr.build_kg,
new_triple_list_data,
embed_mgr
)
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
# 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
return {
"status": "success",
"count": len(new_raw_paragraphs),
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断")
return {"status": "cancelled", "message": "导入操作已被用户中断"}
@ -158,11 +155,11 @@ class LPMMOperations:
async def search(self, query: str, top_k: int = 3) -> List[str]:
"""
检索知识库
Args:
query: 查询问题
top_k: 返回最相关的条目数
Returns:
List[str]: 相关文段列表
"""
@ -179,21 +176,21 @@ class LPMMOperations:
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
"""
根据关键词或完整文段删除知识库内容
Args:
keyword: 匹配关键词或完整文段
exact_match: 是否使用完整文段匹配True=完全匹配False=关键词模糊匹配
Returns:
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 查找匹配的段落
to_delete_keys = []
to_delete_hashes = []
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
if exact_match:
# 完整文段匹配
@ -205,29 +202,25 @@ class LPMMOperations:
if keyword in item.str:
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
if not to_delete_keys:
match_type = "完整文段" if exact_match else "关键词"
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
# 2. 执行删除
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
to_delete_keys
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
)
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
# b. 从知识图谱删除
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
to_delete_hashes,
ent_hashes=None,
remove_orphan_entities=True
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
@ -235,9 +228,13 @@ class LPMMOperations:
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词"
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
return {
"status": "success",
"deleted_count": deleted_count,
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断")
@ -249,13 +246,13 @@ class LPMMOperations:
async def clear_all(self) -> dict:
"""
清空整个LPMM知识库删除所有段落实体关系和知识图谱数据
Returns:
dict: {"status": "success/error", "message": "描述", "stats": {...}}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 记录清空前的统计信息
before_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
@ -264,40 +261,37 @@ class LPMMOperations:
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# 1. 清空所有向量库
# 获取所有keys
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
# 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
para_keys
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
)
embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量
if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items,
ent_keys
embed_mgr.entities_embedding_store.delete_items, ent_keys
)
else:
ent_deleted = 0
# 删除所有关系向量
if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items,
rel_keys
embed_mgr.relation_embedding_store.delete_items, rel_keys
)
else:
rel_deleted = 0
# 2. 清空所有 embedding store 的索引和映射
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
def _clear_embedding_indices():
@ -310,7 +304,7 @@ class LPMMOperations:
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
# 清空实体索引
embed_mgr.entities_embedding_store.faiss_index = None
embed_mgr.entities_embedding_store.idx2hash = None
@ -320,7 +314,7 @@ class LPMMOperations:
os.remove(embed_mgr.entities_embedding_store.index_file_path)
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
# 清空关系索引
embed_mgr.relation_embedding_store.faiss_index = None
embed_mgr.relation_embedding_store.idx2hash = None
@ -330,9 +324,9 @@ class LPMMOperations:
os.remove(embed_mgr.relation_embedding_store.index_file_path)
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
await self._run_cancellable_executor(_clear_embedding_indices)
# 3. 清空知识图谱
# 获取所有段落hash
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
@ -341,24 +335,22 @@ class LPMMOperations:
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
all_pg_hashes,
ent_hashes=None,
remove_orphan_entities=True
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear()
# 4. 保存所有数据此时所有store都是空的索引也是None
# 注意即使store为空save_to_file也会保存空的DataFrame这是正确的
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
after_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
@ -366,14 +358,14 @@ class LPMMOperations:
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
return {
"status": "success",
"message": f"已成功清空LPMM知识库删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
"stats": {
"before": before_stats,
"after": after_stats,
}
},
}
except asyncio.CancelledError:
@ -383,6 +375,6 @@ class LPMMOperations:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
# 内部使用的单例
lpmm_ops = LPMMOperations()

View File

@ -136,4 +136,3 @@ class PlanReplyLogger:
return str(value)
# Fallback to string for other complex types
return str(value)

View File

@ -85,17 +85,17 @@ class ChatBot:
async def _create_pfc_chat(self, message: MessageRecv):
"""创建或获取PFC对话实例
Args:
message: 消息对象
"""
try:
chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname)
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e:
logger.error(f"创建PFC聊天失败: {e}")
logger.error(traceback.format_exc())

View File

@ -96,7 +96,7 @@ class Message(MessageBase):
if processed_text:
return f"{global_config.bot.nickname}: {processed_text}"
return None
tasks = [process_forward_node(node_dict) for node_dict in segment.data]
results = await asyncio.gather(*tasks, return_exceptions=True)
segments_text = []

View File

@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 如果未开启 API Server直接跳过 Fallback
if not global_config.maim_message.enable_api_server:
logger.debug(f"[API Server Fallback] API Server未开启跳过fallback")
logger.debug("[API Server Fallback] API Server未开启跳过fallback")
if legacy_exception:
raise legacy_exception
return False
@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
extra_server = getattr(global_api, "extra_server", None)
if not extra_server:
logger.warning(f"[API Server Fallback] extra_server不存在")
logger.warning("[API Server Fallback] extra_server不存在")
if legacy_exception:
raise legacy_exception
return False
if not extra_server.is_running():
logger.warning(f"[API Server Fallback] extra_server未运行")
logger.warning("[API Server Fallback] extra_server未运行")
if legacy_exception:
raise legacy_exception
return False
@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
)
# 直接调用 Server 的 send_message 接口,它会自动处理路由
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...")
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
results = await extra_server.send_message(api_message)
logger.debug(f"[API Server Fallback] 发送结果: {results}")

View File

@ -35,6 +35,7 @@ logger = get_logger("planner")
install(extra_lines=3)
class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id
@ -48,7 +49,7 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
# 黑话缓存:使用 OrderedDict 实现 LRU最多缓存10个
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
self.unknown_words_cache_limit = 10
@ -111,20 +112,29 @@ class ActionPlanner:
# 替换 [picid:xxx] 为 [图片:描述]
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(pic_match: re.Match) -> str:
pic_id = pic_match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
platform = (
getattr(message, "user_info", None)
and message.user_info.platform
or getattr(message, "chat_info", None)
and message.chat_info.platform
or "qq"
)
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
# 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式)
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
# 这里匹配到的应该都是单独的格式
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
def replace_user_ref(user_match: re.Match) -> str:
user_name = user_match.group(1)
user_id = user_match.group(2)
@ -137,6 +147,7 @@ class ActionPlanner:
except Exception:
# 如果解析失败,使用原始昵称
return user_name
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
@ -165,7 +176,7 @@ class ActionPlanner:
else:
reasoning = "未提供原因"
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
# 非no_reply动作需要target_message_id
target_message = None
@ -244,7 +255,7 @@ class ActionPlanner:
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
"""
更新黑话缓存将新的黑话加入缓存
Args:
new_words: 新提取的黑话列表
"""
@ -254,7 +265,7 @@ class ActionPlanner:
word = word.strip()
if not word:
continue
# 如果已存在移到末尾LRU
if word in self.unknown_words_cache:
self.unknown_words_cache.move_to_end(word)
@ -269,10 +280,10 @@ class ActionPlanner:
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
"""
合并新提取的黑话和缓存中的黑话
Args:
new_words: 新提取的黑话列表可能为None
Returns:
合并后的黑话列表去重
"""
@ -284,31 +295,29 @@ class ActionPlanner:
word = word.strip()
if word:
cleaned_new_words.append(word)
# 获取缓存中的黑话列表
cached_words = list(self.unknown_words_cache.keys())
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
merged_words: List[str] = []
seen = set()
# 先添加新提取的
for word in cleaned_new_words:
if word not in seen:
merged_words.append(word)
seen.add(word)
# 再添加缓存的(如果不在新提取的列表中)
for word in cached_words:
if word not in seen:
merged_words.append(word)
seen.add(word)
return merged_words
def _process_unknown_words_cache(
self, actions: List[ActionPlannerInfo]
) -> None:
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
"""
处理黑话缓存逻辑
1. 检查是否有 reply action 提取了 unknown_words
@ -316,7 +325,7 @@ class ActionPlanner:
3. 如果缓存数量大于5移除最老的2个
4. 对于每个 reply action合并缓存和新提取的黑话
5. 更新缓存
Args:
actions: 解析后的动作列表
"""
@ -330,7 +339,7 @@ class ActionPlanner:
removed_count += 1
if removed_count > 0:
logger.debug(f"{self.log_prefix}缓存数量大于5移除最老的{removed_count}个缓存")
# 检查是否有 reply action 提取了 unknown_words
has_extracted_unknown_words = False
for action in actions:
@ -340,22 +349,22 @@ class ActionPlanner:
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
has_extracted_unknown_words = True
break
# 如果当前 plan 的 reply 没有提取移除最老的1个
if not has_extracted_unknown_words:
if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
# 对于每个 reply action合并缓存和新提取的黑话
for action in actions:
if action.action_type == "reply":
action_data = action.action_data or {}
new_words = action_data.get("unknown_words")
# 合并新提取的和缓存的黑话列表
merged_words = self._merge_unknown_words_with_cache(new_words)
# 更新 action_data
if merged_words:
action_data["unknown_words"] = merged_words
@ -366,7 +375,7 @@ class ActionPlanner:
else:
# 如果没有合并后的黑话,移除 unknown_words 字段
action_data.pop("unknown_words", None)
# 更新缓存(将新提取的黑话加入缓存)
if new_words:
self._update_unknown_words_cache(new_words)
@ -442,15 +451,19 @@ class ActionPlanner:
# 检查是否已经有回复该消息的 action
has_reply_to_force_message = False
for action in actions:
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
if (
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
):
has_reply_to_force_message = True
break
# 如果没有回复该消息,强制添加回复 action
if not has_reply_to_force_message:
# 移除所有 no_reply action如果有
actions = [a for a in actions if a.action_type != "no_reply"]
# 创建强制回复 action
available_actions_dict = dict(current_available_actions)
force_reply_action = ActionPlannerInfo(
@ -577,10 +590,11 @@ class ActionPlanner:
if global_config.chat.think_mode == "classic":
reply_action_example = ""
if global_config.chat.llm_quote:
reply_action_example += "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
"5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"]'
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
)
if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"'
@ -590,7 +604,9 @@ class ActionPlanner:
"5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n"
)
if global_config.chat.llm_quote:
reply_action_example += "6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
"6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
)
reply_action_example += (
'{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", '
@ -741,15 +757,21 @@ class ActionPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}", [
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
action_data={},
action_message=None,
available_actions=available_actions,
)
], llm_content, llm_reasoning, llm_duration_ms
return (
f"LLM 请求失败,模型出现问题: {req_e}",
[
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
action_data={},
action_message=None,
available_actions=available_actions,
)
],
llm_content,
llm_reasoning,
llm_duration_ms,
)
# 解析LLM响应
extracted_reasoning = ""

View File

@ -1071,7 +1071,6 @@ class DefaultReplyer:
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style
multi_styles = global_config.personality.multiple_reply_style

View File

@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
)
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType
@ -807,7 +808,7 @@ class PrivateReplyer:
reply_style = global_config.personality.reply_style
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(platform, user_id):
prompt_template = prompt_manager.get_prompt("private_replyer_self")
prompt_template.add_context("target", target)

View File

@ -519,7 +519,7 @@ def _build_readable_messages_internal(
output_lines: List[str] = []
prev_timestamp: Optional[float] = None
for timestamp, name, content, is_action in detailed_message:
for timestamp, name, content, _is_action in detailed_message:
# 检查是否需要插入长时间间隔提示
if long_time_notice and prev_timestamp is not None:
time_diff = timestamp - prev_timestamp

View File

@ -5,6 +5,7 @@ from src.common.logger import get_logger
logger = get_logger("common_utils")
class TempMethodsExpression:
"""用于临时存放一些方法的类"""

View File

@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
from . import BaseDatabaseDataModel
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
self.session_id = session_id
@ -33,4 +34,4 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
platform=self.platform,
user_id=self.user_id,
group_id=self.group_id,
)
)

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple, Union

View File

@ -221,5 +221,7 @@ if not supports_truecolor():
CONVERTED_MODULE_COLORS[name] = escape_str
else:
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold)
CONVERTED_MODULE_COLORS[name] = escape_str
escape_str = rgb_pair_to_ansi_truecolor(
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
)
CONVERTED_MODULE_COLORS[name] = escape_str

View File

@ -9,6 +9,7 @@ from .server import get_global_server
global_api = None
def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例"""
global global_api
@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
return False
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
# 3. Setup Message Bridge
# Initialize refined route map if not exists
if not hasattr(global_api, "platform_map"):
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
if platform:
global_api.platform_map[platform] = api_key # type: ignore
global_api.platform_map[platform] = api_key # type: ignore
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
except Exception as e:
api_logger.warning(f"Failed to update platform map: {e}")
@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
if "raw_message" not in msg_dict:
msg_dict["raw_message"] = None
await global_api.process_message(msg_dict) # type: ignore
await global_api.process_message(msg_dict) # type: ignore
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
# 3.5. Register custom message handlers (bridge to Legacy handlers)
# message_id_echo: handles message ID echo from adapters
# 兼容新旧两个版本的 maim_message:
# - 旧版: handler(payload)
# - 新版: handler(payload, metadata)
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
# Bridge to the Legacy custom handler registered in main.py
try:
# The Legacy handler expects the payload format directly
if hasattr(global_api, "_custom_message_handlers"):
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
if handler:
await handler(payload)
api_logger.debug(f"Processed message_id_echo: {payload}")
@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
except Exception as e:
api_logger.warning(f"Failed to process message_id_echo: {e}")
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
# 4. Initialize Server
extra_server = WebSocketServer(config=server_config)
@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
global_api.stop = patched_stop
# Attach for reference
global_api.extra_server = extra_server # type: ignore # 这是什么
global_api.extra_server = extra_server # type: ignore # 这是什么
except ImportError:
get_logger("maim_message").error(

View File

@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
logger = get_logger("file_utils")
class FileUtils:
@staticmethod
def save_binary_to_file(file_path: Path, data: bytes):
@ -35,7 +36,7 @@ class FileUtils:
except Exception as e:
logger.error(f"保存文件 {file_path} 失败: {e}")
raise e
@staticmethod
def get_file_path_by_hash(data_hash: str) -> Path:
"""
@ -52,4 +53,4 @@ class FileUtils:
if binary_data := session.exec(statement).first():
return Path(binary_data.full_path)
else:
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")

View File

@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
reason = ",".join(reasons)
return MigrationResult(data=data, migrated=migrated_any, reason=reason)

View File

@ -86,8 +86,8 @@ def init_dream_tools(chat_id: str) -> None:
finish_maintenance = make_finish_maintenance(chat_id)
search_jargon = make_search_jargon(chat_id)
delete_jargon = make_delete_jargon(chat_id)
update_jargon = make_update_jargon(chat_id)
_delete_jargon = make_delete_jargon(chat_id)
_update_jargon = make_update_jargon(chat_id)
_dream_tool_registry.register_tool(
DreamTool(

View File

@ -54,8 +54,6 @@ async def generate_dream_summary(
) -> None:
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
try:
# 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {}
for msg in conversation_messages:

View File

@ -4,4 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册
"""

View File

@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}"
return create_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}"
return delete_chat_history

View File

@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}"
return delete_jargon

View File

@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg
return finish_maintenance

View File

@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail

View File

@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}"
return search_chat_history

View File

@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}"
return update_chat_history

View File

@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}"
return update_jargon

View File

@ -458,8 +458,8 @@ def _default_normal_response_parser(
if not isinstance(arguments, dict):
# 此时为了调试方便,建议打印出 arguments 的类型
raise RespParseException(
resp,
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
resp,
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
)
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e:

View File

@ -2,7 +2,7 @@ import time
import json
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager
@ -34,7 +34,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
try:
with get_db_session() as session:
statement = select(ThinkingQuestion).where(
(ThinkingQuestion.found_answer == False)
col(ThinkingQuestion.found_answer).is_(False)
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
)
records = session.exec(statement).all()
@ -786,8 +786,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
str: 格式化的查询历史字符串
"""
try:
current_time = time.time()
start_time = current_time - time_window_seconds
_current_time = time.time()
with get_db_session() as session:
statement = (
@ -838,15 +837,14 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
List[str]: 格式化的答案列表每个元素格式为 "问题xxx\n答案xxx"
"""
try:
current_time = time.time()
start_time = current_time - time_window_seconds
_current_time = time.time()
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.where(col(ThinkingQuestion.found_answer) == True)
.where(col(ThinkingQuestion.found_answer))
.where(col(ThinkingQuestion.answer).is_not(None))
.where(col(ThinkingQuestion.answer) != "")
.order_by(col(ThinkingQuestion.updated_timestamp).desc())

View File

@ -105,25 +105,27 @@ async def search_chat_history(
# 检查参数
if not keyword and not participant and not start_time and not end_time:
return "未指定查询参数需要提供keyword、participant、start_time或end_time之一"
# 解析时间参数
start_timestamp = None
end_timestamp = None
if start_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
start_timestamp = parse_datetime_to_timestamp(start_time)
except ValueError as e:
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
if end_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
end_timestamp = parse_datetime_to_timestamp(end_time)
except ValueError as e:
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
# 验证时间范围
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
return "开始时间不能晚于结束时间"
@ -158,23 +160,20 @@ async def search_chat_history(
f"search_chat_history 当前聊天流在黑名单中强制使用本地查询chat_id={chat_id}, keyword={keyword}, participant={participant}"
)
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 添加时间过滤条件
if start_timestamp is not None and end_timestamp is not None:
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
query = query.where(
(
(ChatHistory.start_time >= start_timestamp)
& (ChatHistory.start_time <= end_timestamp)
(ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
) # 记录开始时间在查询时间段内
| (
(ChatHistory.end_time >= start_timestamp)
& (ChatHistory.end_time <= end_timestamp)
(ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
) # 记录结束时间在查询时间段内
| (
(ChatHistory.start_time <= start_timestamp)
& (ChatHistory.end_time >= end_timestamp)
(ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
) # 记录完全包含查询时间段
)
logger.debug(
@ -302,7 +301,7 @@ async def search_chat_history(
time_desc = f"时间<='{end_str}'"
if time_desc:
conditions.append(time_desc)
if conditions:
conditions_str = "".join(conditions)
return f"未找到满足条件({conditions_str})的聊天记录"

View File

@ -30,7 +30,7 @@ async def query_words(chat_id: str, words: str) -> str:
if separator in words:
words_list = [w.strip() for w in words.split(separator) if w.strip()]
break
# 如果没有找到分隔符,整个字符串作为一个词语
if not words_list:
words_list = [words.strip()]
@ -76,4 +76,3 @@ def register_tool():
],
execute_func=query_words,
)

View File

@ -123,7 +123,7 @@ async def generate_reply(
# 如果 reply_time_point 未传入,设置为当前时间戳
if reply_time_point is None:
reply_time_point = time.time()
# 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)

View File

@ -39,6 +39,18 @@ def unregister_service(service_name: str, plugin_name: Optional[str] = None) ->
return plugin_service_registry.unregister_service(service_name, plugin_name)
async def call_service(service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
async def call_service(
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务。"""
return await plugin_service_registry.call_service(service_name, *args, plugin_name=plugin_name, **kwargs)
return await plugin_service_registry.call_service(
service_name,
*args,
plugin_name=plugin_name,
caller_plugin=caller_plugin,
**kwargs,
)

View File

@ -558,7 +558,9 @@ class PluginBase(ABC):
if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok:
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
)
return False
if min_version or max_version:

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict
from typing import Any, Dict, List
@dataclass
@ -11,6 +11,8 @@ class PluginServiceInfo:
version: str = "1.0.0"
description: str = ""
enabled: bool = True
public: bool = False
allowed_callers: List[str] = field(default_factory=list)
params_schema: Dict[str, Any] = field(default_factory=dict)
return_schema: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)

View File

@ -274,6 +274,23 @@ class ComponentRegistry:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False
async def remove_components_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有组件。"""
targets = [
(component_info.name, component_info.component_type)
for component_info in self._components.values()
if component_info.plugin_name == plugin_name
]
removed_count = 0
for component_name, component_type in targets:
if await self.remove_component(component_name, component_type, plugin_name):
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
return removed_count
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
@ -734,9 +751,7 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": {
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
},
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
}

View File

@ -200,13 +200,43 @@ class PluginManager:
"""
重载插件模块
"""
old_instance = self.loaded_plugins.get(plugin_name)
if not old_instance:
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
return False
if not await self.remove_registered_plugin(plugin_name):
return False
if not self.load_registered_plugin_classes(plugin_name)[0]:
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
if rollback_ok:
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
else:
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
return False
logger.debug(f"插件 {plugin_name} 重载成功")
return True
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
"""重载失败后回滚旧实例。"""
try:
await component_registry.remove_components_by_plugin(plugin_name)
component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
if not old_instance.register_plugin():
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
return False
self.loaded_plugins[plugin_name] = old_instance
return True
except Exception as e:
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
return False
def rescan_plugin_directory(self) -> Tuple[int, int]:
"""
重新扫描插件根目录
@ -399,7 +429,9 @@ class PluginManager:
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
indegree: Dict[str, int] = {
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items():

View File

@ -26,6 +26,9 @@ class PluginServiceRegistry:
if "." in service_info.plugin_name:
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
return False
full_name = service_info.full_name
if full_name in self._services:
@ -52,7 +55,9 @@ class PluginServiceRegistry:
full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
def list_services(
self, plugin_name: Optional[str] = None, enabled_only: bool = False
) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
services = self._services.copy()
if plugin_name:
@ -103,12 +108,33 @@ class PluginServiceRegistry:
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
return removed_count
async def call_service(self, service_name: str, *args: Any, plugin_name: Optional[str] = None, **kwargs: Any) -> Any:
async def call_service(
self,
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务(支持同步/异步handler"""
service_info = self.get_service(service_name, plugin_name)
if not service_info:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}")
if (
"." not in service_name
and plugin_name is None
and caller_plugin
and service_info.plugin_name != caller_plugin
):
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin):
raise PermissionError(
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
)
if not service_info.enabled:
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
@ -116,8 +142,93 @@ class PluginServiceRegistry:
if not handler:
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
self._validate_input_contract(service_info, args, kwargs)
result = handler(*args, **kwargs)
return await result if inspect.isawaitable(result) else result
resolved_result = await result if inspect.isawaitable(result) else result
self._validate_output_contract(service_info, resolved_result)
return resolved_result
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
"""检查服务调用权限。"""
if caller_plugin is None:
return service_info.public
if caller_plugin == service_info.plugin_name:
return True
if service_info.public:
return True
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""校验服务入参契约。"""
schema = service_info.params_schema
if not schema:
return
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
is_invocation_schema = "args" in properties or "kwargs" in properties
if is_invocation_schema:
payload = {"args": list(args), "kwargs": kwargs}
self._validate_by_schema(payload, schema, path="params")
return
if args:
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
self._validate_by_schema(kwargs, schema, path="params")
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
"""校验服务返回值契约。"""
if not service_info.return_schema:
return
self._validate_by_schema(value, service_info.return_schema, path="return")
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
"""基于简化JSON-Schema校验数据。"""
expected_type = schema.get("type")
if expected_type:
self._validate_type(value, expected_type, path)
enum_values = schema.get("enum")
if enum_values is not None and value not in enum_values:
raise ValueError(f"{path} 不在枚举范围内: {value}")
if expected_type == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
for field in required:
if field not in value:
raise ValueError(f"{path}.{field} 为必填字段")
for field, field_value in value.items():
if field in properties:
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
elif schema.get("additionalProperties", True) is False:
raise ValueError(f"{path}.{field} 不允许额外字段")
if expected_type == "array":
if item_schema := schema.get("items"):
for index, item in enumerate(value):
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
"""校验基础类型。"""
type_checkers: Dict[str, Callable[[Any], bool]] = {
"string": lambda item: isinstance(item, str),
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
"boolean": lambda item: isinstance(item, bool),
"object": lambda item: isinstance(item, dict),
"array": lambda item: isinstance(item, list),
"null": lambda item: item is None,
}
checker = type_checkers.get(expected_type)
if checker and not checker(value):
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
"""解析服务全名。"""

View File

@ -1,7 +1,8 @@
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import asyncio
import inspect
import time
import uuid
import inspect
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, MaiMessages
@ -95,7 +96,9 @@ class WorkflowEngine:
except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
logger.error(
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
)
self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
@ -144,11 +147,19 @@ class WorkflowEngine:
step_timing_key = f"{stage.value}:{step_info.full_name}"
step_start = time.perf_counter()
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
try:
result = handler(context, message)
if inspect.isawaitable(result):
result = await result
if inspect.iscoroutinefunction(handler):
coroutine = handler(context, message)
result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
else:
if timeout_seconds:
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
else:
result = handler(context, message)
if inspect.isawaitable(result):
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
context.timings[step_timing_key] = time.perf_counter() - step_start
normalized_result = self._normalize_step_result(result)
@ -165,10 +176,30 @@ class WorkflowEngine:
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
return normalized_result
except asyncio.TimeoutError:
context.timings[step_timing_key] = time.perf_counter() - step_start
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
context.errors.append(f"{step_info.full_name}: {timeout_message}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
)
return WorkflowStepResult(
status="failed",
return_message=timeout_message,
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
},
)
except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}")
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
)
return WorkflowStepResult(
status="failed",
return_message=str(e),

View File

@ -117,7 +117,7 @@ class PromptManager:
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
"""
添加一个上下文构造函数
Args:
name (str): 上下文名称
func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数接受 Prompt 名称作为参数返回字符串或返回字符串的协程
@ -144,7 +144,7 @@ class PromptManager:
def get_prompt(self, prompt_name: str) -> Prompt:
"""
获取指定名称的 Prompt 实例的克隆
Args:
prompt_name (str): 要获取的 Prompt 名称
Returns:
@ -161,7 +161,7 @@ class PromptManager:
async def render_prompt(self, prompt: Prompt) -> str:
"""
渲染一个 Prompt 实例
Args:
prompt (Prompt): 要渲染的 Prompt 实例
Returns:

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
plan_count: int
latest_timestamp: float
@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
class PlanLogSummary(BaseModel):
"""规划日志摘要"""
chat_id: str
timestamp: float
filename: str
@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
class PlanLogDetail(BaseModel):
"""规划日志详情"""
type: str
chat_id: str
timestamp: float
@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计"""
total_chats: int
total_plans: int
chats: List[ChatSummary]
@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表"""
data: List[PlanLogSummary]
total: int
page: int
@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
@ -86,41 +92,39 @@ async def get_planner_overview():
"""
if not PLAN_LOG_DIR.exists():
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
chats = []
total_plans = 0
for chat_dir in PLAN_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
plan_count = len(json_files)
total_plans += plan_count
if plan_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
chats.append(
ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name,
)
)
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview(
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
@ -128,7 +132,7 @@ async def get_chat_plan_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
):
"""
获取指定聊天的规划日志列表分页
@ -137,73 +141,69 @@ async def get_chat_plan_logs(
"""
chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedChatLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get('prompt', '')
prompt = data.get("prompt", "")
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
page_files = json_files[offset : offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
actions = data.get('actions', [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
logs.append(PlanLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
reasoning_preview=reasoning[:100] if reasoning else ''
))
reasoning = data.get("reasoning", "")
actions = data.get("actions", [])
action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
logs.append(
PlanLogSummary(
chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
reasoning_preview=reasoning[:100] if reasoning else "",
)
)
except Exception:
# 文件读取失败时使用文件名信息
logs.append(PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview='[读取失败]'
))
return PaginatedChatLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
logs.append(
PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview="[读取失败]",
)
)
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
@ -212,22 +212,23 @@ async def get_log_detail(chat_id: str, filename: str):
log_file = PLAN_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
return PlanLogDetail(**data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e
# ========== 兼容旧接口 ==========
@router.get("/stats")
async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口"""
overview = await get_planner_overview()
# 获取最近10条计划的摘要
recent_plans = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
@ -236,17 +237,17 @@ async def get_planner_stats():
recent_plans.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
recent_plans = recent_plans[:10]
return {
"total_chats": overview.total_chats,
"total_plans": overview.total_plans,
"avg_plan_time_ms": 0,
"avg_llm_time_ms": 0,
"recent_plans": recent_plans
"recent_plans": recent_plans,
}
@ -258,44 +259,43 @@ async def get_chat_list():
@router.get("/all-logs")
async def get_all_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
"""获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size}
# 收集所有文件
all_files = []
for chat_dir in PLAN_LOG_DIR.iterdir():
if chat_dir.is_dir():
for log_file in chat_dir.glob("*.json"):
all_files.append((chat_dir.name, log_file))
# 按时间戳排序
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
total = len(all_files)
offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size]
page_files = all_files[offset : offset + page_size]
logs = []
for chat_id, log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
logs.append({
"chat_id": data.get('chat_id', chat_id),
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get('actions', [])),
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
"reasoning_preview": reasoning[:100] if reasoning else ''
})
reasoning = data.get("reasoning", "")
logs.append(
{
"chat_id": data.get("chat_id", chat_id),
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get("actions", [])),
"total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
"llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
"reasoning_preview": reasoning[:100] if reasoning else "",
}
)
except Exception:
continue
return {"data": logs, "total": total, "page": page, "page_size": page_size}
return {"data": logs, "total": total, "page": page, "page_size": page_size}

View File

@ -7,6 +7,7 @@
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
reply_count: int
latest_timestamp: float
@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
class ReplyLogSummary(BaseModel):
"""回复日志摘要"""
chat_id: str
timestamp: float
filename: str
@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
class ReplyLogDetail(BaseModel):
"""回复日志详情"""
type: str
chat_id: str
timestamp: float
@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计"""
total_chats: int
total_replies: int
chats: List[ReplierChatSummary]
@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表"""
data: List[ReplyLogSummary]
total: int
page: int
@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
timestamp_str = filename.split("_")[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
@ -89,41 +95,39 @@ async def get_replier_overview():
"""
if not REPLY_LOG_DIR.exists():
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
chats = []
total_replies = 0
for chat_dir in REPLY_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
reply_count = len(json_files)
total_replies += reply_count
if reply_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
chats.append(
ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name,
)
)
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview(
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
@ -131,7 +135,7 @@ async def get_chat_reply_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
):
"""
获取指定聊天的回复日志列表分页
@ -140,71 +144,67 @@ async def get_chat_reply_logs(
"""
chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedReplyLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get('prompt', '')
prompt = data.get("prompt", "")
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
page_files = json_files[offset : offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
output = data.get('output', '')
logs.append(ReplyLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get('model', ''),
success=data.get('success', True),
llm_ms=data.get('timing', {}).get('llm_ms', 0),
overall_ms=data.get('timing', {}).get('overall_ms', 0),
output_preview=output[:100] if output else ''
))
output = data.get("output", "")
logs.append(
ReplyLogSummary(
chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get("model", ""),
success=data.get("success", True),
llm_ms=data.get("timing", {}).get("llm_ms", 0),
overall_ms=data.get("timing", {}).get("overall_ms", 0),
output_preview=output[:100] if output else "",
)
)
except Exception:
# 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model='',
success=False,
llm_ms=0,
overall_ms=0,
output_preview='[读取失败]'
))
return PaginatedReplyLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
logs.append(
ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model="",
success=False,
llm_ms=0,
overall_ms=0,
output_preview="[读取失败]",
)
)
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
@ -213,35 +213,36 @@ async def get_reply_log_detail(chat_id: str, filename: str):
log_file = REPLY_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
with open(log_file, "r", encoding="utf-8") as f:
data = json.load(f)
return ReplyLogDetail(
type=data.get('type', 'reply'),
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', 0),
prompt=data.get('prompt', ''),
output=data.get('output', ''),
processed_output=data.get('processed_output', []),
model=data.get('model', ''),
reasoning=data.get('reasoning', ''),
think_level=data.get('think_level', 0),
timing=data.get('timing', {}),
error=data.get('error'),
success=data.get('success', True)
type=data.get("type", "reply"),
chat_id=data.get("chat_id", chat_id),
timestamp=data.get("timestamp", 0),
prompt=data.get("prompt", ""),
output=data.get("output", ""),
processed_output=data.get("processed_output", []),
model=data.get("model", ""),
reasoning=data.get("reasoning", ""),
think_level=data.get("think_level", 0),
timing=data.get("timing", {}),
error=data.get("error"),
success=data.get("success", True),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}") from e
# ========== 兼容接口 ==========
@router.get("/stats")
async def get_replier_stats():
"""获取回复器统计信息"""
overview = await get_replier_overview()
# 获取最近10条回复的摘要
recent_replies = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
@ -250,15 +251,15 @@ async def get_replier_stats():
recent_replies.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
recent_replies = recent_replies[:10]
return {
"total_chats": overview.total_chats,
"total_replies": overview.total_replies,
"recent_replies": recent_replies
"recent_replies": recent_replies,
}
@ -266,4 +267,4 @@ async def get_replier_stats():
async def get_replier_chat_list():
"""获取所有聊天ID列表"""
overview = await get_replier_overview()
return [chat.chat_id for chat in overview.chats]
return [chat.chat_id for chat in overview.chats]

View File

@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Depends, Cookie, Header, Request, HTTPException
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
from fastapi import Depends, Cookie, Header, Request
from .core import get_current_token, get_token_manager, check_auth_rate_limit
async def require_auth(

View File

@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
# loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
@ -151,7 +152,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 跳过注释行(以#开头)
if ip_entry.startswith("#"):
continue
@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
def _get_anti_crawler_config():
"""获取防爬虫配置"""
from src.config.config import global_config
return {
'mode': global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff
"mode": global_config.webui.anti_crawler_mode,
"allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
"trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
"trust_xff": global_config.webui.trust_xff,
}
# 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode']
ALLOWED_IPS = _config['allowed_ips']
TRUSTED_PROXIES = _config['trusted_proxies']
TRUST_XFF = _config['trust_xff']
ANTI_CRAWLER_MODE = _config["mode"]
ALLOWED_IPS = _config["allowed_ips"]
TRUSTED_PROXIES = _config["trusted_proxies"]
TRUST_XFF = _config["trust_xff"]
def _get_mode_config(mode: str) -> dict:

View File

@ -333,7 +333,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_at) == True,
col(Messages.is_at),
)
data.at_count = int(session.exec(statement).first() or 0)
@ -342,7 +342,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_mentioned) == True,
col(Messages.is_mentioned),
)
data.mentioned_count = int(session.exec(statement).first() or 0)
@ -552,7 +552,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
# 1. 表情包之王 - 使用次数最多的表情包
with get_db_session() as session:
statement = (
select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5)
select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
)
top_emojis = session.exec(statement).all()
if top_emojis:
@ -636,7 +636,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_picture) == True,
col(Messages.is_picture),
)
data.image_processed_count = int(session.exec(statement).first() or 0)
@ -781,12 +781,12 @@ async def get_achievements(year: int = 2025) -> AchievementData:
# 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话
with get_db_session() as session:
statement = select(func.count()).where(col(Jargon.is_jargon) == True)
statement = select(func.count()).where(col(Jargon.is_jargon))
data.new_jargon_count = int(session.exec(statement).first() or 0)
# 2. 代表性黑话示例
with get_db_session() as session:
statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5)
statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5)
jargon_samples = session.exec(statement).all()
data.sample_jargons = [
{

View File

@ -532,7 +532,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
.select_from(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_registered) == True,
col(Images.is_registered),
)
)
banned_statement = (
@ -540,7 +540,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
.select_from(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == True,
col(Images.is_banned),
)
)
@ -1283,7 +1283,7 @@ async def preheat_thumbnail_cache(
select(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == False,
col(Images.is_banned).is_(False),
)
.order_by(col(Images.query_count).desc())
.limit(limit * 2)

View File

@ -315,15 +315,15 @@ async def get_jargon_stats():
total = session.exec(select(fn.count()).select_from(Jargon)).one()
confirmed_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))
).one()
confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
complete_count = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))
).one()
chat_count = session.exec(

View File

@ -17,36 +17,36 @@ _paragraph_store_cache = None
def _get_paragraph_store():
"""延迟加载段落 embedding store只读模式轻量级
Returns:
EmbeddingStore | None: 如果配置启用则返回store否则返回None
"""
# 检查配置是否启用
if not global_config.webui.enable_paragraph_content:
return None
global _paragraph_store_cache
if _paragraph_store_cache is not None:
return _paragraph_store_cache
try:
from src.chat.knowledge.embedding_store import EmbeddingStore
import os
# 获取数据路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
embedding_dir = os.path.join(root_path, "data/embedding")
# 只加载段落 embedding store轻量级
paragraph_store = EmbeddingStore(
namespace="paragraph",
dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程
chunk_size=100
chunk_size=100,
)
paragraph_store.load_from_file()
_paragraph_store_cache = paragraph_store
logger.info(f"成功加载段落 embedding store包含 {len(paragraph_store.store)} 个段落")
return paragraph_store
@ -57,10 +57,10 @@ def _get_paragraph_store():
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
"""从 embedding store 获取段落完整内容
Args:
node_id: 段落节点ID格式为 'paragraph-{hash}'
Returns:
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
"""
@ -69,12 +69,12 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
if paragraph_store is None:
# 功能未启用
return None, False
# 从 store 中获取完整内容
paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '')
content: str = getattr(paragraph_item, "str", "")
if content:
return content, True
return None, True
@ -156,14 +156,18 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
@ -245,14 +249,18 @@ async def get_knowledge_graph(
try:
node_data = graph[node_id]
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
@ -368,11 +376,15 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
try:
node_data = graph[node_id]
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
content = (
full_content
if full_content is not None
else (node_data["content"] if "content" in node_data else node_id)
)
else:
content = node_data["content"] if "content" in node_data else node_id

View File

@ -370,7 +370,7 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
with get_db_session() as session:
total = len(session.exec(select(PersonInfo.id)).all())
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all())
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all())
unknown = total - known
# 按平台统计

View File

@ -1762,7 +1762,7 @@ async def update_plugin_config_raw(
try:
tomlkit.loads(request.config)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 备份旧配置
import shutil

View File

@ -659,4 +659,4 @@ def get_git_mirror_service() -> GitMirrorService:
global _git_mirror_service
if _git_mirror_service is None:
_git_mirror_service = GitMirrorService()
return _git_mirror_service
return _git_mirror_service