mirror of https://github.com/Mai-with-u/MaiBot.git
622 lines
20 KiB
Python
622 lines
20 KiB
Python
"""
|
||
多聊天室表达风格学习系统
|
||
支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射
|
||
"""
|
||
|
||
import os
|
||
import pickle
|
||
import traceback
|
||
from typing import Dict, List, Optional, Tuple
|
||
from collections import defaultdict
|
||
import asyncio
|
||
|
||
from src.common.logger import get_logger
|
||
from .expressor_model.model import ExpressorModel
|
||
|
||
logger = get_logger("style_learner")
|
||
|
||
|
||
class StyleLearner:
|
||
"""
|
||
单个聊天室的表达风格学习器
|
||
学习从up_content到style的映射关系
|
||
支持动态管理风格集合(无数量上限)
|
||
"""
|
||
|
||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||
self.chat_id = chat_id
|
||
self.model_config = model_config or {
|
||
"alpha": 0.5,
|
||
"beta": 0.5,
|
||
"gamma": 0.99, # 衰减因子,支持遗忘
|
||
"vocab_size": 200000,
|
||
"use_jieba": True,
|
||
}
|
||
|
||
# 初始化表达模型
|
||
self.expressor = ExpressorModel(**self.model_config)
|
||
|
||
# 动态风格管理
|
||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||
self.next_style_id = 0 # 下一个可用的style_id
|
||
|
||
# 学习统计
|
||
self.learning_stats = {
|
||
"total_samples": 0,
|
||
"style_counts": defaultdict(int),
|
||
"last_update": None,
|
||
"style_usage_frequency": defaultdict(int), # 风格使用频率
|
||
}
|
||
|
||
def add_style(self, style: str, situation: str = None) -> bool:
|
||
"""
|
||
动态添加一个新的风格
|
||
|
||
Args:
|
||
style: 风格文本
|
||
situation: 对应的situation文本(可选)
|
||
|
||
Returns:
|
||
bool: 添加是否成功
|
||
"""
|
||
try:
|
||
# 检查是否已存在
|
||
if style in self.style_to_id:
|
||
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
|
||
return True
|
||
|
||
# 生成新的style_id
|
||
style_id = f"style_{self.next_style_id}"
|
||
self.next_style_id += 1
|
||
|
||
# 添加到映射
|
||
self.style_to_id[style] = style_id
|
||
self.id_to_style[style_id] = style
|
||
if situation:
|
||
self.id_to_situation[style_id] = situation
|
||
|
||
# 添加到expressor模型
|
||
self.expressor.add_candidate(style_id, style, situation)
|
||
|
||
logger.info(
|
||
f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})"
|
||
+ (f", situation: '{situation}'" if situation else "")
|
||
)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
|
||
return False
|
||
|
||
def remove_style(self, style: str) -> bool:
|
||
"""
|
||
删除一个风格
|
||
|
||
Args:
|
||
style: 要删除的风格文本
|
||
|
||
Returns:
|
||
bool: 删除是否成功
|
||
"""
|
||
try:
|
||
if style not in self.style_to_id:
|
||
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
|
||
return False
|
||
|
||
style_id = self.style_to_id[style]
|
||
|
||
# 从映射中删除
|
||
del self.style_to_id[style]
|
||
del self.id_to_style[style_id]
|
||
if style_id in self.id_to_situation:
|
||
del self.id_to_situation[style_id]
|
||
|
||
# 从expressor模型中删除(通过重新构建)
|
||
self._rebuild_expressor()
|
||
|
||
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
|
||
return False
|
||
|
||
def update_style(self, old_style: str, new_style: str) -> bool:
|
||
"""
|
||
更新一个风格
|
||
|
||
Args:
|
||
old_style: 原风格文本
|
||
new_style: 新风格文本
|
||
|
||
Returns:
|
||
bool: 更新是否成功
|
||
"""
|
||
try:
|
||
if old_style not in self.style_to_id:
|
||
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
|
||
return False
|
||
|
||
if new_style in self.style_to_id and new_style != old_style:
|
||
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
|
||
return False
|
||
|
||
style_id = self.style_to_id[old_style]
|
||
|
||
# 更新映射
|
||
del self.style_to_id[old_style]
|
||
self.style_to_id[new_style] = style_id
|
||
self.id_to_style[style_id] = new_style
|
||
|
||
# 更新expressor模型(保留原有的situation)
|
||
situation = self.id_to_situation.get(style_id)
|
||
self.expressor.add_candidate(style_id, new_style, situation)
|
||
|
||
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
|
||
return False
|
||
|
||
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
|
||
"""
|
||
批量添加风格
|
||
|
||
Args:
|
||
styles: 风格文本列表
|
||
situations: 对应的situation文本列表(可选)
|
||
|
||
Returns:
|
||
int: 成功添加的数量
|
||
"""
|
||
success_count = 0
|
||
for i, style in enumerate(styles):
|
||
situation = situations[i] if situations and i < len(situations) else None
|
||
if self.add_style(style, situation):
|
||
success_count += 1
|
||
|
||
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
|
||
return success_count
|
||
|
||
def get_all_styles(self) -> List[str]:
|
||
"""获取所有已注册的风格"""
|
||
return list(self.style_to_id.keys())
|
||
|
||
def get_style_count(self) -> int:
|
||
"""获取当前风格数量"""
|
||
return len(self.style_to_id)
|
||
|
||
def get_situation(self, style: str) -> Optional[str]:
|
||
"""
|
||
获取风格对应的situation
|
||
|
||
Args:
|
||
style: 风格文本
|
||
|
||
Returns:
|
||
Optional[str]: 对应的situation,如果不存在则返回None
|
||
"""
|
||
if style not in self.style_to_id:
|
||
return None
|
||
|
||
style_id = self.style_to_id[style]
|
||
return self.id_to_situation.get(style_id)
|
||
|
||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||
"""
|
||
获取风格的完整信息
|
||
|
||
Args:
|
||
style: 风格文本
|
||
|
||
Returns:
|
||
Tuple[Optional[str], Optional[str]]: (style_id, situation)
|
||
"""
|
||
if style not in self.style_to_id:
|
||
return None, None
|
||
|
||
style_id = self.style_to_id[style]
|
||
situation = self.id_to_situation.get(style_id)
|
||
return style_id, situation
|
||
|
||
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||
"""
|
||
获取所有风格的完整信息
|
||
|
||
Returns:
|
||
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
|
||
"""
|
||
result = {}
|
||
for style, style_id in self.style_to_id.items():
|
||
situation = self.id_to_situation.get(style_id)
|
||
result[style] = (style_id, situation)
|
||
return result
|
||
|
||
def _rebuild_expressor(self):
|
||
"""重新构建expressor模型(删除风格后使用)"""
|
||
try:
|
||
# 重新创建expressor
|
||
self.expressor = ExpressorModel(**self.model_config)
|
||
|
||
# 重新添加所有风格和situation
|
||
for style_id, style_text in self.id_to_style.items():
|
||
situation = self.id_to_situation.get(style_id)
|
||
self.expressor.add_candidate(style_id, style_text, situation)
|
||
|
||
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
|
||
|
||
def learn_mapping(self, up_content: str, style: str) -> bool:
|
||
"""
|
||
学习一个up_content到style的映射
|
||
如果style不存在,会自动添加
|
||
|
||
Args:
|
||
up_content: 输入内容
|
||
style: 对应的style文本
|
||
|
||
Returns:
|
||
bool: 学习是否成功
|
||
"""
|
||
try:
|
||
# 如果style不存在,先添加它
|
||
if style not in self.style_to_id:
|
||
if not self.add_style(style):
|
||
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
|
||
return False
|
||
|
||
# 获取style_id
|
||
style_id = self.style_to_id[style]
|
||
|
||
# 使用正反馈学习
|
||
self.expressor.update_positive(up_content, style_id)
|
||
|
||
# 更新统计
|
||
self.learning_stats["total_samples"] += 1
|
||
self.learning_stats["style_counts"][style_id] += 1
|
||
self.learning_stats["style_usage_frequency"][style] += 1
|
||
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
|
||
|
||
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||
"""
|
||
根据up_content预测最合适的style
|
||
|
||
Args:
|
||
up_content: 输入内容
|
||
top_k: 返回前k个候选
|
||
|
||
Returns:
|
||
Tuple[最佳style文本, 所有候选的分数]
|
||
"""
|
||
try:
|
||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||
|
||
if best_style_id is None:
|
||
return None, {}
|
||
|
||
# 将style_id转换为style文本
|
||
best_style = self.id_to_style.get(best_style_id)
|
||
|
||
# 转换所有分数
|
||
style_scores = {}
|
||
for sid, score in scores.items():
|
||
style_text = self.id_to_style.get(sid)
|
||
if style_text:
|
||
style_scores[style_text] = score
|
||
|
||
return best_style, style_scores
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
|
||
traceback.print_exc()
|
||
return None, {}
|
||
|
||
def decay_learning(self, factor: Optional[float] = None) -> None:
|
||
"""
|
||
对学习到的知识进行衰减(遗忘)
|
||
|
||
Args:
|
||
factor: 衰减因子,None则使用配置中的gamma
|
||
"""
|
||
self.expressor.decay(factor)
|
||
logger.debug(f"[{self.chat_id}] 执行知识衰减")
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取学习统计信息"""
|
||
return {
|
||
"chat_id": self.chat_id,
|
||
"total_samples": self.learning_stats["total_samples"],
|
||
"style_count": len(self.style_to_id),
|
||
"style_counts": dict(self.learning_stats["style_counts"]),
|
||
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
||
"last_update": self.learning_stats["last_update"],
|
||
"all_styles": list(self.style_to_id.keys()),
|
||
}
|
||
|
||
def save(self, base_path: str) -> bool:
|
||
"""
|
||
保存模型到文件
|
||
|
||
Args:
|
||
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
|
||
"""
|
||
try:
|
||
os.makedirs(base_path, exist_ok=True)
|
||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||
|
||
# 保存模型和统计信息
|
||
save_data = {
|
||
"model_config": self.model_config,
|
||
"style_to_id": self.style_to_id,
|
||
"id_to_style": self.id_to_style,
|
||
"id_to_situation": self.id_to_situation,
|
||
"next_style_id": self.next_style_id,
|
||
"learning_stats": self.learning_stats,
|
||
}
|
||
|
||
# 先保存expressor模型
|
||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||
self.expressor.save(expressor_path)
|
||
|
||
# 保存其他数据
|
||
with open(file_path, "wb") as f:
|
||
pickle.dump(save_data, f)
|
||
|
||
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
|
||
return False
|
||
|
||
def load(self, base_path: str) -> bool:
|
||
"""
|
||
从文件加载模型
|
||
|
||
Args:
|
||
base_path: 基础路径
|
||
"""
|
||
try:
|
||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||
|
||
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
|
||
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
|
||
return False
|
||
|
||
# 加载其他数据
|
||
with open(file_path, "rb") as f:
|
||
save_data = pickle.load(f)
|
||
|
||
# 恢复配置和状态
|
||
self.model_config = save_data["model_config"]
|
||
self.style_to_id = save_data["style_to_id"]
|
||
self.id_to_style = save_data["id_to_style"]
|
||
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
|
||
self.next_style_id = save_data["next_style_id"]
|
||
self.learning_stats = save_data["learning_stats"]
|
||
|
||
# 重新创建expressor并加载
|
||
self.expressor = ExpressorModel(**self.model_config)
|
||
self.expressor.load(expressor_path)
|
||
|
||
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
|
||
return False
|
||
|
||
|
||
class StyleLearnerManager:
|
||
"""
|
||
多聊天室表达风格学习管理器
|
||
为每个chat_id维护独立的StyleLearner实例
|
||
每个chat_id可以动态管理自己的风格集合(无数量上限)
|
||
"""
|
||
|
||
def __init__(self, model_save_path: str = "data/style_models"):
|
||
self.model_save_path = model_save_path
|
||
self.learners: Dict[str, StyleLearner] = {}
|
||
|
||
# 自动保存配置
|
||
self.auto_save_interval = 300 # 5分钟
|
||
self._auto_save_task: Optional[asyncio.Task] = None
|
||
|
||
logger.info("StyleLearnerManager 已初始化")
|
||
|
||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||
"""
|
||
获取或创建指定chat_id的学习器
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
model_config: 模型配置,None则使用默认配置
|
||
|
||
Returns:
|
||
StyleLearner实例
|
||
"""
|
||
if chat_id not in self.learners:
|
||
# 创建新的学习器
|
||
learner = StyleLearner(chat_id, model_config)
|
||
|
||
# 尝试加载已保存的模型
|
||
learner.load(self.model_save_path)
|
||
|
||
self.learners[chat_id] = learner
|
||
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
|
||
|
||
return self.learners[chat_id]
|
||
|
||
def add_style(self, chat_id: str, style: str) -> bool:
|
||
"""
|
||
为指定chat_id添加风格
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
style: 风格文本
|
||
|
||
Returns:
|
||
bool: 添加是否成功
|
||
"""
|
||
learner = self.get_learner(chat_id)
|
||
return learner.add_style(style)
|
||
|
||
def remove_style(self, chat_id: str, style: str) -> bool:
|
||
"""
|
||
为指定chat_id删除风格
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
style: 风格文本
|
||
|
||
Returns:
|
||
bool: 删除是否成功
|
||
"""
|
||
learner = self.get_learner(chat_id)
|
||
return learner.remove_style(style)
|
||
|
||
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
|
||
"""
|
||
为指定chat_id更新风格
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
old_style: 原风格文本
|
||
new_style: 新风格文本
|
||
|
||
Returns:
|
||
bool: 更新是否成功
|
||
"""
|
||
learner = self.get_learner(chat_id)
|
||
return learner.update_style(old_style, new_style)
|
||
|
||
def get_chat_styles(self, chat_id: str) -> List[str]:
|
||
"""
|
||
获取指定chat_id的所有风格
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
|
||
Returns:
|
||
List[str]: 风格列表
|
||
"""
|
||
learner = self.get_learner(chat_id)
|
||
return learner.get_all_styles()
|
||
|
||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||
"""
|
||
学习一个映射关系
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
up_content: 输入内容
|
||
style: 对应的style
|
||
|
||
Returns:
|
||
bool: 学习是否成功
|
||
"""
|
||
learner = self.get_learner(chat_id)
|
||
return learner.learn_mapping(up_content, style)
|
||
|
||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||
"""
|
||
预测最合适的style
|
||
|
||
Args:
|
||
chat_id: 聊天室ID
|
||
up_content: 输入内容
|
||
top_k: 返回前k个候选
|
||
|
||
Returns:
|
||
Tuple[最佳style, 所有候选分数]
|
||
"""
|
||
learner = self.get_learner(chat_id)
|
||
return learner.predict_style(up_content, top_k)
|
||
|
||
def decay_all_learners(self, factor: Optional[float] = None) -> None:
|
||
"""
|
||
对所有学习器执行衰减
|
||
|
||
Args:
|
||
factor: 衰减因子
|
||
"""
|
||
for learner in self.learners.values():
|
||
learner.decay_learning(factor)
|
||
logger.info("已对所有学习器执行衰减")
|
||
|
||
def get_all_stats(self) -> Dict[str, Dict]:
|
||
"""获取所有学习器的统计信息"""
|
||
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
|
||
|
||
def save_all_models(self) -> bool:
|
||
"""保存所有模型"""
|
||
success_count = 0
|
||
for learner in self.learners.values():
|
||
if learner.save(self.model_save_path):
|
||
success_count += 1
|
||
|
||
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
|
||
return success_count == len(self.learners)
|
||
|
||
def load_all_models(self) -> int:
|
||
"""加载所有已保存的模型"""
|
||
if not os.path.exists(self.model_save_path):
|
||
return 0
|
||
|
||
loaded_count = 0
|
||
for filename in os.listdir(self.model_save_path):
|
||
if filename.endswith("_style_model.pkl"):
|
||
chat_id = filename.replace("_style_model.pkl", "")
|
||
learner = StyleLearner(chat_id)
|
||
if learner.load(self.model_save_path):
|
||
self.learners[chat_id] = learner
|
||
loaded_count += 1
|
||
|
||
logger.info(f"已加载 {loaded_count} 个模型")
|
||
return loaded_count
|
||
|
||
async def start_auto_save(self) -> None:
|
||
"""启动自动保存任务"""
|
||
if self._auto_save_task is None or self._auto_save_task.done():
|
||
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
|
||
logger.info("已启动自动保存任务")
|
||
|
||
async def stop_auto_save(self) -> None:
|
||
"""停止自动保存任务"""
|
||
if self._auto_save_task and not self._auto_save_task.done():
|
||
self._auto_save_task.cancel()
|
||
try:
|
||
await self._auto_save_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
logger.info("已停止自动保存任务")
|
||
|
||
async def _auto_save_loop(self) -> None:
|
||
"""自动保存循环"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(self.auto_save_interval)
|
||
self.save_all_models()
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"自动保存失败: {e}")
|
||
|
||
|
||
# 全局管理器实例
|
||
style_learner_manager = StyleLearnerManager()
|