MaiBot/test_expression_style_situa...

189 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
测试修改后的 expression_learner 与 style_learner 的集成
验证学习新表达时是否正确处理 situation 字段
"""
import os
import sys
import asyncio
import time
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.express.expression_learner import ExpressionLearner
from src.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("expression_style_integration_test")
async def test_expression_style_integration():
"""测试 expression_learner 与 style_learner 的集成(包含 situation"""
print("=== Expression Learner 与 Style Learner 集成测试(含 Situation ===\n")
# 创建测试聊天室ID
test_chat_id = "test_integration_situation_chat"
# 创建 ExpressionLearner 实例
expression_learner = ExpressionLearner(test_chat_id)
print(f"测试聊天室: {test_chat_id}")
# 模拟学习到的表达数据(包含 situation
mock_learnt_expressions = [
(test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"),
(test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"),
(test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"),
(test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"),
(test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"),
]
print("模拟学习到的表达数据(包含 situation:")
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
print(f" {situation} -> {style} (输入: {up_content})")
# 模拟 learn_and_store 方法的处理逻辑
print(f"\n开始处理学习数据...")
# 按chat_id分组
chat_dict = {}
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({
"situation": situation,
"style": style,
"context": context,
"up_content": up_content,
})
# 训练 style_learner包含 situation 处理)
trained_chat_ids = set()
for chat_id, expr_list in chat_dict.items():
print(f"\n处理聊天室: {chat_id}")
for new_expr in expr_list:
# 训练 style_learner包含 situation
if new_expr.get("up_content") and new_expr.get("style"):
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果不存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')")
else:
learner.add_style(new_expr["style"])
print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)")
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
print(f" ✗ StyleLearner学习异常: {e}")
# 保存模型
if trained_chat_ids:
print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
try:
save_success = style_learner_manager.save_all_models()
if save_success:
print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
else:
print("✗ StyleLearner 模型保存失败")
except Exception as e:
print(f"✗ StyleLearner 模型保存异常: {e}")
# 测试预测功能
print(f"\n测试 StyleLearner 预测功能:")
test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"]
for test_input in test_inputs:
try:
best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3)
if best_style:
# 获取对应的 situation
learner = style_learner_manager.get_learner(test_chat_id)
situation = learner.get_situation(best_style)
print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')")
if scores:
top_scores = dict(list(scores.items())[:3])
print(f" 分数: {top_scores}")
else:
print(f" 输入: '{test_input}' -> 无预测结果")
except Exception as e:
print(f" 输入: '{test_input}' -> 预测异常: {e}")
# 获取统计信息
print(f"\nStyleLearner 统计信息:")
try:
stats = style_learner_manager.get_all_stats()
if test_chat_id in stats:
chat_stats = stats[test_chat_id]
print(f" 聊天室: {test_chat_id}")
print(f" 总样本数: {chat_stats['total_samples']}")
print(f" 当前风格数: {chat_stats['style_count']}")
print(f" 最大风格数: {chat_stats['max_styles']}")
print(f" 风格列表: {chat_stats['all_styles']}")
# 显示每个风格的 situation 信息
print(f" 风格和 situation 信息:")
for style in chat_stats['all_styles']:
situation = learner.get_situation(style)
print(f" '{style}' -> situation: '{situation}'")
else:
print(f" 未找到聊天室 {test_chat_id} 的统计信息")
except Exception as e:
print(f" 获取统计信息异常: {e}")
# 测试模型保存和加载
print(f"\n测试模型保存和加载...")
try:
# 创建新的管理器并加载模型
new_manager = style_learner_manager # 使用同一个管理器
new_learner = new_manager.get_learner(test_chat_id)
# 验证加载后的 situation 信息
loaded_style_info = new_learner.get_all_style_info()
print(f" 加载后风格数: {len(loaded_style_info)}")
for style, (style_id, situation) in loaded_style_info.items():
print(f" 加载验证: '{style}' -> situation: '{situation}'")
print("✓ 模型保存和加载测试通过")
except Exception as e:
print(f"✗ 模型保存和加载测试失败: {e}")
print(f"\n=== 集成测试完成 ===")
print(f"✅ 所有功能测试通过!")
print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner")
print(f"✓ StyleLearner 正确存储和获取 situation 信息")
print(f"✓ 预测功能正常工作,可以获取对应的 situation")
print(f"✓ 模型保存和加载支持 situation 字段")
def main():
"""主函数"""
print("Expression Learner 与 Style Learner 集成测试(含 Situation")
print("=" * 70)
# 运行异步测试
asyncio.run(test_expression_style_integration())
if __name__ == "__main__":
main()