mirror of https://github.com/Mai-with-u/MaiBot.git
94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
import asyncio
|
||
import os
|
||
import sys
|
||
from typing import List, Dict, Any
|
||
|
||
# 强制使用 utf-8,避免控制台编码报错影响 Embedding 加载
|
||
try:
|
||
if hasattr(sys.stdout, "reconfigure"):
|
||
sys.stdout.reconfigure(encoding="utf-8")
|
||
if hasattr(sys.stderr, "reconfigure"):
|
||
sys.stderr.reconfigure(encoding="utf-8")
|
||
except Exception:
|
||
pass
|
||
|
||
# 确保能导入 src.*
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config
|
||
from src.chat.knowledge import lpmm_start_up
|
||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
|
||
|
||
logger = get_logger("test_lpmm_retrieval")
|
||
|
||
|
||
TEST_CASES: List[Dict[str, Any]] = [
|
||
{
|
||
"name": "回滚一批知识",
|
||
"query": "LPMM是什么?",
|
||
"expect_keywords": ["哈希列表", "删除脚本", "OpenIE"],
|
||
},
|
||
{
|
||
"name": "调整 LPMM 检索参数",
|
||
"query": "不同用词习惯带来的检索偏差该如何解决",
|
||
"expect_keywords": ["bot_config.toml", "lpmm_knowledge", "qa_paragraph_search_top_k"],
|
||
},
|
||
]
|
||
|
||
|
||
async def run_tests() -> None:
|
||
"""简单测试 LPMM 知识库检索能力"""
|
||
if not global_config.lpmm_knowledge.enable:
|
||
logger.warning("当前配置中 lpmm_knowledge.enable 为 False,检索测试可能直接返回“未启用”。")
|
||
|
||
logger.info("开始初始化 LPMM 知识库...")
|
||
lpmm_start_up()
|
||
logger.info("LPMM 知识库初始化完成,开始执行测试用例。")
|
||
|
||
for case in TEST_CASES:
|
||
name = case["name"]
|
||
query = case["query"]
|
||
expect_keywords: List[str] = case.get("expect_keywords", [])
|
||
|
||
print("\n" + "=" * 60)
|
||
print(f"[TEST] {name}")
|
||
print(f"[Q] {query}")
|
||
|
||
result = await query_lpmm_knowledge(query, limit=3)
|
||
|
||
print("\n[RAW RESULT]")
|
||
print(result)
|
||
|
||
status = "UNKNOWN"
|
||
hit_keywords: List[str] = []
|
||
|
||
if isinstance(result, str):
|
||
if "未启用" in result or "未初始化" in result or "查询失败" in result:
|
||
status = "ERROR"
|
||
elif "未找到与" in result:
|
||
status = "NO_HIT"
|
||
else:
|
||
if expect_keywords:
|
||
hit_keywords = [kw for kw in expect_keywords if kw in result]
|
||
status = "PASS" if hit_keywords else "WARN"
|
||
else:
|
||
status = "PASS"
|
||
|
||
print("\n[CHECK]")
|
||
print(f"Status: {status}")
|
||
if expect_keywords:
|
||
print(f"Expected keywords: {expect_keywords}")
|
||
print(f"Hit keywords: {hit_keywords}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("LPMM 检索测试完成。请根据每条用例的 Status 和命中关键词判断检索效果是否符合预期。")
|
||
|
||
|
||
def main() -> None:
|
||
asyncio.run(run_tests())
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|