mirror of https://github.com/Mai-with-u/MaiBot.git
better:优化了表达模型的预测,和表达方式的学习逻辑
parent
4a074ec374
commit
d073a215e3
|
|
@ -185,13 +185,11 @@ class HeartFChatting:
|
|||
|
||||
question_probability = 0
|
||||
if time.time() - self.last_active_time > 3600:
|
||||
question_probability = 0.01
|
||||
elif time.time() - self.last_active_time > 1200:
|
||||
question_probability = 0.005
|
||||
elif time.time() - self.last_active_time > 600:
|
||||
question_probability = 0.001
|
||||
else:
|
||||
elif time.time() - self.last_active_time > 1200:
|
||||
question_probability = 0.0003
|
||||
else:
|
||||
question_probability = 0.0001
|
||||
|
||||
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
|
||||
|
||||
|
|
@ -210,7 +208,7 @@ class HeartFChatting:
|
|||
if question:
|
||||
logger.info(f"{self.log_prefix} 问题: {question}")
|
||||
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
|
||||
await self._lift_question_reply(question,context,cycle_timers,thinking_id)
|
||||
await self._lift_question_reply(question,context,thinking_id)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 无问题")
|
||||
# self.end_cycle(cycle_timers, thinking_id)
|
||||
|
|
@ -550,8 +548,8 @@ class HeartFChatting:
|
|||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str):
|
||||
reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
|
||||
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
new_msg = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
|
|
|
|||
|
|
@ -310,7 +310,6 @@ class Expression(BaseModel):
|
|||
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
import jieba
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
import difflib
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -148,7 +146,7 @@ class ExpressionLearner:
|
|||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self) -> bool:
|
||||
async def trigger_learning_for_chat(self):
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
|
|
@ -159,11 +157,10 @@ class ExpressionLearner:
|
|||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return False
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(num=25)
|
||||
|
||||
|
|
@ -172,15 +169,13 @@ class ExpressionLearner:
|
|||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
return
|
||||
|
||||
|
||||
|
||||
|
|
@ -188,127 +183,87 @@ class ExpressionLearner:
|
|||
"""
|
||||
学习并存储表达方式
|
||||
"""
|
||||
res = await self.learn_expression(num)
|
||||
learnt_expressions = await self.learn_expression(num)
|
||||
|
||||
if res is None:
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
learnt_expressions = res
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
_chat_id,
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
has_new_expressions = False # 记录是否有新的表达方式
|
||||
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
|
||||
|
||||
for (
|
||||
chat_id,
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append(
|
||||
{
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"up_content": up_content,
|
||||
}
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
trained_chat_ids = set() # 记录已训练的聊天室
|
||||
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == "style")
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
continue
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.up_content = new_expr["up_content"]
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
up_content=new_expr["up_content"],
|
||||
)
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
# 获取 learner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 先添加风格和对应的 situation(如果存在)
|
||||
if new_expr.get("situation"):
|
||||
learner.add_style(new_expr["style"], new_expr["situation"])
|
||||
else:
|
||||
learner.add_style(new_expr["style"])
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
chat_id,
|
||||
new_expr["up_content"],
|
||||
new_expr["style"]
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
|
||||
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
|
||||
trained_chat_ids.add(chat_id)
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
|
||||
has_new_expressions = True
|
||||
|
||||
# 限制最大数量
|
||||
# exprs = list(
|
||||
# Expression.select()
|
||||
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
# .order_by(Expression.last_active_time.asc())
|
||||
# )
|
||||
# if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除最久未活跃的多余表达方式
|
||||
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
# expr.delete_instance()
|
||||
|
||||
# 保存训练好的 style_learner 模型
|
||||
if trained_chat_ids:
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
|
||||
save_success = style_learner_manager.save_all_models()
|
||||
learner.add_style(style, situation)
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
self.chat_id,
|
||||
up_content,
|
||||
style
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||
|
||||
|
||||
# 保存当前聊天室的 style_learner 模型
|
||||
if has_new_expressions:
|
||||
try:
|
||||
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
|
||||
save_success = learner.save(style_learner_manager.model_save_path)
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
|
||||
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
|
||||
else:
|
||||
logger.warning("StyleLearner 模型保存失败")
|
||||
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
|
|
@ -415,15 +370,12 @@ class ExpressionLearner:
|
|||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
|
||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
"""
|
||||
type_str = "语言风格"
|
||||
prompt = "learn_style_prompt"
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习之后的消息
|
||||
|
|
@ -436,14 +388,14 @@ class ExpressionLearner:
|
|||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
_chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
|
||||
# 学习用
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# 溯源用
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
"learn_style_prompt",
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
|
|
@ -453,20 +405,18 @@ class ExpressionLearner:
|
|||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
|
||||
print(f"matched_expressions: {matched_expressions}")
|
||||
|
||||
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
|
||||
# 这里有待斟酌,需要进一步处理图片和表情包
|
||||
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
|
||||
pic_pattern = r"\[picid:[^\]]+\]"
|
||||
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
|
||||
|
|
@ -479,7 +429,6 @@ class ExpressionLearner:
|
|||
content = content.strip()
|
||||
if content:
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
|
|
@ -503,11 +452,7 @@ class ExpressionLearner:
|
|||
if not filtered_with_up:
|
||||
return None
|
||||
|
||||
results: List[Tuple[str, str, str, str]] = []
|
||||
for (situation, style, context, up_content) in filtered_with_up:
|
||||
results.append((self.chat_id, situation, style, context, up_content))
|
||||
|
||||
return results
|
||||
return filtered_with_up
|
||||
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
|
|
|
|||
|
|
@ -136,13 +136,13 @@ class ExpressionSelector:
|
|||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 style_learner 模型预测最合适的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
chat_info: 聊天内容信息
|
||||
target_message: 目标消息内容
|
||||
total_num: 需要预测的数量
|
||||
|
||||
Returns:
|
||||
|
|
@ -152,10 +152,7 @@ class ExpressionSelector:
|
|||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 从聊天信息中提取关键内容作为预测输入
|
||||
# 这里可以进一步优化,提取更合适的预测输入
|
||||
prediction_input = self._extract_prediction_input(chat_info)
|
||||
|
||||
|
||||
predicted_expressions = []
|
||||
|
||||
# 为每个相关的chat_id进行预测
|
||||
|
|
@ -163,7 +160,7 @@ class ExpressionSelector:
|
|||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, prediction_input, top_k=total_num
|
||||
related_chat_id, target_message, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
|
|
@ -175,7 +172,6 @@ class ExpressionSelector:
|
|||
# 从数据库查找对应的表达记录
|
||||
expr_query = Expression.select().where(
|
||||
(Expression.chat_id == related_chat_id) &
|
||||
(Expression.type == "style") &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == best_style)
|
||||
)
|
||||
|
|
@ -188,11 +184,12 @@ class ExpressionSelector:
|
|||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": prediction_input
|
||||
"prediction_input": target_message
|
||||
})
|
||||
else:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||
|
|
@ -208,39 +205,11 @@ class ExpressionSelector:
|
|||
except Exception as e:
|
||||
logger.error(f"模型预测表达方式失败: {e}")
|
||||
# 如果预测失败,回退到随机选择
|
||||
return self._fallback_random_expressions(chat_id, total_num)
|
||||
return self._random_expressions(chat_id, total_num)
|
||||
|
||||
def _extract_prediction_input(self, chat_info: str) -> str:
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从聊天信息中提取用于预测的关键内容
|
||||
|
||||
Args:
|
||||
chat_info: 聊天内容信息
|
||||
|
||||
Returns:
|
||||
str: 提取的预测输入
|
||||
"""
|
||||
try:
|
||||
# 简单的提取策略:取最后几句话作为预测输入
|
||||
lines = chat_info.strip().split('\n')
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
# 取最后3行作为预测输入
|
||||
recent_lines = lines[-1:]
|
||||
prediction_input = ' '.join(recent_lines).strip()
|
||||
|
||||
logger.info(f"提取预测输入: {prediction_input}")
|
||||
|
||||
return prediction_input
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"提取预测输入失败: {e}")
|
||||
return ""
|
||||
|
||||
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
回退到随机选择表达方式
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
|
|
@ -255,7 +224,7 @@ class ExpressionSelector:
|
|||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
(Expression.chat_id.in_(related_chat_ids))
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
|
|
@ -265,7 +234,6 @@ class ExpressionSelector:
|
|||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
|
|
@ -277,7 +245,7 @@ class ExpressionSelector:
|
|||
else:
|
||||
selected_style = []
|
||||
|
||||
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -315,7 +283,7 @@ class ExpressionSelector:
|
|||
if expression_mode == "exp_model":
|
||||
# exp_model模式:直接使用模型预测,不经过LLM
|
||||
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
|
||||
return await self._select_expressions_model_only(chat_id, target_message, max_num)
|
||||
elif expression_mode == "classic":
|
||||
# classic模式:随机选择+LLM选择
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
|
|
@ -327,7 +295,7 @@ class ExpressionSelector:
|
|||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
target_message: str,
|
||||
max_num: int = 10,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
|
|
@ -335,7 +303,7 @@ class ExpressionSelector:
|
|||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
target_message: 目标消息内容
|
||||
max_num: 最大选择数量
|
||||
|
||||
Returns:
|
||||
|
|
@ -343,11 +311,7 @@ class ExpressionSelector:
|
|||
"""
|
||||
try:
|
||||
# 使用模型预测最合适的表达方式
|
||||
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
|
||||
|
||||
|
||||
# 直接取前max_num个预测结果
|
||||
selected_expressions = style_exprs[:max_num]
|
||||
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
|
||||
selected_ids = [expr["id"] for expr in selected_expressions]
|
||||
|
||||
# 更新last_active_time
|
||||
|
|
@ -381,8 +345,8 @@ class ExpressionSelector:
|
|||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用模型预测最合适的表达方式
|
||||
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
|
|
@ -460,21 +424,6 @@ class ExpressionSelector:
|
|||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
使用LLM选择适合的表达方式(保持向后兼容)
|
||||
|
||||
注意:此方法已被 select_suitable_expressions 替代,建议使用新方法
|
||||
"""
|
||||
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
|
||||
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
|
|
@ -482,19 +431,17 @@ class ExpressionSelector:
|
|||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -58,8 +58,18 @@ class ExpressorModel:
|
|||
# 取最高分
|
||||
if not scores:
|
||||
return None, {}
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
# 根据k参数限制返回的候选数量
|
||||
if k is not None and k > 0:
|
||||
# 按分数降序排序,取前k个
|
||||
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
limited_scores = dict(sorted_scores[:k])
|
||||
best = sorted_scores[0][0] if sorted_scores else None
|
||||
return best, limited_scores
|
||||
else:
|
||||
# 如果没有指定k,返回所有分数
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""更新正反馈学习"""
|
||||
|
|
|
|||
|
|
@ -60,13 +60,8 @@ class QuestionMaker:
|
|||
|
||||
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
|
||||
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
||||
if not conflicts_with_zero:
|
||||
if random.random() >= 0.01:
|
||||
return None
|
||||
# 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重)
|
||||
chosen_conflict = random.choice(conflicts)
|
||||
else:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.05
|
||||
if conflicts_with_zero:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01
|
||||
weights = []
|
||||
for conflict in conflicts:
|
||||
current_raise_time = getattr(conflict, "raise_time", 0) or 0
|
||||
|
|
@ -77,12 +72,9 @@ class QuestionMaker:
|
|||
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
||||
|
||||
# 选中后,自增 raise_time 并保存
|
||||
try:
|
||||
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
||||
chosen_conflict.save()
|
||||
except Exception:
|
||||
# 静默失败不影响流程
|
||||
pass
|
||||
|
||||
|
||||
return chosen_conflict
|
||||
|
||||
|
|
|
|||
|
|
@ -288,7 +288,7 @@ class ConflictTracker:
|
|||
if existing_conflict:
|
||||
# 检查raise_time是否大于3且没有答案
|
||||
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
|
||||
if current_raise_time > 1 and not existing_conflict.answer:
|
||||
if current_raise_time > 0 and not existing_conflict.answer:
|
||||
# 删除该条目
|
||||
await self.delete_conflict(original_question, tracker.chat_id)
|
||||
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")
|
||||
|
|
|
|||
|
|
@ -1,188 +0,0 @@
|
|||
"""
|
||||
测试修改后的 expression_learner 与 style_learner 的集成
|
||||
验证学习新表达时是否正确处理 situation 字段
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.express.expression_learner import ExpressionLearner
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_style_integration_test")
|
||||
|
||||
|
||||
async def test_expression_style_integration():
|
||||
"""测试 expression_learner 与 style_learner 的集成(包含 situation)"""
|
||||
print("=== Expression Learner 与 Style Learner 集成测试(含 Situation) ===\n")
|
||||
|
||||
# 创建测试聊天室ID
|
||||
test_chat_id = "test_integration_situation_chat"
|
||||
|
||||
# 创建 ExpressionLearner 实例
|
||||
expression_learner = ExpressionLearner(test_chat_id)
|
||||
|
||||
print(f"测试聊天室: {test_chat_id}")
|
||||
|
||||
# 模拟学习到的表达数据(包含 situation)
|
||||
mock_learnt_expressions = [
|
||||
(test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"),
|
||||
(test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"),
|
||||
(test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"),
|
||||
(test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"),
|
||||
(test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"),
|
||||
]
|
||||
|
||||
print("模拟学习到的表达数据(包含 situation):")
|
||||
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
|
||||
print(f" {situation} -> {style} (输入: {up_content})")
|
||||
|
||||
# 模拟 learn_and_store 方法的处理逻辑
|
||||
print(f"\n开始处理学习数据...")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict = {}
|
||||
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"up_content": up_content,
|
||||
})
|
||||
|
||||
# 训练 style_learner(包含 situation 处理)
|
||||
trained_chat_ids = set()
|
||||
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
print(f"\n处理聊天室: {chat_id}")
|
||||
|
||||
for new_expr in expr_list:
|
||||
# 训练 style_learner(包含 situation)
|
||||
if new_expr.get("up_content") and new_expr.get("style"):
|
||||
try:
|
||||
# 获取 learner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 先添加风格和对应的 situation(如果不存在)
|
||||
if new_expr.get("situation"):
|
||||
learner.add_style(new_expr["style"], new_expr["situation"])
|
||||
print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')")
|
||||
else:
|
||||
learner.add_style(new_expr["style"])
|
||||
print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)")
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
chat_id,
|
||||
new_expr["up_content"],
|
||||
new_expr["style"]
|
||||
)
|
||||
if success:
|
||||
print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" +
|
||||
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
|
||||
trained_chat_ids.add(chat_id)
|
||||
else:
|
||||
print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}")
|
||||
except Exception as e:
|
||||
print(f" ✗ StyleLearner学习异常: {e}")
|
||||
|
||||
# 保存模型
|
||||
if trained_chat_ids:
|
||||
print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
|
||||
try:
|
||||
save_success = style_learner_manager.save_all_models()
|
||||
|
||||
if save_success:
|
||||
print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
|
||||
else:
|
||||
print("✗ StyleLearner 模型保存失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ StyleLearner 模型保存异常: {e}")
|
||||
|
||||
# 测试预测功能
|
||||
print(f"\n测试 StyleLearner 预测功能:")
|
||||
test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"]
|
||||
|
||||
for test_input in test_inputs:
|
||||
try:
|
||||
best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3)
|
||||
if best_style:
|
||||
# 获取对应的 situation
|
||||
learner = style_learner_manager.get_learner(test_chat_id)
|
||||
situation = learner.get_situation(best_style)
|
||||
print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')")
|
||||
if scores:
|
||||
top_scores = dict(list(scores.items())[:3])
|
||||
print(f" 分数: {top_scores}")
|
||||
else:
|
||||
print(f" 输入: '{test_input}' -> 无预测结果")
|
||||
except Exception as e:
|
||||
print(f" 输入: '{test_input}' -> 预测异常: {e}")
|
||||
|
||||
# 获取统计信息
|
||||
print(f"\nStyleLearner 统计信息:")
|
||||
try:
|
||||
stats = style_learner_manager.get_all_stats()
|
||||
if test_chat_id in stats:
|
||||
chat_stats = stats[test_chat_id]
|
||||
print(f" 聊天室: {test_chat_id}")
|
||||
print(f" 总样本数: {chat_stats['total_samples']}")
|
||||
print(f" 当前风格数: {chat_stats['style_count']}")
|
||||
print(f" 最大风格数: {chat_stats['max_styles']}")
|
||||
print(f" 风格列表: {chat_stats['all_styles']}")
|
||||
|
||||
# 显示每个风格的 situation 信息
|
||||
print(f" 风格和 situation 信息:")
|
||||
for style in chat_stats['all_styles']:
|
||||
situation = learner.get_situation(style)
|
||||
print(f" '{style}' -> situation: '{situation}'")
|
||||
else:
|
||||
print(f" 未找到聊天室 {test_chat_id} 的统计信息")
|
||||
except Exception as e:
|
||||
print(f" 获取统计信息异常: {e}")
|
||||
|
||||
# 测试模型保存和加载
|
||||
print(f"\n测试模型保存和加载...")
|
||||
try:
|
||||
# 创建新的管理器并加载模型
|
||||
new_manager = style_learner_manager # 使用同一个管理器
|
||||
new_learner = new_manager.get_learner(test_chat_id)
|
||||
|
||||
# 验证加载后的 situation 信息
|
||||
loaded_style_info = new_learner.get_all_style_info()
|
||||
print(f" 加载后风格数: {len(loaded_style_info)}")
|
||||
for style, (style_id, situation) in loaded_style_info.items():
|
||||
print(f" 加载验证: '{style}' -> situation: '{situation}'")
|
||||
|
||||
print("✓ 模型保存和加载测试通过")
|
||||
except Exception as e:
|
||||
print(f"✗ 模型保存和加载测试失败: {e}")
|
||||
|
||||
print(f"\n=== 集成测试完成 ===")
|
||||
print(f"✅ 所有功能测试通过!")
|
||||
print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner")
|
||||
print(f"✓ StyleLearner 正确存储和获取 situation 信息")
|
||||
print(f"✓ 预测功能正常工作,可以获取对应的 situation")
|
||||
print(f"✓ 模型保存和加载支持 situation 字段")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("Expression Learner 与 Style Learner 集成测试(含 Situation)")
|
||||
print("=" * 70)
|
||||
|
||||
# 运行异步测试
|
||||
asyncio.run(test_expression_style_integration())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue