remove:移除Exp+model表达方式,一处无用代码,新增配置项

pull/1359/head
SengokuCola 2025-11-15 19:23:23 +08:00
parent d18d77cf4b
commit a2495c7834
13 changed files with 25 additions and 1295 deletions

View File

@ -244,7 +244,7 @@ class DefaultReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)

View File

@ -258,7 +258,7 @@ class PrivateReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)

View File

@ -88,12 +88,6 @@ class ChatConfig(ConfigBase):
mentioned_bot_reply: bool = True
"""是否启用提及必回复"""
auto_chat_value: float = 1
"""自动聊天,越小,麦麦主动聊天的概率越低"""
enable_auto_chat_value_rules: bool = True
"""是否启用动态自动聊天频率规则"""
at_bot_inevitable_reply: float = 1
"""@bot 必然回复1为100%回复0为不额外增幅"""
@ -119,24 +113,7 @@ class ChatConfig(ConfigBase):
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
]
匹配优先级: 先匹配指定 chat 流规则再匹配全局规则(\"\").
时间区间支持跨夜例如 "23:00-02:00"
"""
auto_chat_value_rules: list[dict] = field(default_factory=lambda: [])
"""
自动聊天频率规则列表支持按聊天流/按日内时段配置
规则格式{ target="platform:id:type" "", time="HH:MM-HH:MM", value=0.5 }
示例:
[
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
["", "09:00-22:59", 1.0], # 全局规则:白天正常
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
]
匹配优先级: 先匹配指定 chat 流规则再匹配全局规则(\"\").
匹配优先级: 先匹配指定 chat 流规则再匹配全局规则(\"\").
时间区间支持跨夜例如 "23:00-02:00"
"""
@ -245,61 +222,6 @@ class ChatConfig(ConfigBase):
# 3) 未命中规则返回基础值
return self.talk_value
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 auto_chat_value未匹配则回退到基础值。"""
if not self.enable_auto_chat_value_rules or not self.auto_chat_value_rules:
return self.auto_chat_value
now_min = self._now_minutes()
# 1) 先尝试匹配指定 chat 的规则
if chat_id:
for rule in self.auto_chat_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", "")
time_range = rule.get("time", "")
value = rule.get("value", None)
if not isinstance(time_range, str):
continue
# 跳过全局
if target == "":
continue
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
if config_chat_id is None or config_chat_id != chat_id:
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 2) 再匹配全局规则("")
for rule in self.auto_chat_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", None)
time_range = rule.get("time", "")
value = rule.get("value", None)
if target != "" or not isinstance(time_range, str):
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 3) 未命中规则返回基础值
return self.auto_chat_value
@dataclass
class MessageReceiveConfig(ConfigBase):
@ -316,20 +238,19 @@ class MessageReceiveConfig(ConfigBase):
class MemoryConfig(ConfigBase):
"""记忆配置类"""
max_memory_number: int = 100
"""记忆最大数量"""
max_agent_iterations: int = 5
"""Agent最多迭代轮数最低为1"""
memory_build_frequency: int = 1
"""记忆构建频率"""
def __post_init__(self):
"""验证配置值"""
if self.max_agent_iterations < 1:
raise ValueError(f"max_agent_iterations 必须至少为1当前值: {self.max_agent_iterations}")
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
mode: str = "classic"
"""表达方式模式可选classic经典模式exp_model 表达模型模式"""
learning_list: list[list] = field(default_factory=lambda: [])
"""
表达学习配置列表支持按聊天流配置

View File

@ -14,7 +14,6 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, calculate_similarity
from json_repair import repair_json
@ -180,10 +179,7 @@ class ExpressionLearner:
current_time = time.time()
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
# 存储到数据库 Expression 表
for (
situation,
style,
@ -199,7 +195,6 @@ class ExpressionLearner:
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
continue
else:
Expression.create(
situation=situation,
@ -210,37 +205,6 @@ class ExpressionLearner:
context=context,
up_content=up_content,
)
has_new_expressions = True
# 训练 style_learnerup_content 和 style 必定存在)
try:
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
if success:
logger.debug(
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
+ (f" (situation: {situation})" if situation else "")
)
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions

View File

@ -10,8 +10,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, weighted_sample
from src.express.express_utils import weighted_sample
logger = get_logger("expression_selector")
@ -44,6 +43,8 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
@ -113,89 +114,6 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(
self, chat_id: str, target_message: str, total_num: int = 10
) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message)
logger.info(f"{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, filtered_target_message, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id)
& (Expression.situation == situation)
& (Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append(
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date
if expr.create_date is not None
else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message,
}
)
else:
logger.warning(
f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式"
)
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._random_expressions(chat_id, total_num)
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
@ -247,7 +165,7 @@ class ExpressionSelector:
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
选择适合的表达方式使用classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
@ -263,53 +181,9 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, target_message, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
else:
logger.warning(f"未知的表达模式: {expression_mode}回退到classic模式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
async def _select_expressions_model_only(
self,
chat_id: str,
target_message: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 使用模型预测最合适的表达方式
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
async def _select_expressions_classic(
self,

View File

@ -1,148 +0,0 @@
from typing import Dict, Optional, Tuple, List
from collections import Counter, defaultdict
import pickle
import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排可在线学习
支持存储situation字段不参与计算仅与style对应
"""
def __init__(
self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True,
):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
def add_candidate(self, cid: str, text: str, situation: str = None):
"""添加候选文本和对应的situation"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
"""批量添加候选文本和对应的situations"""
for i, (cid, text) in enumerate(items):
situation = situations[i] if situations and i < len(situations) else None
self.add_candidate(cid, text, situation)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""直接对所有候选进行朴素贝叶斯评分"""
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
# 对所有候选进行评分
tf = Counter(toks)
all_cids = list(self._candidates.keys())
scores = self.nb.score_batch(tf, all_cids)
# 取最高分
if not scores:
return None, {}
# 根据k参数限制返回的候选数量
if k is not None and k > 0:
# 按分数降序排序取前k个
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
limited_scores = dict(sorted_scores[:k])
best = sorted_scores[0][0] if sorted_scores else None
return best, limited_scores
else:
# 如果没有指定k返回所有分数
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump(
{
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
},
},
f,
)
def load(self, path: str):
"""加载模型"""
with open(path, "rb") as f:
obj = pickle.load(f)
# 还原候选文本
self._candidates = obj["candidates"]
# 还原situations兼容旧版本
self._situations = obj.get("situations", {})
# 还原朴素贝叶斯模型
self.nb.cls_counts = obj["nb"]["cls_counts"]
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
self.nb.alpha = obj["nb"]["alpha"]
self.nb.beta = obj["nb"]["beta"]
self.nb.gamma = obj["nb"]["gamma"]
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer

View File

@ -1,61 +0,0 @@
import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
del self._logZ[cid]
def _logZ_c(self, cid: str) -> float:
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
inc = 0.0
tc = self.token_counts[cid]
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: float = None):
g = self.gamma if factor is None else factor
if g >= 1.0:
return
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)

View File

@ -1,34 +0,0 @@
import re
from typing import List, Optional, Set
try:
import jieba
_HAS_JIEBA = True
except Exception:
_HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
# 匹配纯符号的正则表达式
_SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$")
def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower())
class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set()
self.use_jieba = use_jieba and _HAS_JIEBA
def tokenize(self, text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
if self.use_jieba:
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
else:
toks = simple_en_tokenize(text)
# 过滤掉纯符号和停用词
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]

View File

@ -1,621 +0,0 @@
"""
多聊天室表达风格学习系统
支持为每个chat_id维护独立的表达模型学习从up_content到style的映射
"""
import os
import pickle
import traceback
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import asyncio
from src.common.logger import get_logger
from .expressor_model.model import ExpressorModel
logger = get_logger("style_learner")
class StyleLearner:
"""
单个聊天室的表达风格学习器
学习从up_content到style的映射关系
支持动态管理风格集合无数量上限
"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True,
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": defaultdict(int),
"last_update": None,
"style_usage_frequency": defaultdict(int), # 风格使用频率
}
def add_style(self, style: str, situation: str = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 对应的situation文本可选
Returns:
bool: 添加是否成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
logger.info(
f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})"
+ (f", situation: '{situation}'" if situation else "")
)
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False
def remove_style(self, style: str) -> bool:
"""
删除一个风格
Args:
style: 要删除的风格文本
Returns:
bool: 删除是否成功
"""
try:
if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False
style_id = self.style_to_id[style]
# 从映射中删除
del self.style_to_id[style]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建
self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False
def update_style(self, old_style: str, new_style: str) -> bool:
"""
更新一个风格
Args:
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
try:
if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False
if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False
style_id = self.style_to_id[old_style]
# 更新映射
del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
"""
批量添加风格
Args:
styles: 风格文本列表
situations: 对应的situation文本列表可选
Returns:
int: 成功添加的数量
"""
success_count = 0
for i, style in enumerate(styles):
situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation):
success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count
def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格"""
return list(self.style_to_id.keys())
def get_style_count(self) -> int:
"""获取当前风格数量"""
return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]:
"""
获取风格对应的situation
Args:
style: 风格文本
Returns:
Optional[str]: 对应的situation如果不存在则返回None
"""
if style not in self.style_to_id:
return None
style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取风格的完整信息
Args:
style: 风格文本
Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation)
"""
if style not in self.style_to_id:
return None, None
style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""
获取所有风格的完整信息
Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
"""
result = {}
for style, style_id in self.style_to_id.items():
situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation)
return result
def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用"""
try:
# 重新创建expressor
self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
如果style不存在会自动添加
Args:
up_content: 输入内容
style: 对应的style文本
Returns:
bool: 学习是否成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc()
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style文本, 所有候选的分数]
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc()
return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None:
"""
对学习到的知识进行衰减遗忘
Args:
factor: 衰减因子None则使用配置中的gamma
"""
self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict:
"""获取学习统计信息"""
return {
"chat_id": self.chat_id,
"total_samples": self.learning_stats["total_samples"],
"style_count": len(self.style_to_id),
"style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys()),
}
def save(self, base_path: str) -> bool:
"""
保存模型到文件
Args:
base_path: 基础路径实际文件为 {base_path}/{chat_id}_style_model.pkl
"""
try:
os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息
save_data = {
"model_config": self.model_config,
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"learning_stats": self.learning_stats,
}
# 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path)
# 保存其他数据
with open(file_path, "wb") as f:
pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载模型
Args:
base_path: 基础路径
"""
try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False
# 加载其他数据
with open(file_path, "rb") as f:
save_data = pickle.load(f)
# 恢复配置和状态
self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"]
self.id_to_style = save_data["id_to_style"]
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"]
self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False
class StyleLearnerManager:
"""
多聊天室表达风格学习管理器
为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合无数量上限
"""
def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置
self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id添加风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 添加是否成功
"""
learner = self.get_learner(chat_id)
return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id删除风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 删除是否成功
"""
learner = self.get_learner(chat_id)
return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
"""
为指定chat_id更新风格
Args:
chat_id: 聊天室ID
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]:
"""
获取指定chat_id的所有风格
Args:
chat_id: 聊天室ID
Returns:
List[str]: 风格列表
"""
learner = self.get_learner(chat_id)
return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 输入内容
style: 对应的style
Returns:
bool: 学习是否成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的style
Args:
chat_id: 聊天室ID
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style, 所有候选分数]
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None:
"""
对所有学习器执行衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool:
"""保存所有模型"""
success_count = 0
for learner in self.learners.values():
if learner.save(self.model_save_path):
success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners)
def load_all_models(self) -> int:
"""加载所有已保存的模型"""
if not os.path.exists(self.model_save_path):
return 0
loaded_count = 0
for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"):
chat_id = filename.replace("_style_model.pkl", "")
learner = StyleLearner(chat_id)
if learner.load(self.model_save_path):
self.learners[chat_id] = learner
loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count
async def start_auto_save(self) -> None:
"""启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None:
"""停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done():
self._auto_save_task.cancel()
try:
await self._auto_save_task
except asyncio.CancelledError:
pass
logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None:
"""自动保存循环"""
while True:
try:
await asyncio.sleep(self.auto_save_interval)
self.save_all_models()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动保存失败: {e}")
# 全局管理器实例
style_learner_manager = StyleLearnerManager()

View File

@ -891,11 +891,11 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
# 处理人物信息
person_name = analysis_result.get("person_name", "").strip()
memory_content = analysis_result.get("memory_content", "").strip()
if person_name and memory_content:
from src.person_info.person_info import store_person_memory_from_answer
await store_person_memory_from_answer(person_name, memory_content, chat_id)
else:
logger.warning(f"分析为人物信息但未提取到人物名称或记忆内容,问题: {question[:50]}...")
# if person_name and memory_content:
# from src.person_info.person_info import store_person_memory_from_answer
# await store_person_memory_from_answer(person_name, memory_content, chat_id)
# else:
# logger.warning(f"分析为人物信息但未提取到人物名称或记忆内容,问题: {question[:50]}...")
else:
logger.info(f"问题和答案类别为'其他',不进行存储,问题: {question[:50]}...")

View File

@ -1,5 +1,5 @@
[inner]
version = "6.21.1"
version = "6.21.3"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@ -58,10 +58,6 @@ states = [
state_probability = 0.3
[expression]
# 表达方式模式
mode = "classic"
# 可选classic经典模式exp_model 表达模型模式,这个模式需要一定时间学习才会有比较好的效果
# 表达学习配置
learning_list = [ # 表达学习配置列表,支持按聊天流配置
["", "enable", "enable", "1.0"], # 全局配置使用表达启用学习学习强度1.0
@ -89,11 +85,9 @@ expression_groups = [
talk_value = 1 #聊天频率越小越沉默范围0-1
mentioned_bot_reply = true # 是否启用提及必回复
max_context_size = 30 # 上下文长度
auto_chat_value = 1 # 自动聊天,越小,麦麦主动聊天的概率越低
planner_smooth = 5 #规划器平滑增大数值会减小planner负荷略微降低反应速度推荐2-80为关闭必须大于等于0
planner_smooth = 2 #规划器平滑增大数值会减小planner负荷略微降低反应速度推荐1-50为关闭必须大于等于0
enable_talk_value_rules = true # 是否启用动态发言频率规则
enable_auto_chat_value_rules = false # 是否启用动态自动聊天频率规则
# 动态发言频率规则:按时段/按chat_id调整 talk_value优先匹配具体chat再匹配全局
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
@ -107,22 +101,8 @@ talk_value_rules = [
{ target = "qq:114514:private", time = "00:00-23:59", value = 0.3 },
]
# 动态自动聊天频率规则:按时段/按chat_id调整 auto_chat_value优先匹配具体chat再匹配全局
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
# 说明:
# - target 为空字符串表示全局type 为 group/private例如"qq:1919810:group" 或 "qq:114514:private"
# - 支持跨夜区间,例如 "23:00-02:00";数值范围建议 0-1。
auto_chat_value_rules = [
{ target = "", time = "00:00-08:59", value = 0.3 },
{ target = "", time = "09:00-22:59", value = 1.0 },
{ target = "qq:1919810:group", time = "20:00-23:59", value = 0.8 },
{ target = "qq:114514:private", time = "00:00-23:59", value = 0.5 },
]
[memory]
max_memory_number = 100 # 记忆最大数量
max_memory_size = 2048 # 记忆最大大小
memory_build_frequency = 1 # 记忆构建频率
max_agent_iterations = 5 # 记忆思考深度最低为1不深入思考
[jargon]
all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除

View File

@ -1,79 +0,0 @@
#!/usr/bin/env python3
"""
查看 .pkl 文件内容的工具脚本
"""
import pickle
import sys
import os
from pprint import pprint
def view_pkl_file(file_path):
"""查看 pkl 文件内容"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, "rb") as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print(f"📊 数据类型: {type(data)}")
print("=" * 50)
if isinstance(data, dict):
print("🔑 字典键:")
for key in data.keys():
print(f" - {key}: {type(data[key])}")
print()
print("📋 详细内容:")
pprint(data, width=120, depth=10)
elif isinstance(data, list):
print(f"📝 列表长度: {len(data)}")
if data:
print(f"📊 第一个元素类型: {type(data[0])}")
print("📋 前几个元素:")
for i, item in enumerate(data[:3]):
print(f" [{i}]: {item}")
else:
print("📋 内容:")
pprint(data, width=120, depth=10)
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]:
print("\n" + "=" * 50)
print("🔍 详细词汇统计 (token_counts):")
token_counts = data["nb"]["token_counts"]
for style_id, tokens in token_counts.items():
print(f"\n📝 {style_id}:")
if tokens:
# 按词频排序显示前10个词
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
for word, count in sorted_tokens[:10]:
print(f" '{word}': {count}")
if len(sorted_tokens) > 10:
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
else:
print(" (无词汇数据)")
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_pkl.py <pkl文件路径>")
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
return
file_path = sys.argv[1]
view_pkl_file(file_path)
if __name__ == "__main__":
main()

View File

@ -1,66 +0,0 @@
#!/usr/bin/env python3
"""
专门查看 expressor.pkl 文件中 token_counts 的脚本
"""
import pickle
import sys
import os
def view_token_counts(file_path):
"""查看 expressor.pkl 文件中的词汇统计"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, "rb") as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print("=" * 60)
if "nb" not in data or "token_counts" not in data["nb"]:
print("❌ 这不是一个 expressor 模型文件")
return
token_counts = data["nb"]["token_counts"]
candidates = data.get("candidates", {})
print(f"🎯 找到 {len(token_counts)} 个风格")
print("=" * 60)
for style_id, tokens in token_counts.items():
style_text = candidates.get(style_id, "未知风格")
print(f"\n📝 {style_id}: {style_text}")
print(f"📊 词汇数量: {len(tokens)}")
if tokens:
# 按词频排序
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
print("🔤 词汇统计 (按频率排序):")
for i, (word, count) in enumerate(sorted_tokens):
print(f" {i + 1:2d}. '{word}': {count}")
else:
print(" (无词汇数据)")
print("-" * 40)
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_tokens.py <expressor.pkl文件路径>")
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
return
file_path = sys.argv[1]
view_token_counts(file_path)
if __name__ == "__main__":
main()