MaiBot/scripts/test_lpmm_retrieval.py

94 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import os
import sys
from typing import List, Dict, Any
# 强制使用 utf-8避免控制台编码报错影响 Embedding 加载
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import lpmm_start_up
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
logger = get_logger("test_lpmm_retrieval")
TEST_CASES: List[Dict[str, Any]] = [
{
"name": "回滚一批知识",
"query": "LPMM是什么?",
"expect_keywords": ["哈希列表", "删除脚本", "OpenIE"],
},
{
"name": "调整 LPMM 检索参数",
"query": "不同用词习惯带来的检索偏差该如何解决",
"expect_keywords": ["bot_config.toml", "lpmm_knowledge", "qa_paragraph_search_top_k"],
},
]
async def run_tests() -> None:
"""简单测试 LPMM 知识库检索能力"""
if not global_config.lpmm_knowledge.enable:
logger.warning("当前配置中 lpmm_knowledge.enable 为 False检索测试可能直接返回“未启用”。")
logger.info("开始初始化 LPMM 知识库...")
lpmm_start_up()
logger.info("LPMM 知识库初始化完成,开始执行测试用例。")
for case in TEST_CASES:
name = case["name"]
query = case["query"]
expect_keywords: List[str] = case.get("expect_keywords", [])
print("\n" + "=" * 60)
print(f"[TEST] {name}")
print(f"[Q] {query}")
result = await query_lpmm_knowledge(query, limit=3)
print("\n[RAW RESULT]")
print(result)
status = "UNKNOWN"
hit_keywords: List[str] = []
if isinstance(result, str):
if "未启用" in result or "未初始化" in result or "查询失败" in result:
status = "ERROR"
elif "未找到与" in result:
status = "NO_HIT"
else:
if expect_keywords:
hit_keywords = [kw for kw in expect_keywords if kw in result]
status = "PASS" if hit_keywords else "WARN"
else:
status = "PASS"
print("\n[CHECK]")
print(f"Status: {status}")
if expect_keywords:
print(f"Expected keywords: {expect_keywords}")
print(f"Hit keywords: {hit_keywords}")
print("\n" + "=" * 60)
print("LPMM 检索测试完成。请根据每条用例的 Status 和命中关键词判断检索效果是否符合预期。")
def main() -> None:
asyncio.run(run_tests())
if __name__ == "__main__":
main()