better:优化了表达模型的预测,和表达方式的学习逻辑

pull/1294/head
SengokuCola 2025-10-11 23:44:52 +08:00
parent 4a074ec374
commit d073a215e3
8 changed files with 113 additions and 410 deletions

View File

@ -185,13 +185,11 @@ class HeartFChatting:
question_probability = 0
if time.time() - self.last_active_time > 3600:
question_probability = 0.01
elif time.time() - self.last_active_time > 1200:
question_probability = 0.005
elif time.time() - self.last_active_time > 600:
question_probability = 0.001
else:
elif time.time() - self.last_active_time > 1200:
question_probability = 0.0003
else:
question_probability = 0.0001
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
@ -210,7 +208,7 @@ class HeartFChatting:
if question:
logger.info(f"{self.log_prefix} 问题: {question}")
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
await self._lift_question_reply(question,context,cycle_timers,thinking_id)
await self._lift_question_reply(question,context,thinking_id)
else:
logger.info(f"{self.log_prefix} 无问题")
# self.end_cycle(cycle_timers, thinking_id)
@ -550,8 +548,8 @@ class HeartFChatting:
traceback.print_exc()
return False, ""
async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str):
reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
new_msg = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),

View File

@ -310,7 +310,6 @@ class Expression(BaseModel):
last_active_time = FloatField()
chat_id = TextField(index=True)
type = TextField()
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
class Meta:

View File

@ -1,11 +1,9 @@
import time
import random
import json
import os
import re
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
from typing import List, Optional, Tuple
import traceback
import difflib
from src.common.logger import get_logger
@ -148,7 +146,7 @@ class ExpressionLearner:
return True
async def trigger_learning_for_chat(self) -> bool:
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
@ -159,11 +157,10 @@ class ExpressionLearner:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return False
return
try:
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
@ -172,15 +169,13 @@ class ExpressionLearner:
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
return True
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
return False
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return False
return
@ -188,127 +183,87 @@ class ExpressionLearner:
"""
学习并存储表达方式
"""
res = await self.learn_expression(num)
learnt_expressions = await self.learn_expression(num)
if res is None:
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
learnt_expressions = res
# 展示学到的表达方式
learnt_expressions_str = ""
for (
_chat_id,
situation,
style,
_context,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
current_time = time.time()
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for (
chat_id,
situation,
style,
context,
up_content,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append(
{
"situation": situation,
"style": style,
"context": context,
"up_content": up_content,
}
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
)
current_time = time.time()
# 存储到数据库 Expression 表并训练 style_learner
trained_chat_ids = set() # 记录已训练的聊天室
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
if query.exists():
# 表达方式完全相同,只更新时间戳
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
continue
else:
Expression.create(
situation=situation,
style=style,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time, # 手动设置创建日期
context=context,
up_content=up_content,
)
if query.exists():
expr_obj = query.get()
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.up_content = new_expr["up_content"]
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=new_expr["situation"],
style=new_expr["style"],
last_active_time=current_time,
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
up_content=new_expr["up_content"],
)
# 训练 style_learnerup_content 和 style 必定存在)
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
else:
learner.add_style(new_expr["style"])
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
has_new_expressions = True
# 限制最大数量
# exprs = list(
# Expression.select()
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
# .order_by(Expression.last_active_time.asc())
# )
# if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除最久未活跃的多余表达方式
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
# expr.delete_instance()
# 保存训练好的 style_learner 模型
if trained_chat_ids:
# 训练 style_learnerup_content 和 style 必定存在)
try:
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
save_success = style_learner_manager.save_all_models()
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(
self.chat_id,
up_content,
style
)
if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning("StyleLearner 模型保存失败")
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
@ -415,15 +370,12 @@ class ExpressionLearner:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
"""
type_str = "语言风格"
prompt = "learn_style_prompt"
current_time = time.time()
# 获取上次学习之后的消息
@ -436,14 +388,14 @@ class ExpressionLearner:
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
_chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
# 学习用
random_msg_str: str = await build_anonymous_messages(random_msg)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
prompt,
"learn_style_prompt",
chat_str=random_msg_str,
)
@ -453,20 +405,18 @@ class ExpressionLearner:
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
print(f"matched_expressions: {matched_expressions}")
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
# 这里有待斟酌,需要进一步处理图片和表情包
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
pic_pattern = r"\[picid:[^\]]+\]"
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
@ -479,7 +429,6 @@ class ExpressionLearner:
content = content.strip()
if content:
bare_lines.append((idx, content))
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
@ -503,11 +452,7 @@ class ExpressionLearner:
if not filtered_with_up:
return None
results: List[Tuple[str, str, str, str]] = []
for (situation, style, context, up_content) in filtered_with_up:
results.append((self.chat_id, situation, style, context, up_content))
return results
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:

View File

@ -136,13 +136,13 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
chat_info: 聊天内容信息
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
@ -152,10 +152,7 @@ class ExpressionSelector:
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
# 从聊天信息中提取关键内容作为预测输入
# 这里可以进一步优化,提取更合适的预测输入
prediction_input = self._extract_prediction_input(chat_info)
predicted_expressions = []
# 为每个相关的chat_id进行预测
@ -163,7 +160,7 @@ class ExpressionSelector:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, prediction_input, top_k=total_num
related_chat_id, target_message, top_k=total_num
)
if best_style and scores:
@ -175,7 +172,6 @@ class ExpressionSelector:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.type == "style") &
(Expression.situation == situation) &
(Expression.style == best_style)
)
@ -188,11 +184,12 @@ class ExpressionSelector:
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": prediction_input
"prediction_input": target_message
})
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
@ -208,39 +205,11 @@ class ExpressionSelector:
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._fallback_random_expressions(chat_id, total_num)
return self._random_expressions(chat_id, total_num)
def _extract_prediction_input(self, chat_info: str) -> str:
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
从聊天信息中提取用于预测的关键内容
Args:
chat_info: 聊天内容信息
Returns:
str: 提取的预测输入
"""
try:
# 简单的提取策略:取最后几句话作为预测输入
lines = chat_info.strip().split('\n')
if not lines:
return ""
# 取最后3行作为预测输入
recent_lines = lines[-1:]
prediction_input = ' '.join(recent_lines).strip()
logger.info(f"提取预测输入: {prediction_input}")
return prediction_input
except Exception as e:
logger.warning(f"提取预测输入失败: {e}")
return ""
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
回退到随机选择表达方式
随机选择表达方式
Args:
chat_id: 聊天室ID
@ -255,7 +224,7 @@ class ExpressionSelector:
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
(Expression.chat_id.in_(related_chat_ids))
)
style_exprs = [
@ -265,7 +234,6 @@ class ExpressionSelector:
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
@ -277,7 +245,7 @@ class ExpressionSelector:
else:
selected_style = []
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
@ -315,7 +283,7 @@ class ExpressionSelector:
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
return await self._select_expressions_model_only(chat_id, target_message, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
@ -327,7 +295,7 @@ class ExpressionSelector:
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
target_message: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
@ -335,7 +303,7 @@ class ExpressionSelector:
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
@ -343,11 +311,7 @@ class ExpressionSelector:
"""
try:
# 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
# 直接取前max_num个预测结果
selected_expressions = style_exprs[:max_num]
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
@ -381,8 +345,8 @@ class ExpressionSelector:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
# 1. 使用随机抽样选择表达方式
style_exprs = self._random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
@ -460,21 +424,6 @@ class ExpressionSelector:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
使用LLM选择适合的表达方式保持向后兼容
注意此方法已被 select_suitable_expressions 替代建议使用新方法
"""
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
@ -482,19 +431,17 @@ class ExpressionSelector:
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
key = (source_id, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)

View File

@ -58,8 +58,18 @@ class ExpressorModel:
# 取最高分
if not scores:
return None, {}
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
# 根据k参数限制返回的候选数量
if k is not None and k > 0:
# 按分数降序排序取前k个
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
limited_scores = dict(sorted_scores[:k])
best = sorted_scores[0][0] if sorted_scores else None
return best, limited_scores
else:
# 如果没有指定k返回所有分数
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""

View File

@ -60,13 +60,8 @@ class QuestionMaker:
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
if not conflicts_with_zero:
if random.random() >= 0.01:
return None
# 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重)
chosen_conflict = random.choice(conflicts)
else:
# 权重规则raise_time == 0 -> 1.0raise_time >= 1 -> 0.05
if conflicts_with_zero:
# 权重规则raise_time == 0 -> 1.0raise_time >= 1 -> 0.01
weights = []
for conflict in conflicts:
current_raise_time = getattr(conflict, "raise_time", 0) or 0
@ -77,12 +72,9 @@ class QuestionMaker:
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
# 选中后,自增 raise_time 并保存
try:
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
chosen_conflict.save()
except Exception:
# 静默失败不影响流程
pass
return chosen_conflict

View File

@ -288,7 +288,7 @@ class ConflictTracker:
if existing_conflict:
# 检查raise_time是否大于3且没有答案
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
if current_raise_time > 1 and not existing_conflict.answer:
if current_raise_time > 0 and not existing_conflict.answer:
# 删除该条目
await self.delete_conflict(original_question, tracker.chat_id)
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")

View File

@ -1,188 +0,0 @@
"""
测试修改后的 expression_learner style_learner 的集成
验证学习新表达时是否正确处理 situation 字段
"""
import os
import sys
import asyncio
import time
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.express.expression_learner import ExpressionLearner
from src.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("expression_style_integration_test")
async def test_expression_style_integration():
"""测试 expression_learner 与 style_learner 的集成(包含 situation"""
print("=== Expression Learner 与 Style Learner 集成测试(含 Situation ===\n")
# 创建测试聊天室ID
test_chat_id = "test_integration_situation_chat"
# 创建 ExpressionLearner 实例
expression_learner = ExpressionLearner(test_chat_id)
print(f"测试聊天室: {test_chat_id}")
# 模拟学习到的表达数据(包含 situation
mock_learnt_expressions = [
(test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"),
(test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"),
(test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"),
(test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"),
(test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"),
]
print("模拟学习到的表达数据(包含 situation:")
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
print(f" {situation} -> {style} (输入: {up_content})")
# 模拟 learn_and_store 方法的处理逻辑
print(f"\n开始处理学习数据...")
# 按chat_id分组
chat_dict = {}
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({
"situation": situation,
"style": style,
"context": context,
"up_content": up_content,
})
# 训练 style_learner包含 situation 处理)
trained_chat_ids = set()
for chat_id, expr_list in chat_dict.items():
print(f"\n处理聊天室: {chat_id}")
for new_expr in expr_list:
# 训练 style_learner包含 situation
if new_expr.get("up_content") and new_expr.get("style"):
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果不存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')")
else:
learner.add_style(new_expr["style"])
print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)")
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
print(f" ✗ StyleLearner学习异常: {e}")
# 保存模型
if trained_chat_ids:
print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
try:
save_success = style_learner_manager.save_all_models()
if save_success:
print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
else:
print("✗ StyleLearner 模型保存失败")
except Exception as e:
print(f"✗ StyleLearner 模型保存异常: {e}")
# 测试预测功能
print(f"\n测试 StyleLearner 预测功能:")
test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"]
for test_input in test_inputs:
try:
best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3)
if best_style:
# 获取对应的 situation
learner = style_learner_manager.get_learner(test_chat_id)
situation = learner.get_situation(best_style)
print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')")
if scores:
top_scores = dict(list(scores.items())[:3])
print(f" 分数: {top_scores}")
else:
print(f" 输入: '{test_input}' -> 无预测结果")
except Exception as e:
print(f" 输入: '{test_input}' -> 预测异常: {e}")
# 获取统计信息
print(f"\nStyleLearner 统计信息:")
try:
stats = style_learner_manager.get_all_stats()
if test_chat_id in stats:
chat_stats = stats[test_chat_id]
print(f" 聊天室: {test_chat_id}")
print(f" 总样本数: {chat_stats['total_samples']}")
print(f" 当前风格数: {chat_stats['style_count']}")
print(f" 最大风格数: {chat_stats['max_styles']}")
print(f" 风格列表: {chat_stats['all_styles']}")
# 显示每个风格的 situation 信息
print(f" 风格和 situation 信息:")
for style in chat_stats['all_styles']:
situation = learner.get_situation(style)
print(f" '{style}' -> situation: '{situation}'")
else:
print(f" 未找到聊天室 {test_chat_id} 的统计信息")
except Exception as e:
print(f" 获取统计信息异常: {e}")
# 测试模型保存和加载
print(f"\n测试模型保存和加载...")
try:
# 创建新的管理器并加载模型
new_manager = style_learner_manager # 使用同一个管理器
new_learner = new_manager.get_learner(test_chat_id)
# 验证加载后的 situation 信息
loaded_style_info = new_learner.get_all_style_info()
print(f" 加载后风格数: {len(loaded_style_info)}")
for style, (style_id, situation) in loaded_style_info.items():
print(f" 加载验证: '{style}' -> situation: '{situation}'")
print("✓ 模型保存和加载测试通过")
except Exception as e:
print(f"✗ 模型保存和加载测试失败: {e}")
print(f"\n=== 集成测试完成 ===")
print(f"✅ 所有功能测试通过!")
print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner")
print(f"✓ StyleLearner 正确存储和获取 situation 信息")
print(f"✓ 预测功能正常工作,可以获取对应的 situation")
print(f"✓ 模型保存和加载支持 situation 字段")
def main():
"""主函数"""
print("Expression Learner 与 Style Learner 集成测试(含 Situation")
print("=" * 70)
# 运行异步测试
asyncio.run(test_expression_style_integration())
if __name__ == "__main__":
main()