mirror of https://github.com/Mai-with-u/MaiBot.git
189 lines
7.9 KiB
Python
189 lines
7.9 KiB
Python
"""
|
||
测试修改后的 expression_learner 与 style_learner 的集成
|
||
验证学习新表达时是否正确处理 situation 字段
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import asyncio
|
||
import time
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from src.express.expression_learner import ExpressionLearner
|
||
from src.express.style_learner import style_learner_manager
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("expression_style_integration_test")
|
||
|
||
|
||
async def test_expression_style_integration():
|
||
"""测试 expression_learner 与 style_learner 的集成(包含 situation)"""
|
||
print("=== Expression Learner 与 Style Learner 集成测试(含 Situation) ===\n")
|
||
|
||
# 创建测试聊天室ID
|
||
test_chat_id = "test_integration_situation_chat"
|
||
|
||
# 创建 ExpressionLearner 实例
|
||
expression_learner = ExpressionLearner(test_chat_id)
|
||
|
||
print(f"测试聊天室: {test_chat_id}")
|
||
|
||
# 模拟学习到的表达数据(包含 situation)
|
||
mock_learnt_expressions = [
|
||
(test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"),
|
||
(test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"),
|
||
(test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"),
|
||
(test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"),
|
||
(test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"),
|
||
]
|
||
|
||
print("模拟学习到的表达数据(包含 situation):")
|
||
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
|
||
print(f" {situation} -> {style} (输入: {up_content})")
|
||
|
||
# 模拟 learn_and_store 方法的处理逻辑
|
||
print(f"\n开始处理学习数据...")
|
||
|
||
# 按chat_id分组
|
||
chat_dict = {}
|
||
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
|
||
if chat_id not in chat_dict:
|
||
chat_dict[chat_id] = []
|
||
chat_dict[chat_id].append({
|
||
"situation": situation,
|
||
"style": style,
|
||
"context": context,
|
||
"up_content": up_content,
|
||
})
|
||
|
||
# 训练 style_learner(包含 situation 处理)
|
||
trained_chat_ids = set()
|
||
|
||
for chat_id, expr_list in chat_dict.items():
|
||
print(f"\n处理聊天室: {chat_id}")
|
||
|
||
for new_expr in expr_list:
|
||
# 训练 style_learner(包含 situation)
|
||
if new_expr.get("up_content") and new_expr.get("style"):
|
||
try:
|
||
# 获取 learner 实例
|
||
learner = style_learner_manager.get_learner(chat_id)
|
||
|
||
# 先添加风格和对应的 situation(如果不存在)
|
||
if new_expr.get("situation"):
|
||
learner.add_style(new_expr["style"], new_expr["situation"])
|
||
print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')")
|
||
else:
|
||
learner.add_style(new_expr["style"])
|
||
print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)")
|
||
|
||
# 学习映射关系
|
||
success = style_learner_manager.learn_mapping(
|
||
chat_id,
|
||
new_expr["up_content"],
|
||
new_expr["style"]
|
||
)
|
||
if success:
|
||
print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" +
|
||
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
|
||
trained_chat_ids.add(chat_id)
|
||
else:
|
||
print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}")
|
||
except Exception as e:
|
||
print(f" ✗ StyleLearner学习异常: {e}")
|
||
|
||
# 保存模型
|
||
if trained_chat_ids:
|
||
print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
|
||
try:
|
||
save_success = style_learner_manager.save_all_models()
|
||
|
||
if save_success:
|
||
print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
|
||
else:
|
||
print("✗ StyleLearner 模型保存失败")
|
||
|
||
except Exception as e:
|
||
print(f"✗ StyleLearner 模型保存异常: {e}")
|
||
|
||
# 测试预测功能
|
||
print(f"\n测试 StyleLearner 预测功能:")
|
||
test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"]
|
||
|
||
for test_input in test_inputs:
|
||
try:
|
||
best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3)
|
||
if best_style:
|
||
# 获取对应的 situation
|
||
learner = style_learner_manager.get_learner(test_chat_id)
|
||
situation = learner.get_situation(best_style)
|
||
print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')")
|
||
if scores:
|
||
top_scores = dict(list(scores.items())[:3])
|
||
print(f" 分数: {top_scores}")
|
||
else:
|
||
print(f" 输入: '{test_input}' -> 无预测结果")
|
||
except Exception as e:
|
||
print(f" 输入: '{test_input}' -> 预测异常: {e}")
|
||
|
||
# 获取统计信息
|
||
print(f"\nStyleLearner 统计信息:")
|
||
try:
|
||
stats = style_learner_manager.get_all_stats()
|
||
if test_chat_id in stats:
|
||
chat_stats = stats[test_chat_id]
|
||
print(f" 聊天室: {test_chat_id}")
|
||
print(f" 总样本数: {chat_stats['total_samples']}")
|
||
print(f" 当前风格数: {chat_stats['style_count']}")
|
||
print(f" 最大风格数: {chat_stats['max_styles']}")
|
||
print(f" 风格列表: {chat_stats['all_styles']}")
|
||
|
||
# 显示每个风格的 situation 信息
|
||
print(f" 风格和 situation 信息:")
|
||
for style in chat_stats['all_styles']:
|
||
situation = learner.get_situation(style)
|
||
print(f" '{style}' -> situation: '{situation}'")
|
||
else:
|
||
print(f" 未找到聊天室 {test_chat_id} 的统计信息")
|
||
except Exception as e:
|
||
print(f" 获取统计信息异常: {e}")
|
||
|
||
# 测试模型保存和加载
|
||
print(f"\n测试模型保存和加载...")
|
||
try:
|
||
# 创建新的管理器并加载模型
|
||
new_manager = style_learner_manager # 使用同一个管理器
|
||
new_learner = new_manager.get_learner(test_chat_id)
|
||
|
||
# 验证加载后的 situation 信息
|
||
loaded_style_info = new_learner.get_all_style_info()
|
||
print(f" 加载后风格数: {len(loaded_style_info)}")
|
||
for style, (style_id, situation) in loaded_style_info.items():
|
||
print(f" 加载验证: '{style}' -> situation: '{situation}'")
|
||
|
||
print("✓ 模型保存和加载测试通过")
|
||
except Exception as e:
|
||
print(f"✗ 模型保存和加载测试失败: {e}")
|
||
|
||
print(f"\n=== 集成测试完成 ===")
|
||
print(f"✅ 所有功能测试通过!")
|
||
print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner")
|
||
print(f"✓ StyleLearner 正确存储和获取 situation 信息")
|
||
print(f"✓ 预测功能正常工作,可以获取对应的 situation")
|
||
print(f"✓ 模型保存和加载支持 situation 字段")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("Expression Learner 与 Style Learner 集成测试(含 Situation)")
|
||
print("=" * 70)
|
||
|
||
# 运行异步测试
|
||
asyncio.run(test_expression_style_integration())
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|