From a2495c7834600a0bd06781d0514792853516c4fd Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 15 Nov 2025 19:23:23 +0800 Subject: [PATCH] =?UTF-8?q?remove=EF=BC=9A=E7=A7=BB=E9=99=A4Exp+model?= =?UTF-8?q?=E8=A1=A8=E8=BE=BE=E6=96=B9=E5=BC=8F=EF=BC=8C=E4=B8=80=E5=A4=84?= =?UTF-8?q?=E6=97=A0=E7=94=A8=E4=BB=A3=E7=A0=81=EF=BC=8C=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/group_generator.py | 2 +- src/chat/replyer/private_generator.py | 2 +- src/config/official_configs.py | 93 +--- src/express/expression_learner.py | 38 +- src/express/expression_selector.py | 140 +---- src/express/expressor_model/model.py | 148 ------ src/express/expressor_model/online_nb.py | 61 --- src/express/expressor_model/tokenizer.py | 34 -- src/express/style_learner.py | 621 ----------------------- src/memory_system/memory_retrieval.py | 10 +- template/bot_config_template.toml | 26 +- view_pkl.py | 79 --- view_tokens.py | 66 --- 13 files changed, 25 insertions(+), 1295 deletions(-) delete mode 100644 src/express/expressor_model/model.py delete mode 100644 src/express/expressor_model/online_nb.py delete mode 100644 src/express/expressor_model/tokenizer.py delete mode 100644 src/express/style_learner.py delete mode 100644 view_pkl.py delete mode 100644 view_tokens.py diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 38cbdac6..b26ba548 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -244,7 +244,7 @@ class DefaultReplyer: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 - # 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择 + # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( self.chat_stream.stream_id, chat_history, max_num=8, target_message=target ) diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 15887a09..37019724 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -258,7 +258,7 @@ class PrivateReplyer: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 - # 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择 + # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( self.chat_stream.stream_id, chat_history, max_num=8, target_message=target ) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 7997dc32..abbd73d2 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -88,12 +88,6 @@ class ChatConfig(ConfigBase): mentioned_bot_reply: bool = True """是否启用提及必回复""" - auto_chat_value: float = 1 - """自动聊天,越小,麦麦主动聊天的概率越低""" - - enable_auto_chat_value_rules: bool = True - """是否启用动态自动聊天频率规则""" - at_bot_inevitable_reply: float = 1 """@bot 必然回复,1为100%回复,0为不额外增幅""" @@ -119,24 +113,7 @@ class ChatConfig(ConfigBase): ["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静 ] - 匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\"). - 时间区间支持跨夜,例如 "23:00-02:00"。 - """ - - auto_chat_value_rules: list[dict] = field(default_factory=lambda: []) - """ - 自动聊天频率规则列表,支持按聊天流/按日内时段配置。 - 规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 } - - 示例: - [ - ["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静 - ["", "09:00-22:59", 1.0], # 全局规则:白天正常 - ["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言 - ["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静 - ] - - 匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\"). + 匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\"). 时间区间支持跨夜,例如 "23:00-02:00"。 """ @@ -245,61 +222,6 @@ class ChatConfig(ConfigBase): # 3) 未命中规则返回基础值 return self.talk_value - def get_auto_chat_value(self, chat_id: Optional[str]) -> float: - """根据规则返回当前 chat 的动态 auto_chat_value,未匹配则回退到基础值。""" - if not self.enable_auto_chat_value_rules or not self.auto_chat_value_rules: - return self.auto_chat_value - - now_min = self._now_minutes() - - # 1) 先尝试匹配指定 chat 的规则 - if chat_id: - for rule in self.auto_chat_value_rules: - if not isinstance(rule, dict): - continue - target = rule.get("target", "") - time_range = rule.get("time", "") - value = rule.get("value", None) - if not isinstance(time_range, str): - continue - # 跳过全局 - if target == "": - continue - config_chat_id = self._parse_stream_config_to_chat_id(str(target)) - if config_chat_id is None or config_chat_id != chat_id: - continue - parsed = self._parse_range(time_range) - if not parsed: - continue - start_min, end_min = parsed - if self._in_range(now_min, start_min, end_min): - try: - return float(value) - except Exception: - continue - - # 2) 再匹配全局规则("") - for rule in self.auto_chat_value_rules: - if not isinstance(rule, dict): - continue - target = rule.get("target", None) - time_range = rule.get("time", "") - value = rule.get("value", None) - if target != "" or not isinstance(time_range, str): - continue - parsed = self._parse_range(time_range) - if not parsed: - continue - start_min, end_min = parsed - if self._in_range(now_min, start_min, end_min): - try: - return float(value) - except Exception: - continue - - # 3) 未命中规则返回基础值 - return self.auto_chat_value - @dataclass class MessageReceiveConfig(ConfigBase): @@ -316,20 +238,19 @@ class MessageReceiveConfig(ConfigBase): class MemoryConfig(ConfigBase): """记忆配置类""" - max_memory_number: int = 100 - """记忆最大数量""" + max_agent_iterations: int = 5 + """Agent最多迭代轮数(最低为1)""" - memory_build_frequency: int = 1 - """记忆构建频率""" + def __post_init__(self): + """验证配置值""" + if self.max_agent_iterations < 1: + raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}") @dataclass class ExpressionConfig(ConfigBase): """表达配置类""" - mode: str = "classic" - """表达方式模式,可选:classic经典模式,exp_model 表达模型模式""" - learning_list: list[list] = field(default_factory=lambda: []) """ 表达学习配置列表,支持按聊天流配置 diff --git a/src/express/expression_learner.py b/src/express/expression_learner.py index b4c357d9..72dd831a 100644 --- a/src/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -14,7 +14,6 @@ from src.chat.utils.chat_message_builder import ( ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.express.style_learner import style_learner_manager from src.express.express_utils import filter_message_content, calculate_similarity from json_repair import repair_json @@ -180,10 +179,7 @@ class ExpressionLearner: current_time = time.time() - # 存储到数据库 Expression 表并训练 style_learner - has_new_expressions = False # 记录是否有新的表达方式 - learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例 - + # 存储到数据库 Expression 表 for ( situation, style, @@ -199,7 +195,6 @@ class ExpressionLearner: expr_obj = query.get() expr_obj.last_active_time = current_time expr_obj.save() - continue else: Expression.create( situation=situation, @@ -210,37 +205,6 @@ class ExpressionLearner: context=context, up_content=up_content, ) - has_new_expressions = True - - # 训练 style_learner(up_content 和 style 必定存在) - try: - learner.add_style(style, situation) - - # 学习映射关系 - success = style_learner_manager.learn_mapping(self.chat_id, up_content, style) - if success: - logger.debug( - f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" - + (f" (situation: {situation})" if situation else "") - ) - else: - logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}") - except Exception as e: - logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}") - - # 保存当前聊天室的 style_learner 模型 - if has_new_expressions: - try: - logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...") - save_success = learner.save(style_learner_manager.model_save_path) - - if save_success: - logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}") - else: - logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}") - - except Exception as e: - logger.error(f"StyleLearner 模型保存异常: {e}") return learnt_expressions diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py index 0650c954..e5daed31 100644 --- a/src/express/expression_selector.py +++ b/src/express/expression_selector.py @@ -10,8 +10,7 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.express.style_learner import style_learner_manager -from src.express.express_utils import filter_message_content, weighted_sample +from src.express.express_utils import weighted_sample logger = get_logger("expression_selector") @@ -44,6 +43,8 @@ def init_prompt(): Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") + + class ExpressionSelector: def __init__(self): self.llm_model = LLMRequest( @@ -113,89 +114,6 @@ class ExpressionSelector: return group_chat_ids return [chat_id] - def get_model_predicted_expressions( - self, chat_id: str, target_message: str, total_num: int = 10 - ) -> List[Dict[str, Any]]: - """ - 使用 style_learner 模型预测最合适的表达方式 - - Args: - chat_id: 聊天室ID - target_message: 目标消息内容 - total_num: 需要预测的数量 - - Returns: - List[Dict[str, Any]]: 预测的表达方式列表 - """ - try: - # 过滤目标消息内容,移除回复、表情包等特殊格式 - filtered_target_message = filter_message_content(target_message) - - logger.info(f"为{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}") - - # 支持多chat_id合并预测 - related_chat_ids = self.get_related_chat_ids(chat_id) - - predicted_expressions = [] - - # 为每个相关的chat_id进行预测 - for related_chat_id in related_chat_ids: - try: - # 使用 style_learner 预测最合适的风格 - best_style, scores = style_learner_manager.predict_style( - related_chat_id, filtered_target_message, top_k=total_num - ) - - if best_style and scores: - # 获取预测风格的完整信息 - learner = style_learner_manager.get_learner(related_chat_id) - style_id, situation = learner.get_style_info(best_style) - - if style_id and situation: - # 从数据库查找对应的表达记录 - expr_query = Expression.select().where( - (Expression.chat_id == related_chat_id) - & (Expression.situation == situation) - & (Expression.style == best_style) - ) - - if expr_query.exists(): - expr = expr_query.get() - predicted_expressions.append( - { - "id": expr.id, - "situation": expr.situation, - "style": expr.style, - "last_active_time": expr.last_active_time, - "source_id": expr.chat_id, - "create_date": expr.create_date - if expr.create_date is not None - else expr.last_active_time, - "prediction_score": scores.get(best_style, 0.0), - "prediction_input": filtered_target_message, - } - ) - else: - logger.warning( - f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式" - ) - - except Exception as e: - logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}") - continue - - # 按预测分数排序,取前 total_num 个 - predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True) - selected_expressions = predicted_expressions[:total_num] - - logger.info(f"为{chat_id} 预测到 {len(selected_expressions)} 个表达方式") - return selected_expressions - - except Exception as e: - logger.error(f"模型预测表达方式失败: {e}") - # 如果预测失败,回退到随机选择 - return self._random_expressions(chat_id, total_num) - def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: """ 随机选择表达方式 @@ -247,7 +165,7 @@ class ExpressionSelector: target_message: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], List[int]]: """ - 根据配置模式选择适合的表达方式 + 选择适合的表达方式(使用classic模式:随机选择+LLM选择) Args: chat_id: 聊天流ID @@ -263,53 +181,9 @@ class ExpressionSelector: logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") return [], [] - # 获取配置模式 - expression_mode = global_config.expression.mode - - if expression_mode == "exp_model": - # exp_model模式:直接使用模型预测,不经过LLM - logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式") - return await self._select_expressions_model_only(chat_id, target_message, max_num) - elif expression_mode == "classic": - # classic模式:随机选择+LLM选择 - logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式") - return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message) - else: - logger.warning(f"未知的表达模式: {expression_mode},回退到classic模式") - return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message) - - async def _select_expressions_model_only( - self, - chat_id: str, - target_message: str, - max_num: int = 10, - ) -> Tuple[List[Dict[str, Any]], List[int]]: - """ - exp_model模式:直接使用模型预测,不经过LLM - - Args: - chat_id: 聊天流ID - target_message: 目标消息内容 - max_num: 最大选择数量 - - Returns: - Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 - """ - try: - # 使用模型预测最合适的表达方式 - selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num) - selected_ids = [expr["id"] for expr in selected_expressions] - - # 更新last_active_time - if selected_expressions: - self.update_expressions_last_active_time(selected_expressions) - - logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式") - return selected_expressions, selected_ids - - except Exception as e: - logger.error(f"exp_model模式选择表达方式失败: {e}") - return [], [] + # 使用classic模式(随机选择+LLM选择) + logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式") + return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message) async def _select_expressions_classic( self, diff --git a/src/express/expressor_model/model.py b/src/express/expressor_model/model.py deleted file mode 100644 index 563821e2..00000000 --- a/src/express/expressor_model/model.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Dict, Optional, Tuple, List -from collections import Counter, defaultdict -import pickle -import os - -from .tokenizer import Tokenizer -from .online_nb import OnlineNaiveBayes - - -class ExpressorModel: - """ - 直接使用朴素贝叶斯精排(可在线学习) - 支持存储situation字段,不参与计算,仅与style对应 - """ - - def __init__( - self, - alpha: float = 0.5, - beta: float = 0.5, - gamma: float = 1.0, - vocab_size: int = 200000, - use_jieba: bool = True, - ): - self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) - self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) - self._candidates: Dict[str, str] = {} # cid -> text (style) - self._situations: Dict[str, str] = {} # cid -> situation (不参与计算) - - def add_candidate(self, cid: str, text: str, situation: str = None): - """添加候选文本和对应的situation""" - self._candidates[cid] = text - if situation is not None: - self._situations[cid] = situation - - # 确保在nb模型中初始化该候选的计数 - if cid not in self.nb.cls_counts: - self.nb.cls_counts[cid] = 0.0 - if cid not in self.nb.token_counts: - self.nb.token_counts[cid] = defaultdict(float) - - def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None): - """批量添加候选文本和对应的situations""" - for i, (cid, text) in enumerate(items): - situation = situations[i] if situations and i < len(situations) else None - self.add_candidate(cid, text, situation) - - def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]: - """直接对所有候选进行朴素贝叶斯评分""" - toks = self.tokenizer.tokenize(text) - if not toks: - return None, {} - - if not self._candidates: - return None, {} - - # 对所有候选进行评分 - tf = Counter(toks) - all_cids = list(self._candidates.keys()) - scores = self.nb.score_batch(tf, all_cids) - - # 取最高分 - if not scores: - return None, {} - - # 根据k参数限制返回的候选数量 - if k is not None and k > 0: - # 按分数降序排序,取前k个 - sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) - limited_scores = dict(sorted_scores[:k]) - best = sorted_scores[0][0] if sorted_scores else None - return best, limited_scores - else: - # 如果没有指定k,返回所有分数 - best = max(scores.items(), key=lambda x: x[1])[0] - return best, scores - - def update_positive(self, text: str, cid: str): - """更新正反馈学习""" - toks = self.tokenizer.tokenize(text) - if not toks: - return - tf = Counter(toks) - self.nb.update_positive(tf, cid) - - def decay(self, factor: float): - self.nb.decay(factor=factor) - - def get_situation(self, cid: str) -> Optional[str]: - """获取候选对应的situation""" - return self._situations.get(cid) - - def get_style(self, cid: str) -> Optional[str]: - """获取候选对应的style""" - return self._candidates.get(cid) - - def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: - """获取候选的style和situation信息""" - return self._candidates.get(cid), self._situations.get(cid) - - def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]: - """获取所有候选的style和situation信息""" - return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()} - - def save(self, path: str): - """保存模型""" - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "wb") as f: - pickle.dump( - { - "candidates": self._candidates, - "situations": self._situations, - "nb": { - "cls_counts": dict(self.nb.cls_counts), - "token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()}, - "alpha": self.nb.alpha, - "beta": self.nb.beta, - "gamma": self.nb.gamma, - "V": self.nb.V, - }, - }, - f, - ) - - def load(self, path: str): - """加载模型""" - with open(path, "rb") as f: - obj = pickle.load(f) - # 还原候选文本 - self._candidates = obj["candidates"] - # 还原situations(兼容旧版本) - self._situations = obj.get("situations", {}) - # 还原朴素贝叶斯模型 - self.nb.cls_counts = obj["nb"]["cls_counts"] - self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"]) - self.nb.alpha = obj["nb"]["alpha"] - self.nb.beta = obj["nb"]["beta"] - self.nb.gamma = obj["nb"]["gamma"] - self.nb.V = obj["nb"]["V"] - self.nb._logZ.clear() - - -def defaultdict_dict(d: Dict[str, Dict[str, float]]): - from collections import defaultdict - - outer = defaultdict(lambda: defaultdict(float)) - for k, inner in d.items(): - outer[k].update(inner) - return outer diff --git a/src/express/expressor_model/online_nb.py b/src/express/expressor_model/online_nb.py deleted file mode 100644 index fff25c08..00000000 --- a/src/express/expressor_model/online_nb.py +++ /dev/null @@ -1,61 +0,0 @@ -import math -from typing import Dict, List -from collections import defaultdict, Counter - - -class OnlineNaiveBayes: - def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): - self.alpha = alpha - self.beta = beta - self.gamma = gamma - self.V = vocab_size - - self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count - self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count - self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) - - def _invalidate(self, cid: str): - if cid in self._logZ: - del self._logZ[cid] - - def _logZ_c(self, cid: str) -> float: - if cid not in self._logZ: - Z = self.cls_counts[cid] + self.V * self.alpha - self._logZ[cid] = math.log(max(Z, 1e-12)) - return self._logZ[cid] - - def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]: - total_cls = sum(self.cls_counts.values()) - n_cls = max(1, len(self.cls_counts)) - denom_prior = math.log(total_cls + self.beta * n_cls) - - out: Dict[str, float] = {} - for cid in cids: - prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior - s = prior - logZ = self._logZ_c(cid) - tc = self.token_counts[cid] - for term, qtf in tf.items(): - num = tc.get(term, 0.0) + self.alpha - s += qtf * (math.log(num) - logZ) - out[cid] = s - return out - - def update_positive(self, tf: Counter, cid: str): - inc = 0.0 - tc = self.token_counts[cid] - for term, c in tf.items(): - tc[term] += float(c) - inc += float(c) - self.cls_counts[cid] += inc - self._invalidate(cid) - - def decay(self, factor: float = None): - g = self.gamma if factor is None else factor - if g >= 1.0: - return - for cid in list(self.cls_counts.keys()): - self.cls_counts[cid] *= g - for term in list(self.token_counts[cid].keys()): - self.token_counts[cid][term] *= g - self._invalidate(cid) diff --git a/src/express/expressor_model/tokenizer.py b/src/express/expressor_model/tokenizer.py deleted file mode 100644 index 61a55950..00000000 --- a/src/express/expressor_model/tokenizer.py +++ /dev/null @@ -1,34 +0,0 @@ -import re -from typing import List, Optional, Set - -try: - import jieba - - _HAS_JIEBA = True -except Exception: - _HAS_JIEBA = False - -_WORD_RE = re.compile(r"[A-Za-z0-9_]+") -# 匹配纯符号的正则表达式 -_SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$") - - -def simple_en_tokenize(text: str) -> List[str]: - return _WORD_RE.findall(text.lower()) - - -class Tokenizer: - def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True): - self.stopwords = stopwords or set() - self.use_jieba = use_jieba and _HAS_JIEBA - - def tokenize(self, text: str) -> List[str]: - text = (text or "").strip() - if not text: - return [] - if self.use_jieba: - toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()] - else: - toks = simple_en_tokenize(text) - # 过滤掉纯符号和停用词 - return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)] diff --git a/src/express/style_learner.py b/src/express/style_learner.py deleted file mode 100644 index 1a40d27b..00000000 --- a/src/express/style_learner.py +++ /dev/null @@ -1,621 +0,0 @@ -""" -多聊天室表达风格学习系统 -支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射 -""" - -import os -import pickle -import traceback -from typing import Dict, List, Optional, Tuple -from collections import defaultdict -import asyncio - -from src.common.logger import get_logger -from .expressor_model.model import ExpressorModel - -logger = get_logger("style_learner") - - -class StyleLearner: - """ - 单个聊天室的表达风格学习器 - 学习从up_content到style的映射关系 - 支持动态管理风格集合(无数量上限) - """ - - def __init__(self, chat_id: str, model_config: Optional[Dict] = None): - self.chat_id = chat_id - self.model_config = model_config or { - "alpha": 0.5, - "beta": 0.5, - "gamma": 0.99, # 衰减因子,支持遗忘 - "vocab_size": 200000, - "use_jieba": True, - } - - # 初始化表达模型 - self.expressor = ExpressorModel(**self.model_config) - - # 动态风格管理 - self.style_to_id: Dict[str, str] = {} # style文本 -> style_id - self.id_to_style: Dict[str, str] = {} # style_id -> style文本 - self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 - self.next_style_id = 0 # 下一个可用的style_id - - # 学习统计 - self.learning_stats = { - "total_samples": 0, - "style_counts": defaultdict(int), - "last_update": None, - "style_usage_frequency": defaultdict(int), # 风格使用频率 - } - - def add_style(self, style: str, situation: str = None) -> bool: - """ - 动态添加一个新的风格 - - Args: - style: 风格文本 - situation: 对应的situation文本(可选) - - Returns: - bool: 添加是否成功 - """ - try: - # 检查是否已存在 - if style in self.style_to_id: - logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在") - return True - - # 生成新的style_id - style_id = f"style_{self.next_style_id}" - self.next_style_id += 1 - - # 添加到映射 - self.style_to_id[style] = style_id - self.id_to_style[style_id] = style - if situation: - self.id_to_situation[style_id] = situation - - # 添加到expressor模型 - self.expressor.add_candidate(style_id, style, situation) - - logger.info( - f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" - + (f", situation: '{situation}'" if situation else "") - ) - return True - - except Exception as e: - logger.error(f"[{self.chat_id}] 添加风格失败: {e}") - return False - - def remove_style(self, style: str) -> bool: - """ - 删除一个风格 - - Args: - style: 要删除的风格文本 - - Returns: - bool: 删除是否成功 - """ - try: - if style not in self.style_to_id: - logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在") - return False - - style_id = self.style_to_id[style] - - # 从映射中删除 - del self.style_to_id[style] - del self.id_to_style[style_id] - if style_id in self.id_to_situation: - del self.id_to_situation[style_id] - - # 从expressor模型中删除(通过重新构建) - self._rebuild_expressor() - - logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})") - return True - - except Exception as e: - logger.error(f"[{self.chat_id}] 删除风格失败: {e}") - return False - - def update_style(self, old_style: str, new_style: str) -> bool: - """ - 更新一个风格 - - Args: - old_style: 原风格文本 - new_style: 新风格文本 - - Returns: - bool: 更新是否成功 - """ - try: - if old_style not in self.style_to_id: - logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在") - return False - - if new_style in self.style_to_id and new_style != old_style: - logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在") - return False - - style_id = self.style_to_id[old_style] - - # 更新映射 - del self.style_to_id[old_style] - self.style_to_id[new_style] = style_id - self.id_to_style[style_id] = new_style - - # 更新expressor模型(保留原有的situation) - situation = self.id_to_situation.get(style_id) - self.expressor.add_candidate(style_id, new_style, situation) - - logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'") - return True - - except Exception as e: - logger.error(f"[{self.chat_id}] 更新风格失败: {e}") - return False - - def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int: - """ - 批量添加风格 - - Args: - styles: 风格文本列表 - situations: 对应的situation文本列表(可选) - - Returns: - int: 成功添加的数量 - """ - success_count = 0 - for i, style in enumerate(styles): - situation = situations[i] if situations and i < len(situations) else None - if self.add_style(style, situation): - success_count += 1 - - logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功") - return success_count - - def get_all_styles(self) -> List[str]: - """获取所有已注册的风格""" - return list(self.style_to_id.keys()) - - def get_style_count(self) -> int: - """获取当前风格数量""" - return len(self.style_to_id) - - def get_situation(self, style: str) -> Optional[str]: - """ - 获取风格对应的situation - - Args: - style: 风格文本 - - Returns: - Optional[str]: 对应的situation,如果不存在则返回None - """ - if style not in self.style_to_id: - return None - - style_id = self.style_to_id[style] - return self.id_to_situation.get(style_id) - - def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: - """ - 获取风格的完整信息 - - Args: - style: 风格文本 - - Returns: - Tuple[Optional[str], Optional[str]]: (style_id, situation) - """ - if style not in self.style_to_id: - return None, None - - style_id = self.style_to_id[style] - situation = self.id_to_situation.get(style_id) - return style_id, situation - - def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]: - """ - 获取所有风格的完整信息 - - Returns: - Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)} - """ - result = {} - for style, style_id in self.style_to_id.items(): - situation = self.id_to_situation.get(style_id) - result[style] = (style_id, situation) - return result - - def _rebuild_expressor(self): - """重新构建expressor模型(删除风格后使用)""" - try: - # 重新创建expressor - self.expressor = ExpressorModel(**self.model_config) - - # 重新添加所有风格和situation - for style_id, style_text in self.id_to_style.items(): - situation = self.id_to_situation.get(style_id) - self.expressor.add_candidate(style_id, style_text, situation) - - logger.debug(f"[{self.chat_id}] 已重新构建expressor模型") - - except Exception as e: - logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}") - - def learn_mapping(self, up_content: str, style: str) -> bool: - """ - 学习一个up_content到style的映射 - 如果style不存在,会自动添加 - - Args: - up_content: 输入内容 - style: 对应的style文本 - - Returns: - bool: 学习是否成功 - """ - try: - # 如果style不存在,先添加它 - if style not in self.style_to_id: - if not self.add_style(style): - logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败") - return False - - # 获取style_id - style_id = self.style_to_id[style] - - # 使用正反馈学习 - self.expressor.update_positive(up_content, style_id) - - # 更新统计 - self.learning_stats["total_samples"] += 1 - self.learning_stats["style_counts"][style_id] += 1 - self.learning_stats["style_usage_frequency"][style] += 1 - self.learning_stats["last_update"] = asyncio.get_event_loop().time() - - logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'") - return True - - except Exception as e: - logger.error(f"[{self.chat_id}] 学习映射失败: {e}") - traceback.print_exc() - return False - - def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: - """ - 根据up_content预测最合适的style - - Args: - up_content: 输入内容 - top_k: 返回前k个候选 - - Returns: - Tuple[最佳style文本, 所有候选的分数] - """ - try: - best_style_id, scores = self.expressor.predict(up_content, k=top_k) - - if best_style_id is None: - return None, {} - - # 将style_id转换为style文本 - best_style = self.id_to_style.get(best_style_id) - - # 转换所有分数 - style_scores = {} - for sid, score in scores.items(): - style_text = self.id_to_style.get(sid) - if style_text: - style_scores[style_text] = score - - return best_style, style_scores - - except Exception as e: - logger.error(f"[{self.chat_id}] 预测style失败: {e}") - traceback.print_exc() - return None, {} - - def decay_learning(self, factor: Optional[float] = None) -> None: - """ - 对学习到的知识进行衰减(遗忘) - - Args: - factor: 衰减因子,None则使用配置中的gamma - """ - self.expressor.decay(factor) - logger.debug(f"[{self.chat_id}] 执行知识衰减") - - def get_stats(self) -> Dict: - """获取学习统计信息""" - return { - "chat_id": self.chat_id, - "total_samples": self.learning_stats["total_samples"], - "style_count": len(self.style_to_id), - "style_counts": dict(self.learning_stats["style_counts"]), - "style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]), - "last_update": self.learning_stats["last_update"], - "all_styles": list(self.style_to_id.keys()), - } - - def save(self, base_path: str) -> bool: - """ - 保存模型到文件 - - Args: - base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl - """ - try: - os.makedirs(base_path, exist_ok=True) - file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") - - # 保存模型和统计信息 - save_data = { - "model_config": self.model_config, - "style_to_id": self.style_to_id, - "id_to_style": self.id_to_style, - "id_to_situation": self.id_to_situation, - "next_style_id": self.next_style_id, - "learning_stats": self.learning_stats, - } - - # 先保存expressor模型 - expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") - self.expressor.save(expressor_path) - - # 保存其他数据 - with open(file_path, "wb") as f: - pickle.dump(save_data, f) - - logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}") - return True - - except Exception as e: - logger.error(f"[{self.chat_id}] 保存模型失败: {e}") - return False - - def load(self, base_path: str) -> bool: - """ - 从文件加载模型 - - Args: - base_path: 基础路径 - """ - try: - file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") - expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") - - if not os.path.exists(file_path) or not os.path.exists(expressor_path): - logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置") - return False - - # 加载其他数据 - with open(file_path, "rb") as f: - save_data = pickle.load(f) - - # 恢复配置和状态 - self.model_config = save_data["model_config"] - self.style_to_id = save_data["style_to_id"] - self.id_to_style = save_data["id_to_style"] - self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本 - self.next_style_id = save_data["next_style_id"] - self.learning_stats = save_data["learning_stats"] - - # 重新创建expressor并加载 - self.expressor = ExpressorModel(**self.model_config) - self.expressor.load(expressor_path) - - logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载") - return True - - except Exception as e: - logger.error(f"[{self.chat_id}] 加载模型失败: {e}") - return False - - -class StyleLearnerManager: - """ - 多聊天室表达风格学习管理器 - 为每个chat_id维护独立的StyleLearner实例 - 每个chat_id可以动态管理自己的风格集合(无数量上限) - """ - - def __init__(self, model_save_path: str = "data/style_models"): - self.model_save_path = model_save_path - self.learners: Dict[str, StyleLearner] = {} - - # 自动保存配置 - self.auto_save_interval = 300 # 5分钟 - self._auto_save_task: Optional[asyncio.Task] = None - - logger.info("StyleLearnerManager 已初始化") - - def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: - """ - 获取或创建指定chat_id的学习器 - - Args: - chat_id: 聊天室ID - model_config: 模型配置,None则使用默认配置 - - Returns: - StyleLearner实例 - """ - if chat_id not in self.learners: - # 创建新的学习器 - learner = StyleLearner(chat_id, model_config) - - # 尝试加载已保存的模型 - learner.load(self.model_save_path) - - self.learners[chat_id] = learner - logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner") - - return self.learners[chat_id] - - def add_style(self, chat_id: str, style: str) -> bool: - """ - 为指定chat_id添加风格 - - Args: - chat_id: 聊天室ID - style: 风格文本 - - Returns: - bool: 添加是否成功 - """ - learner = self.get_learner(chat_id) - return learner.add_style(style) - - def remove_style(self, chat_id: str, style: str) -> bool: - """ - 为指定chat_id删除风格 - - Args: - chat_id: 聊天室ID - style: 风格文本 - - Returns: - bool: 删除是否成功 - """ - learner = self.get_learner(chat_id) - return learner.remove_style(style) - - def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool: - """ - 为指定chat_id更新风格 - - Args: - chat_id: 聊天室ID - old_style: 原风格文本 - new_style: 新风格文本 - - Returns: - bool: 更新是否成功 - """ - learner = self.get_learner(chat_id) - return learner.update_style(old_style, new_style) - - def get_chat_styles(self, chat_id: str) -> List[str]: - """ - 获取指定chat_id的所有风格 - - Args: - chat_id: 聊天室ID - - Returns: - List[str]: 风格列表 - """ - learner = self.get_learner(chat_id) - return learner.get_all_styles() - - def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: - """ - 学习一个映射关系 - - Args: - chat_id: 聊天室ID - up_content: 输入内容 - style: 对应的style - - Returns: - bool: 学习是否成功 - """ - learner = self.get_learner(chat_id) - return learner.learn_mapping(up_content, style) - - def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: - """ - 预测最合适的style - - Args: - chat_id: 聊天室ID - up_content: 输入内容 - top_k: 返回前k个候选 - - Returns: - Tuple[最佳style, 所有候选分数] - """ - learner = self.get_learner(chat_id) - return learner.predict_style(up_content, top_k) - - def decay_all_learners(self, factor: Optional[float] = None) -> None: - """ - 对所有学习器执行衰减 - - Args: - factor: 衰减因子 - """ - for learner in self.learners.values(): - learner.decay_learning(factor) - logger.info("已对所有学习器执行衰减") - - def get_all_stats(self) -> Dict[str, Dict]: - """获取所有学习器的统计信息""" - return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} - - def save_all_models(self) -> bool: - """保存所有模型""" - success_count = 0 - for learner in self.learners.values(): - if learner.save(self.model_save_path): - success_count += 1 - - logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型") - return success_count == len(self.learners) - - def load_all_models(self) -> int: - """加载所有已保存的模型""" - if not os.path.exists(self.model_save_path): - return 0 - - loaded_count = 0 - for filename in os.listdir(self.model_save_path): - if filename.endswith("_style_model.pkl"): - chat_id = filename.replace("_style_model.pkl", "") - learner = StyleLearner(chat_id) - if learner.load(self.model_save_path): - self.learners[chat_id] = learner - loaded_count += 1 - - logger.info(f"已加载 {loaded_count} 个模型") - return loaded_count - - async def start_auto_save(self) -> None: - """启动自动保存任务""" - if self._auto_save_task is None or self._auto_save_task.done(): - self._auto_save_task = asyncio.create_task(self._auto_save_loop()) - logger.info("已启动自动保存任务") - - async def stop_auto_save(self) -> None: - """停止自动保存任务""" - if self._auto_save_task and not self._auto_save_task.done(): - self._auto_save_task.cancel() - try: - await self._auto_save_task - except asyncio.CancelledError: - pass - logger.info("已停止自动保存任务") - - async def _auto_save_loop(self) -> None: - """自动保存循环""" - while True: - try: - await asyncio.sleep(self.auto_save_interval) - self.save_all_models() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"自动保存失败: {e}") - - -# 全局管理器实例 -style_learner_manager = StyleLearnerManager() diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 6bc42807..7e4a4a6a 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -891,11 +891,11 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) -> # 处理人物信息 person_name = analysis_result.get("person_name", "").strip() memory_content = analysis_result.get("memory_content", "").strip() - if person_name and memory_content: - from src.person_info.person_info import store_person_memory_from_answer - await store_person_memory_from_answer(person_name, memory_content, chat_id) - else: - logger.warning(f"分析为人物信息但未提取到人物名称或记忆内容,问题: {question[:50]}...") + # if person_name and memory_content: + # from src.person_info.person_info import store_person_memory_from_answer + # await store_person_memory_from_answer(person_name, memory_content, chat_id) + # else: + # logger.warning(f"分析为人物信息但未提取到人物名称或记忆内容,问题: {question[:50]}...") else: logger.info(f"问题和答案类别为'其他',不进行存储,问题: {question[:50]}...") diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 4ab75552..45d97c30 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.21.1" +version = "6.21.3" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -58,10 +58,6 @@ states = [ state_probability = 0.3 [expression] -# 表达方式模式 -mode = "classic" -# 可选:classic经典模式,exp_model 表达模型模式,这个模式需要一定时间学习才会有比较好的效果 - # 表达学习配置 learning_list = [ # 表达学习配置列表,支持按聊天流配置 ["", "enable", "enable", "1.0"], # 全局配置:使用表达,启用学习,学习强度1.0 @@ -89,11 +85,9 @@ expression_groups = [ talk_value = 1 #聊天频率,越小越沉默,范围0-1 mentioned_bot_reply = true # 是否启用提及必回复 max_context_size = 30 # 上下文长度 -auto_chat_value = 1 # 自动聊天,越小,麦麦主动聊天的概率越低 -planner_smooth = 5 #规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-8,0为关闭,必须大于等于0 +planner_smooth = 2 #规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0 enable_talk_value_rules = true # 是否启用动态发言频率规则 -enable_auto_chat_value_rules = false # 是否启用动态自动聊天频率规则 # 动态发言频率规则:按时段/按chat_id调整 talk_value(优先匹配具体chat,再匹配全局) # 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 } @@ -107,22 +101,8 @@ talk_value_rules = [ { target = "qq:114514:private", time = "00:00-23:59", value = 0.3 }, ] -# 动态自动聊天频率规则:按时段/按chat_id调整 auto_chat_value(优先匹配具体chat,再匹配全局) -# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 } -# 说明: -# - target 为空字符串表示全局;type 为 group/private,例如:"qq:1919810:group" 或 "qq:114514:private"; -# - 支持跨夜区间,例如 "23:00-02:00";数值范围建议 0-1。 -auto_chat_value_rules = [ - { target = "", time = "00:00-08:59", value = 0.3 }, - { target = "", time = "09:00-22:59", value = 1.0 }, - { target = "qq:1919810:group", time = "20:00-23:59", value = 0.8 }, - { target = "qq:114514:private", time = "00:00-23:59", value = 0.5 }, -] - [memory] -max_memory_number = 100 # 记忆最大数量 -max_memory_size = 2048 # 记忆最大大小 -memory_build_frequency = 1 # 记忆构建频率 +max_agent_iterations = 5 # 记忆思考深度(最低为1(不深入思考)) [jargon] all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除 diff --git a/view_pkl.py b/view_pkl.py deleted file mode 100644 index 2d50681b..00000000 --- a/view_pkl.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -""" -查看 .pkl 文件内容的工具脚本 -""" - -import pickle -import sys -import os -from pprint import pprint - - -def view_pkl_file(file_path): - """查看 pkl 文件内容""" - if not os.path.exists(file_path): - print(f"❌ 文件不存在: {file_path}") - return - - try: - with open(file_path, "rb") as f: - data = pickle.load(f) - - print(f"📁 文件: {file_path}") - print(f"📊 数据类型: {type(data)}") - print("=" * 50) - - if isinstance(data, dict): - print("🔑 字典键:") - for key in data.keys(): - print(f" - {key}: {type(data[key])}") - print() - - print("📋 详细内容:") - pprint(data, width=120, depth=10) - - elif isinstance(data, list): - print(f"📝 列表长度: {len(data)}") - if data: - print(f"📊 第一个元素类型: {type(data[0])}") - print("📋 前几个元素:") - for i, item in enumerate(data[:3]): - print(f" [{i}]: {item}") - - else: - print("📋 内容:") - pprint(data, width=120, depth=10) - - # 如果是 expressor 模型,特别显示 token_counts 的详细信息 - if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]: - print("\n" + "=" * 50) - print("🔍 详细词汇统计 (token_counts):") - token_counts = data["nb"]["token_counts"] - for style_id, tokens in token_counts.items(): - print(f"\n📝 {style_id}:") - if tokens: - # 按词频排序显示前10个词 - sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True) - for word, count in sorted_tokens[:10]: - print(f" '{word}': {count}") - if len(sorted_tokens) > 10: - print(f" ... 还有 {len(sorted_tokens) - 10} 个词") - else: - print(" (无词汇数据)") - - except Exception as e: - print(f"❌ 读取文件失败: {e}") - - -def main(): - if len(sys.argv) != 2: - print("用法: python view_pkl.py ") - print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl") - return - - file_path = sys.argv[1] - view_pkl_file(file_path) - - -if __name__ == "__main__": - main() diff --git a/view_tokens.py b/view_tokens.py deleted file mode 100644 index 02ca1ea0..00000000 --- a/view_tokens.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 -""" -专门查看 expressor.pkl 文件中 token_counts 的脚本 -""" - -import pickle -import sys -import os - - -def view_token_counts(file_path): - """查看 expressor.pkl 文件中的词汇统计""" - if not os.path.exists(file_path): - print(f"❌ 文件不存在: {file_path}") - return - - try: - with open(file_path, "rb") as f: - data = pickle.load(f) - - print(f"📁 文件: {file_path}") - print("=" * 60) - - if "nb" not in data or "token_counts" not in data["nb"]: - print("❌ 这不是一个 expressor 模型文件") - return - - token_counts = data["nb"]["token_counts"] - candidates = data.get("candidates", {}) - - print(f"🎯 找到 {len(token_counts)} 个风格") - print("=" * 60) - - for style_id, tokens in token_counts.items(): - style_text = candidates.get(style_id, "未知风格") - print(f"\n📝 {style_id}: {style_text}") - print(f"📊 词汇数量: {len(tokens)}") - - if tokens: - # 按词频排序 - sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True) - - print("🔤 词汇统计 (按频率排序):") - for i, (word, count) in enumerate(sorted_tokens): - print(f" {i + 1:2d}. '{word}': {count}") - else: - print(" (无词汇数据)") - - print("-" * 40) - - except Exception as e: - print(f"❌ 读取文件失败: {e}") - - -def main(): - if len(sys.argv) != 2: - print("用法: python view_tokens.py ") - print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl") - return - - file_path = sys.argv[1] - view_token_counts(file_path) - - -if __name__ == "__main__": - main()