From 7839acd25d2fadbad360e47e148c659a81a8a145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 13 Nov 2025 13:24:55 +0800 Subject: [PATCH] Ruff fix --- bot.py | 2 +- plugins/ChatFrequency/plugin.py | 52 ++- scripts/build_io_pairs.py | 51 ++- scripts/expression_scatter_analysis.py | 209 ++++++------ src/chat/emoji_system/emoji_manager.py | 4 +- .../frequency_control/frequency_control.py | 23 +- src/chat/heart_flow/heartFC_chat.py | 80 +++-- .../heart_flow/heartflow_message_processor.py | 3 +- src/chat/message_receive/bot.py | 10 +- src/chat/planner_actions/planner.py | 50 +-- src/chat/replyer/group_generator.py | 80 ++--- src/chat/replyer/private_generator.py | 87 +++-- src/chat/replyer/prompt/replyer_prompt.py | 19 +- src/chat/utils/chat_history_summarizer.py | 181 ++++++----- src/chat/utils/chat_message_builder.py | 5 +- src/chat/utils/memory_forget_task.py | 198 ++++++------ src/chat/utils/statistic.py | 137 ++++---- src/chat/utils/utils.py | 26 +- src/chat/utils/utils_image.py | 14 +- src/common/data_models/database_data_model.py | 4 +- src/common/database/database_model.py | 13 +- src/config/official_configs.py | 21 +- src/express/express_utils.py | 29 +- src/express/expression_learner.py | 61 ++-- src/express/expression_selector.py | 108 ++++--- src/express/expressor_model/model.py | 63 ++-- src/express/expressor_model/online_nb.py | 7 +- src/express/expressor_model/tokenizer.py | 7 +- src/express/style_learner.py | 258 +++++++-------- src/jargon/__init__.py | 2 - src/jargon/jargon_miner.py | 207 ++++++------ src/llm_models/model_client/gemini_client.py | 10 +- src/llm_models/model_client/openai_client.py | 20 +- src/llm_models/utils_model.py | 8 +- src/main.py | 5 +- src/memory_system/curious.py | 50 +-- src/memory_system/memory_retrieval.py | 299 ++++++++---------- src/memory_system/memory_utils.py | 32 +- .../retrieval_tools/query_chat_history.py | 87 +++-- .../retrieval_tools/query_jargon.py | 49 +-- .../retrieval_tools/tool_registry.py | 36 +-- .../retrieval_tools/tool_utils.py | 13 +- src/mood/mood_manager.py | 3 - src/plugin_system/apis/frequency_api.py | 4 +- src/plugin_system/apis/message_api.py | 2 +- src/plugin_system/apis/tool_api.py | 4 +- src/plugin_system/base/base_action.py | 17 +- src/plugin_system/base/base_tool.py | 6 +- src/plugin_system/core/events_manager.py | 5 +- src/plugin_system/core/tool_use.py | 5 +- view_pkl.py | 29 +- view_tokens.py | 35 +- 52 files changed, 1322 insertions(+), 1408 deletions(-) diff --git a/bot.py b/bot.py index 3f47e435..cf342507 100644 --- a/bot.py +++ b/bot.py @@ -30,7 +30,7 @@ else: raise # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 -from src.common.logger import initialize_logging, get_logger, shutdown_logging +from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa initialize_logging() diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py index 93cb522b..e2231c86 100644 --- a/plugins/ChatFrequency/plugin.py +++ b/plugins/ChatFrequency/plugin.py @@ -1,15 +1,11 @@ from typing import List, Tuple, Type, Optional -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseCommand, - ComponentInfo, - ConfigField -) +from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField from src.plugin_system.apis import send_api, frequency_api + class SetTalkFrequencyCommand(BaseCommand): """设置当前聊天的talk_frequency值""" + command_name = "set_talk_frequency" command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>" command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P[+-]?\d*\.?\d+)$" @@ -19,35 +15,35 @@ class SetTalkFrequencyCommand(BaseCommand): # 获取命令参数 - 使用命名捕获组 if not self.matched_groups or "value" not in self.matched_groups: return False, "命令格式错误", False - + value_str = self.matched_groups["value"] if not value_str: return False, "无法获取数值参数", False - + value = float(value_str) - + # 获取聊天流ID if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"): return False, "无法获取聊天流信息", False - + chat_id = self.message.chat_stream.stream_id - + # 设置talk_frequency frequency_api.set_talk_frequency_adjust(chat_id, value) - + final_value = frequency_api.get_current_talk_value(chat_id) adjust_value = frequency_api.get_talk_frequency_adjust(chat_id) base_value = final_value / adjust_value - + # 发送反馈消息(不保存到数据库) await send_api.text_to_stream( f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}", chat_id, - storage_message=False + storage_message=False, ) - + return True, None, False - + except ValueError: error_msg = "数值格式错误,请输入有效的数字" await self.send_text(error_msg, storage_message=False) @@ -60,6 +56,7 @@ class SetTalkFrequencyCommand(BaseCommand): class ShowFrequencyCommand(BaseCommand): """显示当前聊天的频率控制状态""" + command_name = "show_frequency" command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s" command_pattern = r"^/chat\s+(?:show|s)$" @@ -116,11 +113,7 @@ class BetterFrequencyPlugin(BasePlugin): config_file_name: str = "config.toml" # 配置节描述 - config_section_descriptions = { - "plugin": "插件基本信息", - "frequency": "频率控制配置", - "features": "功能开关配置" - } + config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"} # 配置Schema定义 config_schema: dict = { @@ -138,13 +131,14 @@ class BetterFrequencyPlugin(BasePlugin): def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: components = [] - + # 根据配置决定是否注册命令组件 if self.config.get("features", {}).get("enable_commands", True): - components.extend([ - (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), - (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), - ]) - - + components.extend( + [ + (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), + (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), + ] + ) + return components diff --git a/scripts/build_io_pairs.py b/scripts/build_io_pairs.py index b298dcd2..f934566a 100644 --- a/scripts/build_io_pairs.py +++ b/scripts/build_io_pairs.py @@ -6,15 +6,16 @@ import sys import os from datetime import datetime from typing import Dict, Iterable, List, Optional, Tuple +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.message_repository import find_messages +from src.chat.utils.chat_message_builder import build_readable_messages # 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级) PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) -from src.common.data_models.database_data_model import DatabaseMessages -from src.common.message_repository import find_messages -from src.chat.utils.chat_message_builder import build_readable_messages + SECONDS_5_MINUTES = 5 * 60 @@ -28,16 +29,16 @@ def clean_output_text(text: str) -> str: """ if not text: return text - + # 移除表情包内容:[表情包:...] - text = re.sub(r'\[表情包:[^\]]*\]', '', text) - + text = re.sub(r"\[表情包:[^\]]*\]", "", text) + # 移除回复内容:[回复...],说:... 的完整模式 - text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text) - + text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text) + # 清理多余的空格和换行 - text = re.sub(r'\s+', ' ', text).strip() - + text = re.sub(r"\s+", " ", text).strip() + return text @@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa for msg in messages: groups.setdefault(msg.chat_id, []).append(msg) # 保证每个分组内按时间升序 - for chat_id, msgs in groups.items(): + for _chat_id, msgs in groups.items(): msgs.sort(key=lambda m: m.time or 0) return groups @@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM continue last = bucket[-1] - same_user = (msg.user_info.user_id == last.user_info.user_id) - close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES) + same_user = msg.user_info.user_id == last.user_info.user_id + close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES if same_user and close_enough: bucket.append(msg) @@ -199,38 +200,36 @@ def build_pairs_for_chat( pairs: List[Tuple[str, str, str]] = [] n_merged = len(merged_messages) n_original = len(original_messages) - + if n_merged == 0 or n_original == 0: return pairs # 为每个合并后的消息找到对应的原始消息位置 merged_to_original_map = {} original_idx = 0 - + for merged_idx, merged_msg in enumerate(merged_messages): # 找到这个合并消息对应的第一个原始消息 - while (original_idx < n_original and - original_messages[original_idx].time < merged_msg.time): + while original_idx < n_original and original_messages[original_idx].time < merged_msg.time: original_idx += 1 - + # 如果找到了时间匹配的原始消息,建立映射 - if (original_idx < n_original and - original_messages[original_idx].time == merged_msg.time): + if original_idx < n_original and original_messages[original_idx].time == merged_msg.time: merged_to_original_map[merged_idx] = original_idx for merged_idx in range(n_merged): merged_msg = merged_messages[merged_idx] - + # 如果指定了 target_user_id,只处理该用户的消息作为 output if target_user_id and merged_msg.user_info.user_id != target_user_id: continue - + # 找到对应的原始消息位置 if merged_idx not in merged_to_original_map: continue - + original_idx = merged_to_original_map[merged_idx] - + # 选择上下文窗口大小 window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx start = max(0, original_idx - window) @@ -266,7 +265,7 @@ def build_pairs( groups = group_by_chat(messages) all_pairs: List[Tuple[str, str, str]] = [] - for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用 + for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用 # 对消息进行合并,用于output merged = merge_adjacent_same_user(msgs) # 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息 @@ -385,5 +384,3 @@ def run_interactive() -> int: if __name__ == "__main__": sys.exit(main()) - - diff --git a/scripts/expression_scatter_analysis.py b/scripts/expression_scatter_analysis.py index f6243ada..b022c94e 100644 --- a/scripts/expression_scatter_analysis.py +++ b/scripts/expression_scatter_analysis.py @@ -1,4 +1,3 @@ -import time import sys import os import matplotlib.pyplot as plt @@ -6,16 +5,17 @@ import matplotlib.dates as mdates from datetime import datetime from typing import List, Tuple import numpy as np +from src.common.database.database_model import Expression, ChatStreams # Add project root to Python path project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) -from src.common.database.database_model import Expression, ChatStreams + # 设置中文字体 -plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] -plt.rcParams['axes.unicode_minus'] = False +plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"] +plt.rcParams["axes.unicode_minus"] = False def get_chat_name(chat_id: str) -> str: @@ -39,19 +39,14 @@ def get_expression_data() -> List[Tuple[float, float, str, str]]: """获取Expression表中的数据,返回(create_date, count, chat_id, expression_type)的列表""" expressions = Expression.select() data = [] - + for expr in expressions: # 如果create_date为空,跳过该记录 if expr.create_date is None: continue - - data.append(( - expr.create_date, - expr.count, - expr.chat_id, - expr.type - )) - + + data.append((expr.create_date, expr.count, expr.chat_id, expr.type)) + return data @@ -60,71 +55,71 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st if not data: print("没有找到有效的表达式数据") return - + # 分离数据 create_dates = [item[0] for item in data] counts = [item[1] for item in data] - chat_ids = [item[2] for item in data] - expression_types = [item[3] for item in data] - + _chat_ids = [item[2] for item in data] + _expression_types = [item[3] for item in data] + # 转换时间戳为datetime对象 dates = [datetime.fromtimestamp(ts) for ts in create_dates] - + # 计算时间跨度,自动调整显示格式 time_span = max(dates) - min(dates) if time_span.days > 30: # 超过30天,按月显示 - date_format = '%Y-%m-%d' + date_format = "%Y-%m-%d" major_locator = mdates.MonthLocator() minor_locator = mdates.DayLocator(interval=7) elif time_span.days > 7: # 超过7天,按天显示 - date_format = '%Y-%m-%d' + date_format = "%Y-%m-%d" major_locator = mdates.DayLocator(interval=1) minor_locator = mdates.HourLocator(interval=12) else: # 7天内,按小时显示 - date_format = '%Y-%m-%d %H:%M' + date_format = "%Y-%m-%d %H:%M" major_locator = mdates.HourLocator(interval=6) minor_locator = mdates.HourLocator(interval=1) - + # 创建图形 fig, ax = plt.subplots(figsize=(12, 8)) - + # 创建散点图 - scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis') - + scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis") + # 设置标签和标题 - ax.set_xlabel('创建日期 (Create Date)', fontsize=12) - ax.set_ylabel('使用次数 (Count)', fontsize=12) - ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold') - + ax.set_xlabel("创建日期 (Create Date)", fontsize=12) + ax.set_ylabel("使用次数 (Count)", fontsize=12) + ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold") + # 设置x轴日期格式 - 根据时间跨度自动调整 ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_minor_locator(minor_locator) plt.xticks(rotation=45) - + # 添加网格 ax.grid(True, alpha=0.3) - + # 添加颜色条 cbar = plt.colorbar(scatter) - cbar.set_label('数据点顺序', fontsize=10) - + cbar.set_label("数据点顺序", fontsize=10) + # 调整布局 plt.tight_layout() - + # 显示统计信息 - print(f"\n=== 数据统计 ===") + print("\n=== 数据统计 ===") print(f"总数据点数量: {len(data)}") print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}") print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}") print(f"平均使用次数: {np.mean(counts):.2f}") print(f"中位数使用次数: {np.median(counts):.2f}") - + # 保存图片 if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"\n散点图已保存到: {save_path}") - + # 显示图片 plt.show() @@ -134,7 +129,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_ if not data: print("没有找到有效的表达式数据") return - + # 按chat_id分组 chat_groups = {} for item in data: @@ -142,75 +137,82 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_ if chat_id not in chat_groups: chat_groups[chat_id] = [] chat_groups[chat_id].append(item) - + # 计算时间跨度,自动调整显示格式 all_dates = [datetime.fromtimestamp(item[0]) for item in data] time_span = max(all_dates) - min(all_dates) if time_span.days > 30: # 超过30天,按月显示 - date_format = '%Y-%m-%d' + date_format = "%Y-%m-%d" major_locator = mdates.MonthLocator() minor_locator = mdates.DayLocator(interval=7) elif time_span.days > 7: # 超过7天,按天显示 - date_format = '%Y-%m-%d' + date_format = "%Y-%m-%d" major_locator = mdates.DayLocator(interval=1) minor_locator = mdates.HourLocator(interval=12) else: # 7天内,按小时显示 - date_format = '%Y-%m-%d %H:%M' + date_format = "%Y-%m-%d %H:%M" major_locator = mdates.HourLocator(interval=6) minor_locator = mdates.HourLocator(interval=1) - + # 创建图形 fig, ax = plt.subplots(figsize=(14, 10)) - + # 为每个聊天分配不同颜色 colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups))) - + for i, (chat_id, chat_data) in enumerate(chat_groups.items()): create_dates = [item[0] for item in chat_data] counts = [item[1] for item in chat_data] dates = [datetime.fromtimestamp(ts) for ts in create_dates] - + chat_name = get_chat_name(chat_id) # 截断过长的聊天名称 display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name - - ax.scatter(dates, counts, alpha=0.7, s=40, - c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)", - edgecolors='black', linewidth=0.5) - + + ax.scatter( + dates, + counts, + alpha=0.7, + s=40, + c=[colors[i]], + label=f"{display_name} ({len(chat_data)}个)", + edgecolors="black", + linewidth=0.5, + ) + # 设置标签和标题 - ax.set_xlabel('创建日期 (Create Date)', fontsize=12) - ax.set_ylabel('使用次数 (Count)', fontsize=12) - ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold') - + ax.set_xlabel("创建日期 (Create Date)", fontsize=12) + ax.set_ylabel("使用次数 (Count)", fontsize=12) + ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold") + # 设置x轴日期格式 - 根据时间跨度自动调整 ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_minor_locator(minor_locator) plt.xticks(rotation=45) - + # 添加图例 - ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) - + ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8) + # 添加网格 ax.grid(True, alpha=0.3) - + # 调整布局 plt.tight_layout() - + # 显示统计信息 - print(f"\n=== 分组统计 ===") + print("\n=== 分组统计 ===") print(f"总聊天数量: {len(chat_groups)}") for chat_id, chat_data in chat_groups.items(): chat_name = get_chat_name(chat_id) counts = [item[1] for item in chat_data] print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}") - + # 保存图片 if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"\n分组散点图已保存到: {save_path}") - + # 显示图片 plt.show() @@ -220,7 +222,7 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat if not data: print("没有找到有效的表达式数据") return - + # 按type分组 type_groups = {} for item in data: @@ -228,69 +230,76 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat if expr_type not in type_groups: type_groups[expr_type] = [] type_groups[expr_type].append(item) - + # 计算时间跨度,自动调整显示格式 all_dates = [datetime.fromtimestamp(item[0]) for item in data] time_span = max(all_dates) - min(all_dates) if time_span.days > 30: # 超过30天,按月显示 - date_format = '%Y-%m-%d' + date_format = "%Y-%m-%d" major_locator = mdates.MonthLocator() minor_locator = mdates.DayLocator(interval=7) elif time_span.days > 7: # 超过7天,按天显示 - date_format = '%Y-%m-%d' + date_format = "%Y-%m-%d" major_locator = mdates.DayLocator(interval=1) minor_locator = mdates.HourLocator(interval=12) else: # 7天内,按小时显示 - date_format = '%Y-%m-%d %H:%M' + date_format = "%Y-%m-%d %H:%M" major_locator = mdates.HourLocator(interval=6) minor_locator = mdates.HourLocator(interval=1) - + # 创建图形 fig, ax = plt.subplots(figsize=(12, 8)) - + # 为每个类型分配不同颜色 colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups))) - + for i, (expr_type, type_data) in enumerate(type_groups.items()): create_dates = [item[0] for item in type_data] counts = [item[1] for item in type_data] dates = [datetime.fromtimestamp(ts) for ts in create_dates] - - ax.scatter(dates, counts, alpha=0.7, s=40, - c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)", - edgecolors='black', linewidth=0.5) - + + ax.scatter( + dates, + counts, + alpha=0.7, + s=40, + c=[colors[i]], + label=f"{expr_type} ({len(type_data)}个)", + edgecolors="black", + linewidth=0.5, + ) + # 设置标签和标题 - ax.set_xlabel('创建日期 (Create Date)', fontsize=12) - ax.set_ylabel('使用次数 (Count)', fontsize=12) - ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold') - + ax.set_xlabel("创建日期 (Create Date)", fontsize=12) + ax.set_ylabel("使用次数 (Count)", fontsize=12) + ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold") + # 设置x轴日期格式 - 根据时间跨度自动调整 ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_minor_locator(minor_locator) plt.xticks(rotation=45) - + # 添加图例 - ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') - + ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + # 添加网格 ax.grid(True, alpha=0.3) - + # 调整布局 plt.tight_layout() - + # 显示统计信息 - print(f"\n=== 类型统计 ===") + print("\n=== 类型统计 ===") for expr_type, type_data in type_groups.items(): counts = [item[1] for item in type_data] print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}") - + # 保存图片 if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"\n类型散点图已保存到: {save_path}") - + # 显示图片 plt.show() @@ -298,35 +307,35 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat def main(): """主函数""" print("开始分析表达式数据...") - + # 获取数据 data = get_expression_data() - + if not data: print("没有找到有效的表达式数据(create_date不为空的数据)") return - + print(f"找到 {len(data)} 条有效数据") - + # 创建输出目录 output_dir = os.path.join(project_root, "data", "temp") os.makedirs(output_dir, exist_ok=True) - + # 生成时间戳用于文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + # 1. 创建基础散点图 print("\n1. 创建基础散点图...") create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png")) - + # 2. 创建按聊天分组的散点图 print("\n2. 创建按聊天分组的散点图...") create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png")) - + # 3. 创建按类型分组的散点图 print("\n3. 创建按类型分组的散点图...") create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png")) - + print("\n分析完成!") diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index b26ab844..1a562fcc 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -945,9 +945,7 @@ class EmojiManager: prompt, image_base64, "jpg", temperature=0.5 ) else: - prompt = ( - "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析,精简回答" - ) + prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析,精简回答" description, _ = await self.vlm.generate_response_for_image( prompt, image_base64, image_format, temperature=0.5 ) diff --git a/src/chat/frequency_control/frequency_control.py b/src/chat/frequency_control/frequency_control.py index 95242972..78041ae7 100644 --- a/src/chat/frequency_control/frequency_control.py +++ b/src/chat/frequency_control/frequency_control.py @@ -12,6 +12,7 @@ from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.plugin_system.apis import frequency_api + def init_prompt(): Prompt( """{name_block} @@ -28,7 +29,7 @@ def init_prompt(): """, "frequency_adjust_prompt", ) - + logger = get_logger("frequency_control") @@ -40,7 +41,7 @@ class FrequencyControl: self.chat_id = chat_id # 发言频率调整值 self.talk_frequency_adjust: float = 1.0 - + self.last_frequency_adjust_time: float = 0.0 self.frequency_model = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust" @@ -53,16 +54,14 @@ class FrequencyControl: def set_talk_frequency_adjust(self, value: float) -> None: """设置发言频率调整值""" self.talk_frequency_adjust = max(0.1, min(5.0, value)) - - + async def trigger_frequency_adjust(self) -> None: msg_list = get_raw_msg_by_timestamp_with_chat( chat_id=self.chat_id, timestamp_start=self.last_frequency_adjust_time, timestamp_end=time.time(), ) - - + if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20: return else: @@ -73,7 +72,7 @@ class FrequencyControl: limit=20, limit_mode="latest", ) - + message_str = build_readable_messages( new_msg_list, replace_bot_name=True, @@ -97,15 +96,15 @@ class FrequencyControl: response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async( prompt, ) - + # logger.info(f"频率调整 prompt: {prompt}") # logger.info(f"频率调整 response: {response}") - + if global_config.debug.show_prompt: logger.info(f"频率调整 prompt: {prompt}") logger.info(f"频率调整 response: {response}") logger.info(f"频率调整 reasoning_content: {reasoning_content}") - + final_value_by_api = frequency_api.get_current_talk_value(self.chat_id) # LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限 @@ -118,7 +117,8 @@ class FrequencyControl: self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2)) self.last_frequency_adjust_time = time.time() else: - logger.info(f"频率调整:response不符合要求,取消本次调整") + logger.info("频率调整:response不符合要求,取消本次调整") + class FrequencyControlManager: """频率控制管理器,管理多个聊天流的频率控制实例""" @@ -143,6 +143,7 @@ class FrequencyControlManager: """获取所有有频率控制的聊天ID""" return list(self.frequency_control_dict.keys()) + init_prompt() # 创建全局实例 diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 6bbda587..bd99d2a7 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -1,5 +1,4 @@ import asyncio -from multiprocessing import context import time import traceback import random @@ -19,7 +18,6 @@ from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail from src.express.expression_learner import expression_learner_manager from src.chat.frequency_control.frequency_control import frequency_control_manager -from src.memory_system.curious import check_and_make_question from src.jargon import extract_and_store_jargon from src.person_info.person_info import Person from src.plugin_system.base.component_types import EventType, ActionInfo @@ -103,14 +101,14 @@ class HeartFChatting: self.is_mute = False - self.last_active_time = time.time() # 记录上一次非noreply时间 + self.last_active_time = time.time() # 记录上一次非noreply时间 self.question_probability_multiplier = 1 self.questioned = False - + # 跟踪连续 no_reply 次数,用于动态调整阈值 self.consecutive_no_reply_count = 0 - + # 聊天内容概括器 self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id) @@ -128,10 +126,10 @@ class HeartFChatting: self._loop_task = asyncio.create_task(self._main_chat_loop()) self._loop_task.add_done_callback(self._handle_loop_completion) - + # 启动聊天内容概括器的后台定期检查循环 await self.chat_history_summarizer.start() - + logger.info(f"{self.log_prefix} HeartFChatting 启动完成") except Exception as e: @@ -181,7 +179,7 @@ class HeartFChatting: + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "") ) - async def _loopbody(self): + async def _loopbody(self): recent_messages_list = message_api.get_messages_by_time_in_chat( chat_id=self.stream_id, start_time=self.last_read_time, @@ -192,9 +190,6 @@ class HeartFChatting: filter_command=True, ) - - - # 根据连续 no_reply 次数动态调整阈值 # 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2) # 5次 no_reply 时,提高到 2(大于等于两条消息的阈值) @@ -205,10 +200,10 @@ class HeartFChatting: threshold = 2 if random.random() < 0.5 else 1 else: threshold = 1 - + if len(recent_messages_list) >= threshold: # for message in recent_messages_list: - # print(message.processed_plain_text) + # print(message.processed_plain_text) # !处理no_reply_until_call逻辑 if self.no_reply_until_call: for message in recent_messages_list: @@ -338,7 +333,7 @@ class HeartFChatting: # 重置连续 no_reply 计数 self.consecutive_no_reply_count = 0 reason = "有人提到了你,进行回复" - + await database_api.store_action_info( chat_stream=self.chat_stream, action_build_into_prompt=False, @@ -396,15 +391,16 @@ class HeartFChatting: ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if if recent_messages_list is None: recent_messages_list = [] - reply_text = "" # 初始化reply_text变量,避免UnboundLocalError + _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError start_time = time.time() - async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) - asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()) - + asyncio.create_task( + frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust() + ) + # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容 # asyncio.create_task(check_and_make_question(self.stream_id)) # 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却) @@ -412,8 +408,7 @@ class HeartFChatting: # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录 # 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理 # asyncio.create_task(self.chat_history_summarizer.process()) - - + cycle_timers, thinking_id = self.start_cycle() logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") @@ -428,7 +423,7 @@ class HeartFChatting: # 如果被提及,让回复生成和planner并行执行 if force_reply_message: logger.info(f"{self.log_prefix} 检测到提及,回复生成与planner并行执行") - + # 并行执行planner和回复生成 planner_task = asyncio.create_task( self._run_planner_without_reply( @@ -458,7 +453,12 @@ class HeartFChatting: # 处理回复结果 if isinstance(reply_result, BaseException): logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}") - reply_result = {"action_type": "reply", "success": False, "result": "回复生成异常", "loop_info": None} + reply_result = { + "action_type": "reply", + "success": False, + "result": "回复生成异常", + "loop_info": None, + } else: # 正常流程:只执行planner is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() @@ -517,7 +517,7 @@ class HeartFChatting: # 并行执行所有任务 results = await asyncio.gather(*action_tasks, return_exceptions=True) - + # 如果有独立的回复结果,添加到结果列表中 if reply_result: results = list(results) + [reply_result] @@ -559,7 +559,7 @@ class HeartFChatting: "taken_time": time.time(), } ) - reply_text = reply_text_from_reply + _reply_text = reply_text_from_reply else: # 没有回复信息,构建纯动作的loop_info loop_info = { @@ -572,7 +572,7 @@ class HeartFChatting: "taken_time": time.time(), }, } - reply_text = action_reply_text + _reply_text = action_reply_text self.end_cycle(loop_info, cycle_timers) self.print_cycle_info(cycle_timers) @@ -648,7 +648,6 @@ class HeartFChatting: result = await action_handler.execute() success, action_text = result - return success, action_text except Exception as e: @@ -656,8 +655,6 @@ class HeartFChatting: traceback.print_exc() return False, "" - - async def _send_response( self, reply_set: "ReplySetModel", @@ -733,7 +730,6 @@ class HeartFChatting: action_reasoning=reason, ) - return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""} elif action_planner_info.action_type == "no_reply_until_call": @@ -754,7 +750,12 @@ class HeartFChatting: action_name="no_reply_until_call", action_reasoning=reason, ) - return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""} + return { + "action_type": "no_reply_until_call", + "success": True, + "result": "保持沉默,直到有人直接叫的名字", + "command": "", + } elif action_planner_info.action_type == "reply": # 直接当场执行reply逻辑 @@ -784,19 +785,16 @@ class HeartFChatting: enable_tool=global_config.tool.enable_tool, request_type="replyer", from_plugin=False, - reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()), + reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()), ) if not success or not llm_response or not llm_response.reply_set: if action_planner_info.action_message: - logger.info( - f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败" - ) + logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败") else: logger.info("回复生成失败") return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None} - response_set = llm_response.reply_set selected_expressions = llm_response.selected_expressions loop_info, reply_text, _ = await self._send_and_store_reply( @@ -818,12 +816,12 @@ class HeartFChatting: # 执行普通动作 with Timer("动作执行", cycle_timers): success, result = await self._handle_action( - action = action_planner_info.action_type, - action_reasoning = action_planner_info.action_reasoning or "", - action_data = action_planner_info.action_data or {}, - cycle_timers = cycle_timers, - thinking_id = thinking_id, - action_message= action_planner_info.action_message, + action=action_planner_info.action_type, + action_reasoning=action_planner_info.action_reasoning or "", + action_data=action_planner_info.action_data or {}, + cycle_timers=cycle_timers, + thinking_id=thinking_id, + action_message=action_planner_info.action_message, ) self.last_active_time = time.time() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 032c52cd..90e5e118 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -13,10 +13,11 @@ from src.person_info.person_info import Person from src.common.database.database_model import Images if TYPE_CHECKING: - from src.chat.heart_flow.heartFC_chat import HeartFChatting + pass logger = get_logger("chat") + class HeartFCMessageReceiver: """心流处理器,负责处理接收到的消息并计算兴趣度""" diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 8af62bf2..070f78bd 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -15,7 +15,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.plugin_system.base import BaseCommand, EventType -from src.person_info.person_info import Person # 定义日志配置 @@ -171,7 +170,11 @@ class ChatBot: # 撤回事件打印;无法获取被撤回者则省略 if sub_type == "recall": - op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None)) + op_name = ( + getattr(op, "user_cardname", None) + or getattr(op, "user_nickname", None) + or str(getattr(op, "user_id", None)) + ) recalled_name = None try: if isinstance(recalled, dict): @@ -189,7 +192,7 @@ class ChatBot: logger.info(f"{op_name} 撤回了消息") else: logger.debug( - f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) " + f"[notice] sub_type={sub_type} scene={scene} op={getattr(op, 'user_nickname', None)}({getattr(op, 'user_id', None)}) " f"gid={gid} msg_id={msg_id} recalled={recalled_id}" ) except Exception: @@ -234,7 +237,6 @@ class ChatBot: # 确保所有任务已启动 await self._ensure_started() - if message_data["message_info"].get("group_info") is not None: message_data["message_info"]["group_info"]["group_id"] = str( message_data["message_info"]["group_info"]["group_id"] diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 5dfd2578..7af3291a 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -143,7 +143,6 @@ class ActionPlanner: self.last_obs_time_mark = 0.0 - self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = [] def find_message_by_id( @@ -306,7 +305,9 @@ class ActionPlanner: loop_start_time=loop_start_time, ) - logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") + logger.info( + f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}" + ) self.add_plan_log(reasoning, actions) @@ -316,7 +317,7 @@ class ActionPlanner: self.plan_log.append((reasoning, time.time(), actions)) if len(self.plan_log) > 20: self.plan_log.pop(0) - + def add_plan_excute_log(self, result: str): self.plan_log.append(("", time.time(), result)) if len(self.plan_log) > 20: @@ -325,17 +326,17 @@ class ActionPlanner: def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str: """ 获取计划日志字符串 - + Args: max_action_records: 显示多少条最新的action记录,默认2 max_execution_records: 显示多少条最新执行结果记录,默认8 - + Returns: 格式化的日志字符串 """ action_records = [] execution_records = [] - + # 从后往前遍历,收集最新的记录 for reasoning, timestamp, content in reversed(self.plan_log): if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): @@ -346,13 +347,13 @@ class ActionPlanner: # 这是执行结果记录 if len(execution_records) < max_execution_records: execution_records.append((reasoning, timestamp, content, "execution")) - + # 合并所有记录并按时间戳排序 all_records = action_records + execution_records all_records.sort(key=lambda x: x[1]) # 按时间戳排序 - + plan_log_str = "" - + # 按时间顺序添加所有记录 for reasoning, timestamp, content, record_type in all_records: time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S") @@ -361,21 +362,21 @@ class ActionPlanner: plan_log_str += f"{time_str}:{reasoning}\n" else: plan_log_str += f"{time_str}:你执行了action:{content}\n" - + return plan_log_str def _has_consecutive_no_reply(self, min_count: int = 3) -> bool: """ 检查是否有连续min_count次以上的no_reply - + Args: min_count: 需要连续的最少次数,默认3 - + Returns: 如果有连续min_count次以上no_reply返回True,否则返回False """ consecutive_count = 0 - + # 从后往前遍历plan_log,检查最新的连续记录 for _reasoning, _timestamp, content in reversed(self.plan_log): if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): @@ -387,7 +388,7 @@ class ActionPlanner: else: # 如果遇到非no_reply的action,重置计数 break - + return False async def build_planner_prompt( @@ -402,8 +403,7 @@ class ActionPlanner: ) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]: """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: - - actions_before_now_block=self.get_plan_log_str() + actions_before_now_block = self.get_plan_log_str() # 构建聊天上下文描述 chat_context_description = "你现在正在一个群聊中" @@ -537,7 +537,7 @@ class ActionPlanner: for require_item in action_info.action_require: require_text += f"- {require_item}\n" require_text = require_text.rstrip("\n") - + if not action_info.parallel_action: parallel_text = "(当选择这个动作时,请不要选择其他动作)" else: @@ -564,7 +564,7 @@ class ActionPlanner: filtered_actions: Dict[str, ActionInfo], available_actions: Dict[str, ActionInfo], loop_start_time: float, - ) -> Tuple[str,List[ActionPlannerInfo]]: + ) -> Tuple[str, List[ActionPlannerInfo]]: """执行主规划器""" llm_content = None actions: List[ActionPlannerInfo] = [] @@ -589,7 +589,7 @@ class ActionPlanner: except Exception as req_e: logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") - return f"LLM 请求失败,模型出现问题: {req_e}",[ + return f"LLM 请求失败,模型出现问题: {req_e}", [ ActionPlannerInfo( action_type="no_reply", reasoning=f"LLM 请求失败,模型出现问题: {req_e}", @@ -608,7 +608,11 @@ class ActionPlanner: logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") filtered_actions_list = list(filtered_actions.items()) for json_obj in json_objects: - actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning)) + actions.extend( + self._parse_single_action( + json_obj, message_id_list, filtered_actions_list, extracted_reasoning + ) + ) else: # 尝试解析为直接的JSON logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}") @@ -631,7 +635,7 @@ class ActionPlanner: logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") - return extracted_reasoning,actions + return extracted_reasoning, actions def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: """创建no_reply""" @@ -674,7 +678,7 @@ class ActionPlanner: json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释 if json_str := json_str.strip(): # 尝试按行分割,每行可能是一个JSON对象 - lines = [line.strip() for line in json_str.split('\n') if line.strip()] + lines = [line.strip() for line in json_str.split("\n") if line.strip()] for line in lines: try: # 尝试解析每一行作为独立的JSON对象 @@ -688,7 +692,7 @@ class ActionPlanner: except json.JSONDecodeError: # 如果单行解析失败,尝试将整个块作为一个JSON对象或数组 pass - + # 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组 if not json_objects: json_obj = json.loads(repair_json(json_str)) diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index bcda39b9..6500026f 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -134,12 +134,12 @@ class DefaultReplyer: try: content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt) # logger.debug(f"replyer生成内容: {content}") - + logger.info(f"replyer生成内容: {content}") if global_config.debug.show_replyer_reasoning: logger.info(f"replyer生成推理:\n{reasoning_content}") logger.info(f"replyer生成模型: {model_name}") - + llm_response.content = content llm_response.reasoning = reasoning_content llm_response.model = model_name @@ -268,14 +268,13 @@ class DefaultReplyer: expression_habits_block += f"{style_habits_str}\n" return f"{expression_habits_title}\n{expression_habits_block}", selected_ids - + async def build_mood_state_prompt(self) -> str: """构建情绪状态提示""" if not global_config.mood.enable_mood: return "" mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood() return f"你现在的心情是:{mood_state}" - async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: """构建工具信息块 @@ -303,7 +302,7 @@ class DefaultReplyer: for tool_result in tool_results: tool_name = tool_result.get("tool_name", "unknown") content = tool_result.get("content", "") - result_type = tool_result.get("type", "tool_result") + _result_type = tool_result.get("type", "tool_result") tool_info_str += f"- 【{tool_name}】: {content}\n" @@ -343,45 +342,45 @@ class DefaultReplyer: def _replace_picids_with_descriptions(self, text: str) -> str: """将文本中的[picid:xxx]替换为具体的图片描述 - + Args: text: 包含picid标记的文本 - + Returns: 替换后的文本 """ # 匹配 [picid:xxxxx] 格式 pic_pattern = r"\[picid:([^\]]+)\]" - + def replace_pic_id(match: re.Match) -> str: pic_id = match.group(1) description = translate_pid_to_description(pic_id) return f"[图片:{description}]" - + return re.sub(pic_pattern, replace_pic_id, text) def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]: """分析target内容类型(基于原始picid格式) - + Args: target: 目标消息内容(包含[picid:xxx]格式) - + Returns: Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分) """ if not target or not target.strip(): return False, False, "", "" - + # 检查是否只包含picid标记 picid_pattern = r"\[picid:[^\]]+\]" picid_matches = re.findall(picid_pattern, target) - + # 移除所有picid标记后检查是否还有文字内容 text_without_picids = re.sub(picid_pattern, "", target).strip() - + has_only_pics = len(picid_matches) > 0 and not text_without_picids has_text = bool(text_without_picids) - + # 提取图片部分(转换为[图片:描述]格式) pic_part = "" if picid_matches: @@ -396,7 +395,7 @@ class DefaultReplyer: else: pic_descriptions.append(f"[图片:{description}]") pic_part = "".join(pic_descriptions) - + return has_only_pics, has_text, pic_part, text_without_picids async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: @@ -481,7 +480,7 @@ class DefaultReplyer: ) return all_dialogue_prompt - + def core_background_build_chat_history_prompts( self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str ) -> Tuple[str, str]: @@ -603,25 +602,27 @@ class DefaultReplyer: # 获取基础personality prompt_personality = global_config.personality.personality - + # 检查是否需要随机替换为状态 - if (global_config.personality.states and - global_config.personality.state_probability > 0 and - random.random() < global_config.personality.state_probability): + if ( + global_config.personality.states + and global_config.personality.state_probability > 0 + and random.random() < global_config.personality.state_probability + ): # 随机选择一个状态替换personality selected_state = random.choice(global_config.personality.states) prompt_personality = selected_state - + prompt_personality = f"{prompt_personality};" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]: """ 解析聊天prompt配置字符串并生成对应的 chat_id 和 prompt内容 - + Args: chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串 - + Returns: tuple: (chat_id, prompt_content),如果解析失败则返回 None """ @@ -657,10 +658,10 @@ class DefaultReplyer: def get_chat_prompt_for_chat(self, chat_id: str) -> str: """ 根据聊天流ID获取匹配的额外prompt(仅匹配group类型) - + Args: chat_id: 聊天流ID(哈希值) - + Returns: str: 匹配的额外prompt内容,如果没有匹配则返回空字符串 """ @@ -670,21 +671,21 @@ class DefaultReplyer: for chat_prompt_str in global_config.experimental.chat_prompts: if not isinstance(chat_prompt_str, str): continue - + # 解析配置字符串,检查类型是否为group parts = chat_prompt_str.split(":", 3) if len(parts) != 4: continue - + stream_type = parts[2] # 只匹配group类型 if stream_type != "group": continue - + result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str) if result is None: continue - + config_chat_id, prompt_content = result if config_chat_id == chat_id: logger.debug(f"匹配到群聊prompt配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...") @@ -720,7 +721,7 @@ class DefaultReplyer: available_actions = {} chat_stream = self.chat_stream chat_id = chat_stream.stream_id - is_group_chat = bool(chat_stream.group_info) + _is_group_chat = bool(chat_stream.group_info) platform = chat_stream.platform user_id = "用户ID" @@ -736,10 +737,10 @@ class DefaultReplyer: target = reply_message.processed_plain_text target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) - + # 在picid替换之前分析内容类型(防止prompt注入) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) - + # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) @@ -911,10 +912,10 @@ class DefaultReplyer: sender, target = self._parse_reply_target(reply_to) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) - + # 在picid替换之前分析内容类型(防止prompt注入) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) - + # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) @@ -956,9 +957,7 @@ class DefaultReplyer: ) elif has_text and pic_part: # 既有图片又有文字 - reply_target_block = ( - f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。" - ) + reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。" else: # 只包含文字 reply_target_block = ( @@ -975,7 +974,9 @@ class DefaultReplyer: reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。" elif has_text and pic_part: # 既有图片又有文字 - reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" + reply_target_block = ( + f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" + ) else: # 只包含文字 reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。" @@ -1124,6 +1125,7 @@ class DefaultReplyer: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "" + def weighted_sample_no_replacement(items, weights, k) -> list: """ 加权且不放回地随机抽取k个元素。 diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 58928259..0bbce12a 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -46,6 +46,7 @@ init_memory_retrieval_prompt() logger = get_logger("replyer") + class PrivateReplyer: def __init__( self, @@ -277,9 +278,7 @@ class PrivateReplyer: expression_habits_block = "" expression_habits_title = "" if style_habits_str.strip(): - expression_habits_title = ( - "在回复时,你可以参考以下的语言习惯,不要生硬使用:" - ) + expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:" expression_habits_block += f"{style_habits_str}\n" return f"{expression_habits_title}\n{expression_habits_block}", selected_ids @@ -291,7 +290,6 @@ class PrivateReplyer: mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood() return f"你现在的心情是:{mood_state}" - async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: """构建工具信息块 @@ -358,45 +356,45 @@ class PrivateReplyer: def _replace_picids_with_descriptions(self, text: str) -> str: """将文本中的[picid:xxx]替换为具体的图片描述 - + Args: text: 包含picid标记的文本 - + Returns: 替换后的文本 """ # 匹配 [picid:xxxxx] 格式 pic_pattern = r"\[picid:([^\]]+)\]" - + def replace_pic_id(match: re.Match) -> str: pic_id = match.group(1) description = translate_pid_to_description(pic_id) return f"[图片:{description}]" - + return re.sub(pic_pattern, replace_pic_id, text) def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]: """分析target内容类型(基于原始picid格式) - + Args: target: 目标消息内容(包含[picid:xxx]格式) - + Returns: Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分) """ if not target or not target.strip(): return False, False, "", "" - + # 检查是否只包含picid标记 picid_pattern = r"\[picid:[^\]]+\]" picid_matches = re.findall(picid_pattern, target) - + # 移除所有picid标记后检查是否还有文字内容 text_without_picids = re.sub(picid_pattern, "", target).strip() - + has_only_pics = len(picid_matches) > 0 and not text_without_picids has_text = bool(text_without_picids) - + # 提取图片部分(转换为[图片:描述]格式) pic_part = "" if picid_matches: @@ -411,7 +409,7 @@ class PrivateReplyer: else: pic_descriptions.append(f"[图片:{description}]") pic_part = "".join(pic_descriptions) - + return has_only_pics, has_text, pic_part, text_without_picids async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: @@ -517,25 +515,27 @@ class PrivateReplyer: # 获取基础personality prompt_personality = global_config.personality.personality - + # 检查是否需要随机替换为状态 - if (global_config.personality.states and - global_config.personality.state_probability > 0 and - random.random() < global_config.personality.state_probability): + if ( + global_config.personality.states + and global_config.personality.state_probability > 0 + and random.random() < global_config.personality.state_probability + ): # 随机选择一个状态替换personality selected_state = random.choice(global_config.personality.states) prompt_personality = selected_state - + prompt_personality = f"{prompt_personality};" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]: """ 解析聊天prompt配置字符串并生成对应的 chat_id 和 prompt内容 - + Args: chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串 - + Returns: tuple: (chat_id, prompt_content),如果解析失败则返回 None """ @@ -571,10 +571,10 @@ class PrivateReplyer: def get_chat_prompt_for_chat(self, chat_id: str) -> str: """ 根据聊天流ID获取匹配的额外prompt(仅匹配private类型) - + Args: chat_id: 聊天流ID(哈希值) - + Returns: str: 匹配的额外prompt内容,如果没有匹配则返回空字符串 """ @@ -584,21 +584,21 @@ class PrivateReplyer: for chat_prompt_str in global_config.experimental.chat_prompts: if not isinstance(chat_prompt_str, str): continue - + # 解析配置字符串,检查类型是否为private parts = chat_prompt_str.split(":", 3) if len(parts) != 4: continue - + stream_type = parts[2] # 只匹配private类型 if stream_type != "private": continue - + result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str) if result is None: continue - + config_chat_id, prompt_content = result if config_chat_id == chat_id: logger.debug(f"匹配到私聊prompt配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...") @@ -647,13 +647,11 @@ class PrivateReplyer: sender = person_name target = reply_message.processed_plain_text - - target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) - + # 在picid替换之前分析内容类型(防止prompt注入) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) - + # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) @@ -662,7 +660,7 @@ class PrivateReplyer: timestamp=time.time(), limit=global_config.chat.max_context_size, ) - + dialogue_prompt = build_readable_messages( message_list_before_now_long, replace_bot_name=True, @@ -710,9 +708,7 @@ class PrivateReplyer: self._time_and_run_task( self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" ), - self._time_and_run_task( - self.build_relation_info(chat_talking_prompt_short, sender), "relation_info" - ), + self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"), self._time_and_run_task( self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" ), @@ -852,15 +848,13 @@ class PrivateReplyer: sender, target = self._parse_reply_target(reply_to) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) - + # 在picid替换之前分析内容类型(防止prompt注入) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) - + # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) - - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), @@ -900,9 +894,7 @@ class PrivateReplyer: ) elif has_text and pic_part: # 既有图片又有文字 - reply_target_block = ( - f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。" - ) + reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。" else: # 只包含文字 reply_target_block = ( @@ -919,7 +911,9 @@ class PrivateReplyer: reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。" elif has_text and pic_part: # 既有图片又有文字 - reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" + reply_target_block = ( + f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" + ) else: # 只包含文字 reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。" @@ -1010,7 +1004,7 @@ class PrivateReplyer: content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async( prompt ) - + content = content.strip() logger.info(f"使用 {model_name} 生成回复内容: {content}") @@ -1102,6 +1096,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list: pool.pop(idx) break return selected - - - diff --git a/src/chat/replyer/prompt/replyer_prompt.py b/src/chat/replyer/prompt/replyer_prompt.py index 8b5c30f3..871f5460 100644 --- a/src/chat/replyer/prompt/replyer_prompt.py +++ b/src/chat/replyer/prompt/replyer_prompt.py @@ -1,16 +1,13 @@ - from src.chat.utils.prompt_builder import Prompt # from src.chat.memory_system.memory_activator import MemoryActivator - def init_replyer_prompt(): Prompt("正在群里聊天", "chat_target_group2") Prompt("和{sender_name}聊天", "chat_target_private2") - - + Prompt( -"""{knowledge_prompt}{tool_info_block}{extra_info_block} + """{knowledge_prompt}{tool_info_block}{extra_info_block} {expression_habits_block}{memory_retrieval} 你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片: @@ -27,10 +24,9 @@ def init_replyer_prompt(): 现在,你说:""", "replyer_prompt", ) - - + Prompt( -"""{knowledge_prompt}{tool_info_block}{extra_info_block} + """{knowledge_prompt}{tool_info_block}{extra_info_block} {expression_habits_block}{memory_retrieval} 你正在和{sender_name}聊天,这是你们之前聊的内容: @@ -46,10 +42,9 @@ def init_replyer_prompt(): {moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""", "private_replyer_prompt", ) - - + Prompt( - """{knowledge_prompt}{tool_info_block}{extra_info_block} + """{knowledge_prompt}{tool_info_block}{extra_info_block} {expression_habits_block}{memory_retrieval} 你正在和{sender_name}聊天,这是你们之前聊的内容: @@ -65,4 +60,4 @@ def init_replyer_prompt(): {moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。 """, "private_replyer_self_prompt", - ) \ No newline at end of file + ) diff --git a/src/chat/utils/chat_history_summarizer.py b/src/chat/utils/chat_history_summarizer.py index b2a319b1..6b71706d 100644 --- a/src/chat/utils/chat_history_summarizer.py +++ b/src/chat/utils/chat_history_summarizer.py @@ -2,6 +2,7 @@ 聊天内容概括器 用于累积、打包和压缩聊天记录 """ + import asyncio import json import time @@ -23,6 +24,7 @@ logger = get_logger("chat_history_summarizer") @dataclass class MessageBatch: """消息批次""" + messages: List[DatabaseMessages] start_time: float end_time: float @@ -31,11 +33,11 @@ class MessageBatch: class ChatHistorySummarizer: """聊天内容概括器""" - + def __init__(self, chat_id: str, check_interval: int = 60): """ 初始化聊天内容概括器 - + Args: chat_id: 聊天ID check_interval: 定期检查间隔(秒),默认60秒 @@ -43,24 +45,23 @@ class ChatHistorySummarizer: self.chat_id = chat_id self._chat_display_name = self._get_chat_display_name() self.log_prefix = f"[{self._chat_display_name}]" - + # 记录时间点,用于计算新消息 self.last_check_time = time.time() - + # 当前累积的消息批次 self.current_batch: Optional[MessageBatch] = None - + # LLM请求器,用于压缩聊天内容 self.summarizer_llm = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="chat_history_summarizer" + model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer" ) - + # 后台循环相关 self.check_interval = check_interval # 检查间隔(秒) self._periodic_task: Optional[asyncio.Task] = None self._running = False - + def _get_chat_display_name(self) -> str: """获取聊天显示名称""" try: @@ -76,17 +77,17 @@ class ChatHistorySummarizer: if len(self.chat_id) > 20: return f"{self.chat_id[:8]}..." return self.chat_id - + async def process(self, current_time: Optional[float] = None): """ 处理聊天内容概括 - + Args: current_time: 当前时间戳,如果为None则使用time.time() """ if current_time is None: current_time = time.time() - + try: logger.info( f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" @@ -101,25 +102,23 @@ class ChatHistorySummarizer: filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言 filter_command=False, ) - + if not new_messages: # 没有新消息,检查是否需要打包 if self.current_batch and self.current_batch.messages: await self._check_and_package(current_time) self.last_check_time = current_time return - + # 有新消息,更新最后检查时间 self.last_check_time = current_time - + # 如果有当前批次,添加新消息 if self.current_batch: before_count = len(self.current_batch.messages) self.current_batch.messages.extend(new_messages) self.current_batch.end_time = current_time - logger.info( - f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息" - ) + logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息") else: # 创建新批次 self.current_batch = MessageBatch( @@ -127,23 +126,22 @@ class ChatHistorySummarizer: start_time=new_messages[0].time if new_messages else current_time, end_time=current_time, ) - logger.info( - f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息" - ) - + logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息") + # 检查是否需要打包 await self._check_and_package(current_time) - + except Exception as e: logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") import traceback + traceback.print_exc() - + async def _check_and_package(self, current_time: float): """检查是否需要打包""" if not self.current_batch or not self.current_batch.messages: return - + messages = self.current_batch.messages message_count = len(messages) last_message_time = messages[-1].time if messages else current_time @@ -153,48 +151,48 @@ class ChatHistorySummarizer: if time_since_last_message < 60: time_str = f"{time_since_last_message:.1f}秒" elif time_since_last_message < 3600: - time_str = f"{time_since_last_message/60:.1f}分钟" + time_str = f"{time_since_last_message / 60:.1f}分钟" else: - time_str = f"{time_since_last_message/3600:.1f}小时" - + time_str = f"{time_since_last_message / 3600:.1f}小时" + preparing_status = "是" if self.current_batch.is_preparing else "否" - + logger.info( f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}" ) - + # 检查打包条件 should_package = False - + # 条件1: 消息长度超过120,直接打包 if message_count >= 120: should_package = True logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条)") - + # 条件2: 最后一条消息的时间和当前时间差>600秒,直接打包 elif time_since_last_message > 600: should_package = True logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟)") - + # 条件3: 消息长度超过100,进入准备结束模式 elif message_count > 100: if not self.current_batch.is_preparing: self.current_batch.is_preparing = True logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值(100条),进入准备结束模式") - + # 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒,就打包 if time_since_last_message > 10: should_package = True logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒)") - + if should_package: await self._package_and_store() - + async def _package_and_store(self): """打包并存储聊天记录""" if not self.current_batch or not self.current_batch.messages: return - + messages = self.current_batch.messages start_time = self.current_batch.start_time end_time = self.current_batch.end_time @@ -202,12 +200,12 @@ class ChatHistorySummarizer: logger.info( f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" ) - + # 检查是否有bot发言 # 第一条消息前推600s到最后一条消息的时间内 check_start_time = max(start_time - 600, 0) check_end_time = end_time - + # 使用包含边界的时间范围查询 bot_messages = message_api.get_messages_by_time_in_chat_inclusive( chat_id=self.chat_id, @@ -218,7 +216,7 @@ class ChatHistorySummarizer: filter_mai=False, filter_command=False, ) - + # 检查是否有bot的发言 has_bot_message = False bot_user_id = str(global_config.bot.qq_account) @@ -226,14 +224,14 @@ class ChatHistorySummarizer: if msg.user_info.user_id == bot_user_id: has_bot_message = True break - + if not has_bot_message: logger.info( f"{self.log_prefix} 批次内无Bot发言,丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}" ) self.current_batch = None return - + # 有bot发言,进行压缩和存储 try: # 构建对话原文 @@ -245,39 +243,36 @@ class ChatHistorySummarizer: truncate=False, show_actions=False, ) - + # 获取参与的所有人的昵称 participants_set: Set[str] = set() for msg in messages: # 使用 msg.user_platform(扁平化字段)或 msg.user_info.platform - platform = getattr(msg, 'user_platform', None) or (msg.user_info.platform if msg.user_info else None) or msg.chat_info.platform - person = Person( - platform=platform, - user_id=msg.user_info.user_id + platform = ( + getattr(msg, "user_platform", None) + or (msg.user_info.platform if msg.user_info else None) + or msg.chat_info.platform ) + person = Person(platform=platform, user_id=msg.user_info.user_id) person_name = person.person_name if person_name: participants_set.add(person_name) participants = list(participants_set) - logger.info( - f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}" - ) - + logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}") + # 使用LLM压缩聊天内容 success, theme, keywords, summary = await self._compress_with_llm(original_text) - + if not success: - logger.warning( - f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}" - ) + logger.warning(f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}") # 清空当前批次,避免重复处理 self.current_batch = None return - + logger.info( f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)} 字" ) - + # 存储到数据库 await self._store_to_database( start_time=start_time, @@ -288,23 +283,24 @@ class ChatHistorySummarizer: keywords=keywords, summary=summary, ) - + logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}") - + # 清空当前批次 self.current_batch = None - + except Exception as e: logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}") import traceback + traceback.print_exc() # 出错时也清空批次,避免重复处理 self.current_batch = None - + async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]: """ 使用LLM压缩聊天内容 - + Returns: tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括) """ @@ -325,37 +321,37 @@ class ChatHistorySummarizer: {original_text} 请直接返回JSON,不要包含其他内容。""" - + try: response, _ = await self.summarizer_llm.generate_response_async( prompt=prompt, temperature=0.3, max_tokens=500, ) - + # 解析JSON响应 import re - + # 移除可能的markdown代码块标记 json_str = response.strip() - json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE) - json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE) + json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) + json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) json_str = json_str.strip() - + # 尝试找到JSON对象的开始和结束位置 # 查找第一个 { 和最后一个匹配的 } - start_idx = json_str.find('{') + start_idx = json_str.find("{") if start_idx == -1: raise ValueError("未找到JSON对象开始标记") - + # 从后往前查找最后一个 } - end_idx = json_str.rfind('}') + end_idx = json_str.rfind("}") if end_idx == -1 or end_idx <= start_idx: raise ValueError("未找到JSON对象结束标记") - + # 提取JSON字符串 - json_str = json_str[start_idx:end_idx + 1] - + json_str = json_str[start_idx : end_idx + 1] + # 尝试解析JSON try: result = json.loads(json_str) @@ -372,7 +368,7 @@ class ChatHistorySummarizer: if escape_next: fixed_chars.append(char) escape_next = False - elif char == '\\': + elif char == "\\": fixed_chars.append(char) escape_next = True elif char == '"' and not escape_next: @@ -384,27 +380,27 @@ class ChatHistorySummarizer: else: fixed_chars.append(char) i += 1 - - json_str = ''.join(fixed_chars) + + json_str = "".join(fixed_chars) # 再次尝试解析 result = json.loads(json_str) - + theme = result.get("theme", "未命名对话") keywords = result.get("keywords", []) summary = result.get("summary", "无概括") - + # 确保keywords是列表 if isinstance(keywords, str): keywords = [keywords] - + return True, theme, keywords, summary - + except Exception as e: logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") # 返回失败标志和默认值 return False, "未命名对话", [], "压缩失败,无法生成概括" - + async def _store_to_database( self, start_time: float, @@ -419,7 +415,7 @@ class ChatHistorySummarizer: try: from src.common.database.database_model import ChatHistory from src.plugin_system.apis import database_api - + # 准备数据 data = { "chat_id": self.chat_id, @@ -432,7 +428,7 @@ class ChatHistorySummarizer: "summary": summary, "count": 0, } - + # 使用db_save存储(使用start_time和chat_id作为唯一标识) # 由于可能有多条记录,我们使用组合键,但peewee不支持,所以使用start_time作为唯一标识 # 但为了避免冲突,我们使用组合键:chat_id + start_time @@ -441,28 +437,29 @@ class ChatHistorySummarizer: ChatHistory, data=data, ) - + if saved_record: logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库") else: logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") - + except Exception as e: logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") import traceback + traceback.print_exc() raise - + async def start(self): """启动后台定期检查循环""" if self._running: logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动") return - + self._running = True self._periodic_task = asyncio.create_task(self._periodic_check_loop()) logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒") - + async def stop(self): """停止后台定期检查循环""" self._running = False @@ -474,14 +471,14 @@ class ChatHistorySummarizer: pass self._periodic_task = None logger.info(f"{self.log_prefix} 已停止后台定期检查循环") - + async def _periodic_check_loop(self): """后台定期检查循环""" try: while self._running: # 执行一次检查 await self.process() - + # 等待指定间隔后再次检查 await asyncio.sleep(self.check_interval) except asyncio.CancelledError: @@ -490,6 +487,6 @@ class ChatHistorySummarizer: except Exception as e: logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") import traceback + traceback.print_exc() self._running = False - diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 94288900..4bd7850f 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -2,7 +2,7 @@ import time import random import re -from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable +from typing import List, Dict, Any, Tuple, Optional, Callable from rich.traceback import install from src.config.config import global_config @@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re output_lines = [] current_time = time.time() - for action in actions: action_time = action.time or current_time action_name = action.action_name or "未知动作" @@ -595,7 +594,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”" output_lines.append(line) - return "\n".join(output_lines) @@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str: return formatted_string - async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: """ 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。 diff --git a/src/chat/utils/memory_forget_task.py b/src/chat/utils/memory_forget_task.py index 11d49171..15a912b4 100644 --- a/src/chat/utils/memory_forget_task.py +++ b/src/chat/utils/memory_forget_task.py @@ -2,6 +2,7 @@ 记忆遗忘任务 每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆 """ + import time import random from typing import List @@ -15,27 +16,27 @@ logger = get_logger("memory_forget_task") class MemoryForgetTask(AsyncTask): """记忆遗忘任务,每5分钟执行一次""" - + def __init__(self): # 每5分钟执行一次(300秒) super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300) - + async def run(self): """执行遗忘检查""" try: current_time = time.time() logger.info("[记忆遗忘] 开始遗忘检查...") - + # 执行4个阶段的遗忘检查 await self._forget_stage_1(current_time) await self._forget_stage_2(current_time) await self._forget_stage_3(current_time) await self._forget_stage_4(current_time) - + logger.info("[记忆遗忘] 遗忘检查完成") except Exception as e: logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True) - + async def _forget_stage_1(self, current_time: float): """ 第一次遗忘检查: @@ -45,38 +46,34 @@ class MemoryForgetTask(AsyncTask): try: # 30分钟 = 1800秒 time_threshold = current_time - 1800 - + # 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold candidates = list( - ChatHistory.select() - .where( - (ChatHistory.forget_times == 0) & - (ChatHistory.end_time < time_threshold) - ) + ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold)) ) - + if not candidates: logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆") return - + logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆") - + # 按count排序 candidates.sort(key=lambda x: x.count, reverse=True) - + # 计算要删除的数量(最高25%和最低25%) total_count = len(candidates) delete_count = int(total_count * 0.25) # 25% - + if delete_count == 0: logger.debug("[记忆遗忘-阶段1] 删除数量为0,跳过") return - + # 选择要删除的记录(处理count相同的情况:随机选择) to_delete = [] to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - + # 去重(避免重复删除),使用id去重 seen_ids = set() unique_to_delete = [] @@ -85,7 +82,7 @@ class MemoryForgetTask(AsyncTask): seen_ids.add(record.id) unique_to_delete.append(record) to_delete = unique_to_delete - + # 删除记录并更新forget_times deleted_count = 0 for record in to_delete: @@ -94,22 +91,22 @@ class MemoryForgetTask(AsyncTask): deleted_count += 1 except Exception as e: logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}") - + # 更新剩余记录的forget_times为1 to_delete_ids = {r.id for r in to_delete} remaining = [r for r in candidates if r.id not in to_delete_ids] if remaining: # 批量更新 ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=1).where( - ChatHistory.id.in_(ids_to_update) - ).execute() - - logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1") - + ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute() + + logger.info( + f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1" + ) + except Exception as e: logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True) - + async def _forget_stage_2(self, current_time: float): """ 第二次遗忘检查: @@ -119,41 +116,37 @@ class MemoryForgetTask(AsyncTask): try: # 8小时 = 28800秒 time_threshold = current_time - 28800 - + # 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold candidates = list( - ChatHistory.select() - .where( - (ChatHistory.forget_times == 1) & - (ChatHistory.end_time < time_threshold) - ) + ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold)) ) - + if not candidates: logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆") return - + logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆") - + # 按count排序 candidates.sort(key=lambda x: x.count, reverse=True) - + # 计算要删除的数量(最高7%和最低7%) total_count = len(candidates) delete_count = int(total_count * 0.07) # 7% - + if delete_count == 0: logger.debug("[记忆遗忘-阶段2] 删除数量为0,跳过") return - + # 选择要删除的记录 to_delete = [] to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - + # 去重 to_delete = list(set(to_delete)) - + # 删除记录 deleted_count = 0 for record in to_delete: @@ -162,21 +155,21 @@ class MemoryForgetTask(AsyncTask): deleted_count += 1 except Exception as e: logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}") - + # 更新剩余记录的forget_times为2 to_delete_ids = {r.id for r in to_delete} remaining = [r for r in candidates if r.id not in to_delete_ids] if remaining: ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=2).where( - ChatHistory.id.in_(ids_to_update) - ).execute() - - logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2") - + ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute() + + logger.info( + f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2" + ) + except Exception as e: logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True) - + async def _forget_stage_3(self, current_time: float): """ 第三次遗忘检查: @@ -186,41 +179,37 @@ class MemoryForgetTask(AsyncTask): try: # 48小时 = 172800秒 time_threshold = current_time - 172800 - + # 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold candidates = list( - ChatHistory.select() - .where( - (ChatHistory.forget_times == 2) & - (ChatHistory.end_time < time_threshold) - ) + ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold)) ) - + if not candidates: logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆") return - + logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆") - + # 按count排序 candidates.sort(key=lambda x: x.count, reverse=True) - + # 计算要删除的数量(最高5%和最低5%) total_count = len(candidates) delete_count = int(total_count * 0.05) # 5% - + if delete_count == 0: logger.debug("[记忆遗忘-阶段3] 删除数量为0,跳过") return - + # 选择要删除的记录 to_delete = [] to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - + # 去重 to_delete = list(set(to_delete)) - + # 删除记录 deleted_count = 0 for record in to_delete: @@ -229,21 +218,21 @@ class MemoryForgetTask(AsyncTask): deleted_count += 1 except Exception as e: logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}") - + # 更新剩余记录的forget_times为3 to_delete_ids = {r.id for r in to_delete} remaining = [r for r in candidates if r.id not in to_delete_ids] if remaining: ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=3).where( - ChatHistory.id.in_(ids_to_update) - ).execute() - - logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3") - + ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute() + + logger.info( + f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3" + ) + except Exception as e: logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True) - + async def _forget_stage_4(self, current_time: float): """ 第四次遗忘检查: @@ -253,41 +242,37 @@ class MemoryForgetTask(AsyncTask): try: # 7天 = 604800秒 time_threshold = current_time - 604800 - + # 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold candidates = list( - ChatHistory.select() - .where( - (ChatHistory.forget_times == 3) & - (ChatHistory.end_time < time_threshold) - ) + ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold)) ) - + if not candidates: logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆") return - + logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆") - + # 按count排序 candidates.sort(key=lambda x: x.count, reverse=True) - + # 计算要删除的数量(最高2%和最低2%) total_count = len(candidates) delete_count = int(total_count * 0.02) # 2% - + if delete_count == 0: logger.debug("[记忆遗忘-阶段4] 删除数量为0,跳过") return - + # 选择要删除的记录 to_delete = [] to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - + # 去重 to_delete = list(set(to_delete)) - + # 删除记录 deleted_count = 0 for record in to_delete: @@ -296,38 +281,40 @@ class MemoryForgetTask(AsyncTask): deleted_count += 1 except Exception as e: logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}") - + # 更新剩余记录的forget_times为4 to_delete_ids = {r.id for r in to_delete} remaining = [r for r in candidates if r.id not in to_delete_ids] if remaining: ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=4).where( - ChatHistory.id.in_(ids_to_update) - ).execute() - - logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4") - + ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute() + + logger.info( + f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4" + ) + except Exception as e: logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True) - - def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]: + + def _handle_same_count_random( + self, candidates: List[ChatHistory], delete_count: int, mode: str + ) -> List[ChatHistory]: """ 处理count相同的情况,随机选择要删除的记录 - + Args: candidates: 候选记录列表(已按count排序) delete_count: 要删除的数量 mode: "high" 表示选择最高count的记录,"low" 表示选择最低count的记录 - + Returns: 要删除的记录列表 """ if not candidates or delete_count == 0: return [] - + to_delete = [] - + if mode == "high": # 从最高count开始选择 start_idx = 0 @@ -339,7 +326,7 @@ class MemoryForgetTask(AsyncTask): while idx < len(candidates) and candidates[idx].count == current_count: same_count_records.append(candidates[idx]) idx += 1 - + # 如果相同count的记录数量 <= 还需要删除的数量,全部选择 needed = delete_count - len(to_delete) if len(same_count_records) <= needed: @@ -347,9 +334,9 @@ class MemoryForgetTask(AsyncTask): else: # 随机选择需要的数量 to_delete.extend(random.sample(same_count_records, needed)) - + start_idx = idx - + else: # mode == "low" # 从最低count开始选择 start_idx = len(candidates) - 1 @@ -361,7 +348,7 @@ class MemoryForgetTask(AsyncTask): while idx >= 0 and candidates[idx].count == current_count: same_count_records.append(candidates[idx]) idx -= 1 - + # 如果相同count的记录数量 <= 还需要删除的数量,全部选择 needed = delete_count - len(to_delete) if len(same_count_records) <= needed: @@ -369,8 +356,7 @@ class MemoryForgetTask(AsyncTask): else: # 随机选择需要的数量 to_delete.extend(random.sample(same_count_records, needed)) - - start_idx = idx - - return to_delete + start_idx = idx + + return to_delete diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index bcd0a1f8..9b5497e9 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -153,7 +153,7 @@ def _format_large_number(num: float | int, html: bool = False) -> str: else: number_part = f"{value:.1f}" k_suffix = "K" - + if html: # HTML输出:K着色为主题色并加粗大写 return f"{number_part}K" @@ -502,9 +502,13 @@ class StatisticOutputTask(AsyncTask): } for period_key, _ in collect_period } - + # 获取bot的QQ账号 - bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else "" + bot_qq_account = ( + str(global_config.bot.qq_account) + if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account") + else "" + ) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore @@ -547,7 +551,7 @@ class StatisticOutputTask(AsyncTask): is_bot_reply = False if bot_qq_account and message.user_id == bot_qq_account: is_bot_reply = True - + for idx, (_, period_start_dt) in enumerate(collect_period): if message_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: @@ -588,7 +592,9 @@ class StatisticOutputTask(AsyncTask): continue last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳 - self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段 + self.stat_period = [ + item for item in self.stat_period if item[0] != "all_time" + ] # 删除"所有时间"的统计时段 self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的")) except Exception as e: logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}") @@ -640,12 +646,12 @@ class StatisticOutputTask(AsyncTask): # 更新上次完整统计数据的时间戳 # 将所有defaultdict转换为普通dict以避免类型冲突 clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"]) - + # 将 name_mapping 中的元组转换为列表,因为JSON不支持元组 json_safe_name_mapping = {} for chat_id, (chat_name, timestamp) in self.name_mapping.items(): json_safe_name_mapping[chat_id] = [chat_name, timestamp] - + local_storage["last_full_statistics"] = { "name_mapping": json_safe_name_mapping, "stat_data": clean_stat_data, @@ -682,24 +688,28 @@ class StatisticOutputTask(AsyncTask): """ # 计算总token数(从所有模型的token数中累加) total_tokens = sum(stats[TOTAL_TOK_BY_MODEL].values()) if stats[TOTAL_TOK_BY_MODEL] else 0 - + # 计算花费/消息数量指标(每100条) cost_per_100_messages = (stats[TOTAL_COST] / stats[TOTAL_MSG_CNT] * 100) if stats[TOTAL_MSG_CNT] > 0 else 0.0 - + # 计算花费/时间指标(花费/小时) online_hours = stats[ONLINE_TIME] / 3600.0 if stats[ONLINE_TIME] > 0 else 0.0 cost_per_hour = stats[TOTAL_COST] / online_hours if online_hours > 0 else 0.0 - + # 计算token/时间指标(token/小时) tokens_per_hour = (total_tokens / online_hours) if online_hours > 0 else 0.0 - + # 计算花费/回复数量指标(每100条) total_replies = stats.get(TOTAL_REPLY_CNT, 0) cost_per_100_replies = (stats[TOTAL_COST] / total_replies * 100) if total_replies > 0 else 0.0 - + # 计算花费/消息数量(排除自己回复)指标(每100条) total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies - cost_per_100_messages_excluding_replies = (stats[TOTAL_COST] / total_messages_excluding_replies * 100) if total_messages_excluding_replies > 0 else 0.0 + cost_per_100_messages_excluding_replies = ( + (stats[TOTAL_COST] / total_messages_excluding_replies * 100) + if total_messages_excluding_replies > 0 + else 0.0 + ) output = [ f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}", @@ -709,7 +719,9 @@ class StatisticOutputTask(AsyncTask): f"总Token数: {_format_large_number(total_tokens)}", f"总花费: {stats[TOTAL_COST]:.2f}¥", f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A", - f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" if total_messages_excluding_replies > 0 else "花费/消息数量(排除回复): N/A", + f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" + if total_messages_excluding_replies > 0 + else "花费/消息数量(排除回复): N/A", f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A", f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A", f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A", @@ -745,7 +757,16 @@ class StatisticOutputTask(AsyncTask): formatted_out_tokens = _format_large_number(out_tokens) formatted_tokens = _format_large_number(tokens) output.append( - data_fmt.format(name, formatted_count, formatted_in_tokens, formatted_out_tokens, formatted_tokens, cost, avg_time_cost, std_time_cost) + data_fmt.format( + name, + formatted_count, + formatted_in_tokens, + formatted_out_tokens, + formatted_tokens, + cost, + avg_time_cost, + std_time_cost, + ) ) output.append("") @@ -891,8 +912,12 @@ class StatisticOutputTask(AsyncTask): except (IndexError, TypeError) as e: logger.warning(f"生成HTML聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}") chat_rows.append(f"未知聊天{_format_large_number(count, html=True)}") - - chat_rows_html = "\n".join(chat_rows) if chat_rows else "暂无数据" + + chat_rows_html = ( + "\n".join(chat_rows) + if chat_rows + else "暂无数据" + ) # 生成HTML return f"""
@@ -1197,7 +1222,7 @@ class StatisticOutputTask(AsyncTask): # 添加图表内容 chart_data = self._generate_chart_data(stat) tab_content_list.append(self._generate_chart_tab(chart_data)) - + # 添加指标趋势图表 metrics_data = self._generate_metrics_data(now) tab_content_list.append(self._generate_metrics_tab(metrics_data)) @@ -1772,121 +1797,125 @@ class StatisticOutputTask(AsyncTask): def _generate_metrics_data(self, now: datetime) -> dict: """生成指标趋势数据""" metrics_data = {} - + # 24小时尺度:1小时为单位 metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1) - + # 7天尺度:1天为单位 - metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24*7, interval_hours=24) - + metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24) + # 30天尺度:1天为单位 - metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24*30, interval_hours=24) - + metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24) + return metrics_data - + def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict: """收集指定时间范围内每个间隔的指标数据""" start_time = now - timedelta(hours=hours) time_points = [] current_time = start_time - + # 生成时间点 while current_time <= now: time_points.append(current_time) current_time += timedelta(hours=interval_hours) - + # 初始化数据结构 cost_per_100_messages = [0.0] * len(time_points) # 花费/消息数量(每100条) cost_per_hour = [0.0] * len(time_points) # 花费/时间(每小时) tokens_per_hour = [0.0] * len(time_points) # Token/时间(每小时) cost_per_100_replies = [0.0] * len(time_points) # 花费/回复数量(每100条) - + # 每个时间点的累计数据 total_costs = [0.0] * len(time_points) total_tokens = [0] * len(time_points) total_messages = [0] * len(time_points) total_replies = [0] * len(time_points) total_online_hours = [0.0] * len(time_points) - + # 获取bot的QQ账号 - bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else "" - + bot_qq_account = ( + str(global_config.bot.qq_account) + if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account") + else "" + ) + interval_seconds = interval_hours * 3600 - + # 查询LLM使用记录 query_start_time = start_time for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore record_time = record.timestamp - + # 找到对应的时间间隔索引 time_diff = (record_time - start_time).total_seconds() interval_index = int(time_diff // interval_seconds) - + if 0 <= interval_index < len(time_points): cost = record.cost or 0.0 prompt_tokens = record.prompt_tokens or 0 completion_tokens = record.completion_tokens or 0 total_token = prompt_tokens + completion_tokens - + total_costs[interval_index] += cost total_tokens[interval_index] += total_token - + # 查询消息记录 query_start_timestamp = start_time.timestamp() for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore message_time_ts = message.time - + time_diff = message_time_ts - query_start_timestamp interval_index = int(time_diff // interval_seconds) - + if 0 <= interval_index < len(time_points): total_messages[interval_index] += 1 # 检查是否是bot发送的消息(回复) if bot_qq_account and message.user_id == bot_qq_account: total_replies[interval_index] += 1 - + # 查询在线时间记录 for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore record_start = record.start_timestamp record_end = record.end_timestamp - + # 找到记录覆盖的所有时间间隔 for idx, time_point in enumerate(time_points): interval_start = time_point interval_end = time_point + timedelta(hours=interval_hours) - + # 计算重叠部分 overlap_start = max(record_start, interval_start) overlap_end = min(record_end, interval_end) - + if overlap_end > overlap_start: overlap_hours = (overlap_end - overlap_start).total_seconds() / 3600.0 total_online_hours[idx] += overlap_hours - + # 计算指标 for idx in range(len(time_points)): # 花费/消息数量(每100条) if total_messages[idx] > 0: - cost_per_100_messages[idx] = (total_costs[idx] / total_messages[idx] * 100) - + cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100 + # 花费/时间(每小时) if total_online_hours[idx] > 0: - cost_per_hour[idx] = (total_costs[idx] / total_online_hours[idx]) - + cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx] + # Token/时间(每小时) if total_online_hours[idx] > 0: - tokens_per_hour[idx] = (total_tokens[idx] / total_online_hours[idx]) - + tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx] + # 花费/回复数量(每100条) if total_replies[idx] > 0: - cost_per_100_replies[idx] = (total_costs[idx] / total_replies[idx] * 100) - + cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100 + # 生成时间标签 if interval_hours == 1: time_labels = [t.strftime("%H:%M") for t in time_points] else: time_labels = [t.strftime("%m-%d") for t in time_points] - + return { "time_labels": time_labels, "cost_per_100_messages": cost_per_100_messages, @@ -1894,7 +1923,7 @@ class StatisticOutputTask(AsyncTask): "tokens_per_hour": tokens_per_hour, "cost_per_100_replies": cost_per_100_replies, } - + def _generate_metrics_tab(self, metrics_data: dict) -> str: """生成指标趋势图表选项卡HTML内容""" colors = { @@ -1903,7 +1932,7 @@ class StatisticOutputTask(AsyncTask): "tokens_per_hour": "#c7bbff", "cost_per_100_replies": "#d9ceff", } - + return f"""

指标趋势图表

diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index f9f551ce..0464b734 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -4,14 +4,11 @@ import time import jieba import json import ast -import numpy as np -from collections import Counter from typing import Optional, Tuple, List, TYPE_CHECKING from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages -from src.common.message_repository import find_messages, count_messages from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager @@ -32,10 +29,10 @@ def is_english_letter(char: str) -> bool: def parse_platform_accounts(platforms: list[str]) -> dict[str, str]: """解析 platforms 列表,返回平台到账号的映射 - + Args: platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"] - + Returns: 字典,键为平台名,值为账号 """ @@ -49,12 +46,12 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]: def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str: """根据当前平台获取对应的账号 - + Args: platform: 当前消息的平台 platform_accounts: 从 platforms 列表解析的平台账号映射 qq_account: QQ 账号(兼容旧配置) - + Returns: 当前平台对应的账号 """ @@ -72,12 +69,12 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float """检查消息是否提到了机器人(统一多平台实现)""" text = message.processed_plain_text or "" platform = getattr(message.message_info, "platform", "") or "" - + # 获取各平台账号 platforms_list = getattr(global_config.bot, "platforms", []) or [] platform_accounts = parse_platform_accounts(platforms_list) qq_account = str(getattr(global_config.bot, "qq_account", "") or "") - + # 获取当前平台对应的账号 current_account = get_current_platform_account(platform, platform_accounts, qq_account) @@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float elif current_account: if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text): is_mentioned = True - elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text): + elif re.search( + rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text + ): is_mentioned = True # 6) 名称/别名 提及(去除 @/回复标记后再匹配) @@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]] return embedding - def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: """将文本分割成句子,并根据概率合并 1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。 @@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: prev_char = text[i - 1] next_char = text[i + 1] # 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则 - if char == ' ': + if char == " ": prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char) next_is_alnum = next_char.isdigit() or is_english_letter(next_char) if prev_is_alnum and next_is_alnum: @@ -340,7 +338,7 @@ def _get_random_default_reply() -> str: "不知道", "不晓得", "懒得说", - "()" + "()", ] return random.choice(default_replies) @@ -469,7 +467,6 @@ def calculate_typing_time( return total_time # 加上回车时间 - def truncate_message(message: str, max_length=20) -> str: """截断消息,使其不超过指定长度""" return f"{message[:max_length]}..." if len(message) > max_length else message @@ -546,7 +543,6 @@ def get_western_ratio(paragraph): return western_count / len(alnum_chars) - def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch """将时间戳转换为人类可读的时间格式 diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 4ce64eca..f6012f09 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -103,14 +103,16 @@ class ImageManager: invalid_values = ["", "None"] # 清理 Images 表 - deleted_images = Images.delete().where( - (Images.description >> None) | (Images.description << invalid_values) - ).execute() + deleted_images = ( + Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute() + ) # 清理 ImageDescriptions 表 - deleted_descriptions = ImageDescriptions.delete().where( - (ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values) - ).execute() + deleted_descriptions = ( + ImageDescriptions.delete() + .where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)) + .execute() + ) if deleted_images or deleted_descriptions: logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条") diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 8b2e94c3..b981bd33 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -220,7 +220,7 @@ class DatabaseActionRecords(BaseDataModel): chat_id: str, chat_info_stream_id: str, chat_info_platform: str, - action_reasoning:str + action_reasoning: str, ): self.action_id = action_id self.time = time @@ -235,4 +235,4 @@ class DatabaseActionRecords(BaseDataModel): self.chat_id = chat_id self.chat_info_stream_id = chat_info_stream_id self.chat_info_platform = chat_info_platform - self.action_reasoning = action_reasoning \ No newline at end of file + self.action_reasoning = action_reasoning diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 424ec125..73677962 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -317,10 +317,12 @@ class Expression(BaseModel): class Meta: table_name = "expression" + class Jargon(BaseModel): """ 用于存储俚语的模型 """ + content = TextField() raw_content = TextField(null=True) type = TextField(null=True) @@ -332,14 +334,16 @@ class Jargon(BaseModel): is_jargon = BooleanField(null=True) # None表示未判定,True表示是黑话,False表示不是黑话 last_inference_count = IntegerField(null=True) # 最后一次判定的count值,用于避免重启后重复判定 is_complete = BooleanField(default=False) # 是否已完成所有推断(count>=100后不再推断) - + class Meta: table_name = "jargon" + class ChatHistory(BaseModel): """ 用于存储聊天历史概括的模型 """ + chat_id = TextField(index=True) # 聊天ID start_time = DoubleField() # 起始时间 end_time = DoubleField() # 结束时间 @@ -350,7 +354,7 @@ class ChatHistory(BaseModel): summary = TextField() # 概括:对这段话的平文本概括 count = IntegerField(default=0) # 被检索次数 forget_times = IntegerField(default=0) # 被遗忘检查的次数 - + class Meta: table_name = "chat_history" @@ -359,6 +363,7 @@ class ThinkingBack(BaseModel): """ 用于存储记忆检索思考过程的模型 """ + chat_id = TextField(index=True) # 聊天ID question = TextField() # 提出的问题 context = TextField(null=True) # 上下文信息 @@ -367,10 +372,11 @@ class ThinkingBack(BaseModel): thinking_steps = TextField(null=True) # 思考步骤(JSON格式) create_time = DoubleField() # 创建时间 update_time = DoubleField() # 更新时间 - + class Meta: table_name = "thinking_back" + MODELS = [ ChatStreams, LLMUsage, @@ -387,6 +393,7 @@ MODELS = [ ThinkingBack, ] + def create_tables(): """ 创建所有在模型中定义的数据库表。 diff --git a/src/config/official_configs.py b/src/config/official_configs.py index e21d8f96..bc0976e8 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -27,7 +27,7 @@ class BotConfig(ConfigBase): nickname: str """昵称""" - + platforms: list[str] = field(default_factory=lambda: []) """其他平台列表""" @@ -311,16 +311,18 @@ class MessageReceiveConfig(ConfigBase): ban_msgs_regex: set[str] = field(default_factory=lambda: set()) """过滤正则表达式列表""" + @dataclass class MemoryConfig(ConfigBase): """记忆配置类""" - + max_memory_number: int = 100 """记忆最大数量""" - + memory_build_frequency: int = 1 """记忆构建频率""" + @dataclass class ExpressionConfig(ConfigBase): """表达配置类""" @@ -494,13 +496,14 @@ class MoodConfig(ConfigBase): enable_mood: bool = True """是否启用情绪系统""" - + mood_update_threshold: float = 1 """情绪更新阈值,越高,更新越慢""" - + emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大" """情感特征,影响情绪的变化情况""" + @dataclass class VoiceConfig(ConfigBase): """语音识别配置类""" @@ -644,16 +647,16 @@ class DebugConfig(ConfigBase): show_prompt: bool = False """是否显示prompt""" - + show_replyer_prompt: bool = True """是否显示回复器prompt""" - + show_replyer_reasoning: bool = True """是否显示回复器推理""" - + show_jargon_prompt: bool = False """是否显示jargon相关提示词""" - + show_planner_prompt: bool = False """是否显示planner相关提示词""" diff --git a/src/express/express_utils.py b/src/express/express_utils.py index bf065495..c27306d1 100644 --- a/src/express/express_utils.py +++ b/src/express/express_utils.py @@ -3,31 +3,30 @@ import difflib import random from datetime import datetime from typing import Optional, List, Dict -from collections import defaultdict def filter_message_content(content: Optional[str]) -> str: """ 过滤消息内容,移除回复、@、图片等格式 - + Args: content: 原始消息内容 - + Returns: str: 过滤后的内容 """ if not content: return "" - + # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r'\[回复.*?\],说:\s*', '', content) + content = re.sub(r"\[回复.*?\],说:\s*", "", content) # 移除@<...>格式的内容 - content = re.sub(r'@<[^>]*>', '', content) + content = re.sub(r"@<[^>]*>", "", content) # 移除[picid:...]格式的图片ID - content = re.sub(r'\[picid:[^\]]*\]', '', content) + content = re.sub(r"\[picid:[^\]]*\]", "", content) # 移除[表情包:...]格式的内容 - content = re.sub(r'\[表情包:[^\]]*\]', '', content) - + content = re.sub(r"\[表情包:[^\]]*\]", "", content) + return content.strip() @@ -35,11 +34,11 @@ def calculate_similarity(text1: str, text2: str) -> float: """ 计算两个文本的相似度,返回0-1之间的值 使用SequenceMatcher计算相似度 - + Args: text1: 第一个文本 text2: 第二个文本 - + Returns: float: 相似度值,范围0-1 """ @@ -49,10 +48,10 @@ def calculate_similarity(text1: str, text2: str) -> float: def format_create_date(timestamp: float) -> str: """ 将时间戳格式化为可读的日期字符串 - + Args: timestamp: 时间戳 - + Returns: str: 格式化后的日期字符串 """ @@ -65,11 +64,11 @@ def format_create_date(timestamp: float) -> str: def weighted_sample(population: List[Dict], k: int) -> List[Dict]: """ 随机抽样函数 - + Args: population: 总体数据列表 k: 需要抽取的数量 - + Returns: List[Dict]: 抽取的数据列表 """ diff --git a/src/express/expression_learner.py b/src/express/expression_learner.py index 4d2894fb..b4c357d9 100644 --- a/src/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -1,7 +1,6 @@ import time import json import os -from datetime import datetime from typing import List, Optional, Tuple import traceback from src.common.logger import get_logger @@ -158,8 +157,6 @@ class ExpressionLearner: traceback.print_exc() return - - async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]: """ 学习并存储表达方式 @@ -169,7 +166,7 @@ class ExpressionLearner: if learnt_expressions is None: logger.info("没有学习到表达风格") return [] - + # 展示学到的表达方式 learnt_expressions_str = "" for ( @@ -186,7 +183,7 @@ class ExpressionLearner: # 存储到数据库 Expression 表并训练 style_learner has_new_expressions = False # 记录是否有新的表达方式 learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例 - + for ( situation, style, @@ -195,9 +192,7 @@ class ExpressionLearner: ) in learnt_expressions: # 查找是否已存在相似表达方式 query = Expression.select().where( - (Expression.chat_id == self.chat_id) - & (Expression.situation == situation) - & (Expression.style == style) + (Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style) ) if query.exists(): # 表达方式完全相同,只更新时间戳 @@ -216,39 +211,37 @@ class ExpressionLearner: up_content=up_content, ) has_new_expressions = True - + # 训练 style_learner(up_content 和 style 必定存在) try: learner.add_style(style, situation) - + # 学习映射关系 - success = style_learner_manager.learn_mapping( - self.chat_id, - up_content, - style - ) + success = style_learner_manager.learn_mapping(self.chat_id, up_content, style) if success: - logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else "")) + logger.debug( + f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + + (f" (situation: {situation})" if situation else "") + ) else: logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}") except Exception as e: logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}") - - + # 保存当前聊天室的 style_learner 模型 if has_new_expressions: try: logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...") save_success = learner.save(style_learner_manager.model_save_path) - + if save_success: logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}") else: logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}") - + except Exception as e: logger.error(f"StyleLearner 模型保存异常: {e}") - + return learnt_expressions async def match_expression_context( @@ -334,7 +327,7 @@ class ExpressionLearner: matched_expressions = [] used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引 - + logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}") logger.debug(f"match_responses 内容: {match_responses}") @@ -344,12 +337,12 @@ class ExpressionLearner: if not isinstance(match_response, dict): logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}") continue - + # 获取表达方式序号 if "expression_pair" not in match_response: logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}") continue - + pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引 # 检查索引是否有效且未被使用过 @@ -367,9 +360,7 @@ class ExpressionLearner: return matched_expressions - async def learn_expression( - self, num: int = 10 - ) -> Optional[List[Tuple[str, str, str, str]]]: + async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]: """从指定聊天流学习表达方式 Args: @@ -409,7 +400,6 @@ class ExpressionLearner: expressions: List[Tuple[str, str]] = self.parse_expression_response(response) # logger.debug(f"学习{type_str}的response: {response}") - # 对表达方式溯源 matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context( expressions, random_msg_match_str @@ -426,17 +416,17 @@ class ExpressionLearner: if similarity >= 0.85: # 85%相似度阈值 pos = i break - + if pos is None or pos == 0: # 没有匹配到目标句或没有上一句,跳过该表达 continue - + # 检查目标句是否为空 target_content = bare_lines[pos][1] if not target_content: # 目标句为空,跳过该表达 continue - + prev_original_idx = bare_lines[pos - 1][0] up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "") if not up_content: @@ -449,7 +439,6 @@ class ExpressionLearner: return filtered_with_up - def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 @@ -483,21 +472,21 @@ class ExpressionLearner: def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]: """ 为每条消息构建精简文本列表,保留到原消息索引的映射 - + Args: messages: 消息列表 - + Returns: List[Tuple[int, str]]: (original_index, bare_content) 元组列表 """ bare_lines: List[Tuple[int, str]] = [] - + for idx, msg in enumerate(messages): content = msg.processed_plain_text or "" content = filter_message_content(content) # 即使content为空也要记录,防止错位 bare_lines.append((idx, content)) - + return bare_lines diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py index 005bcf81..0650c954 100644 --- a/src/express/expression_selector.py +++ b/src/express/expression_selector.py @@ -1,8 +1,6 @@ import json import time -import random import hashlib -import re from typing import List, Dict, Optional, Any, Tuple from json_repair import repair_json @@ -115,30 +113,31 @@ class ExpressionSelector: return group_chat_ids return [chat_id] - def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]: + def get_model_predicted_expressions( + self, chat_id: str, target_message: str, total_num: int = 10 + ) -> List[Dict[str, Any]]: """ 使用 style_learner 模型预测最合适的表达方式 - + Args: chat_id: 聊天室ID target_message: 目标消息内容 total_num: 需要预测的数量 - + Returns: List[Dict[str, Any]]: 预测的表达方式列表 """ try: # 过滤目标消息内容,移除回复、表情包等特殊格式 filtered_target_message = filter_message_content(target_message) - + logger.info(f"为{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}") - + # 支持多chat_id合并预测 related_chat_ids = self.get_related_chat_ids(chat_id) - predicted_expressions = [] - + # 为每个相关的chat_id进行预测 for related_chat_id in related_chat_ids: try: @@ -146,59 +145,65 @@ class ExpressionSelector: best_style, scores = style_learner_manager.predict_style( related_chat_id, filtered_target_message, top_k=total_num ) - + if best_style and scores: # 获取预测风格的完整信息 learner = style_learner_manager.get_learner(related_chat_id) style_id, situation = learner.get_style_info(best_style) - + if style_id and situation: # 从数据库查找对应的表达记录 expr_query = Expression.select().where( - (Expression.chat_id == related_chat_id) & - (Expression.situation == situation) & - (Expression.style == best_style) + (Expression.chat_id == related_chat_id) + & (Expression.situation == situation) + & (Expression.style == best_style) ) - + if expr_query.exists(): expr = expr_query.get() - predicted_expressions.append({ - "id": expr.id, - "situation": expr.situation, - "style": expr.style, - "last_active_time": expr.last_active_time, - "source_id": expr.chat_id, - "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - "prediction_score": scores.get(best_style, 0.0), - "prediction_input": filtered_target_message - }) + predicted_expressions.append( + { + "id": expr.id, + "situation": expr.situation, + "style": expr.style, + "last_active_time": expr.last_active_time, + "source_id": expr.chat_id, + "create_date": expr.create_date + if expr.create_date is not None + else expr.last_active_time, + "prediction_score": scores.get(best_style, 0.0), + "prediction_input": filtered_target_message, + } + ) else: - logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式") - + logger.warning( + f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式" + ) + except Exception as e: logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}") continue - + # 按预测分数排序,取前 total_num 个 predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True) selected_expressions = predicted_expressions[:total_num] - + logger.info(f"为{chat_id} 预测到 {len(selected_expressions)} 个表达方式") return selected_expressions - + except Exception as e: logger.error(f"模型预测表达方式失败: {e}") # 如果预测失败,回退到随机选择 return self._random_expressions(chat_id, total_num) - + def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: """ 随机选择表达方式 - + Args: chat_id: 聊天室ID total_num: 需要选择的数量 - + Returns: List[Dict[str, Any]]: 随机选择的表达方式列表 """ @@ -207,9 +212,7 @@ class ExpressionSelector: related_chat_ids = self.get_related_chat_ids(chat_id) # 优化:一次性查询所有相关chat_id的表达方式 - style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) - ) + style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids))) style_exprs = [ { @@ -228,15 +231,14 @@ class ExpressionSelector: selected_style = weighted_sample(style_exprs, total_num) else: selected_style = [] - + logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式") return selected_style - + except Exception as e: logger.error(f"随机选择表达方式失败: {e}") return [] - async def select_suitable_expressions( self, chat_id: str, @@ -246,13 +248,13 @@ class ExpressionSelector: ) -> Tuple[List[Dict[str, Any]], List[int]]: """ 根据配置模式选择适合的表达方式 - + Args: chat_id: 聊天流ID chat_info: 聊天内容信息 max_num: 最大选择数量 target_message: 目标消息内容 - + Returns: Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 """ @@ -263,7 +265,7 @@ class ExpressionSelector: # 获取配置模式 expression_mode = global_config.expression.mode - + if expression_mode == "exp_model": # exp_model模式:直接使用模型预测,不经过LLM logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式") @@ -284,12 +286,12 @@ class ExpressionSelector: ) -> Tuple[List[Dict[str, Any]], List[int]]: """ exp_model模式:直接使用模型预测,不经过LLM - + Args: chat_id: 聊天流ID target_message: 目标消息内容 max_num: 最大选择数量 - + Returns: Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 """ @@ -297,14 +299,14 @@ class ExpressionSelector: # 使用模型预测最合适的表达方式 selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num) selected_ids = [expr["id"] for expr in selected_expressions] - + # 更新last_active_time if selected_expressions: self.update_expressions_last_active_time(selected_expressions) - + logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式") return selected_expressions, selected_ids - + except Exception as e: logger.error(f"exp_model模式选择表达方式失败: {e}") return [], [] @@ -318,13 +320,13 @@ class ExpressionSelector: ) -> Tuple[List[Dict[str, Any]], List[int]]: """ classic模式:随机选择+LLM选择 - + Args: chat_id: 聊天流ID chat_info: 聊天内容信息 max_num: 最大选择数量 target_message: 目标消息内容 - + Returns: Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 """ @@ -425,17 +427,13 @@ class ExpressionSelector: updates_by_key[key] = expr for chat_id, situation, style in updates_by_key: query = Expression.select().where( - (Expression.chat_id == chat_id) - & (Expression.situation == situation) - & (Expression.style == style) + (Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style) ) if query.exists(): expr_obj = query.get() expr_obj.last_active_time = time.time() expr_obj.save() - logger.debug( - "表达方式激活: 更新last_active_time in db" - ) + logger.debug("表达方式激活: 更新last_active_time in db") init_prompt() diff --git a/src/express/expressor_model/model.py b/src/express/expressor_model/model.py index d47873d9..563821e2 100644 --- a/src/express/expressor_model/model.py +++ b/src/express/expressor_model/model.py @@ -6,18 +6,21 @@ import os from .tokenizer import Tokenizer from .online_nb import OnlineNaiveBayes + class ExpressorModel: """ 直接使用朴素贝叶斯精排(可在线学习) 支持存储situation字段,不参与计算,仅与style对应 """ - def __init__(self, - alpha: float = 0.5, - beta: float = 0.5, - gamma: float = 1.0, - vocab_size: int = 200000, - use_jieba: bool = True): + def __init__( + self, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, + vocab_size: int = 200000, + use_jieba: bool = True, + ): self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) self._candidates: Dict[str, str] = {} # cid -> text (style) @@ -28,7 +31,7 @@ class ExpressorModel: self._candidates[cid] = text if situation is not None: self._situations[cid] = situation - + # 确保在nb模型中初始化该候选的计数 if cid not in self.nb.cls_counts: self.nb.cls_counts[cid] = 0.0 @@ -46,7 +49,7 @@ class ExpressorModel: toks = self.tokenizer.tokenize(text) if not toks: return None, {} - + if not self._candidates: return None, {} @@ -58,7 +61,7 @@ class ExpressorModel: # 取最高分 if not scores: return None, {} - + # 根据k参数限制返回的候选数量 if k is not None and k > 0: # 按分数降序排序,取前k个 @@ -81,40 +84,42 @@ class ExpressorModel: def decay(self, factor: float): self.nb.decay(factor=factor) - + def get_situation(self, cid: str) -> Optional[str]: """获取候选对应的situation""" return self._situations.get(cid) - + def get_style(self, cid: str) -> Optional[str]: """获取候选对应的style""" return self._candidates.get(cid) - + def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: """获取候选的style和situation信息""" return self._candidates.get(cid), self._situations.get(cid) - + def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]: """获取所有候选的style和situation信息""" - return {cid: (style, self._situations.get(cid)) - for cid, style in self._candidates.items()} + return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()} def save(self, path: str): """保存模型""" os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as f: - pickle.dump({ - "candidates": self._candidates, - "situations": self._situations, - "nb": { - "cls_counts": dict(self.nb.cls_counts), - "token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()}, - "alpha": self.nb.alpha, - "beta": self.nb.beta, - "gamma": self.nb.gamma, - "V": self.nb.V, - } - }, f) + pickle.dump( + { + "candidates": self._candidates, + "situations": self._situations, + "nb": { + "cls_counts": dict(self.nb.cls_counts), + "token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()}, + "alpha": self.nb.alpha, + "beta": self.nb.beta, + "gamma": self.nb.gamma, + "V": self.nb.V, + }, + }, + f, + ) def load(self, path: str): """加载模型""" @@ -133,9 +138,11 @@ class ExpressorModel: self.nb.V = obj["nb"]["V"] self.nb._logZ.clear() + def defaultdict_dict(d: Dict[str, Dict[str, float]]): from collections import defaultdict + outer = defaultdict(lambda: defaultdict(float)) for k, inner in d.items(): outer[k].update(inner) - return outer \ No newline at end of file + return outer diff --git a/src/express/expressor_model/online_nb.py b/src/express/expressor_model/online_nb.py index 9705043b..fff25c08 100644 --- a/src/express/expressor_model/online_nb.py +++ b/src/express/expressor_model/online_nb.py @@ -2,6 +2,7 @@ import math from typing import Dict, List from collections import defaultdict, Counter + class OnlineNaiveBayes: def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): self.alpha = alpha @@ -9,9 +10,9 @@ class OnlineNaiveBayes: self.gamma = gamma self.V = vocab_size - self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count + self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count - self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) + self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) def _invalidate(self, cid: str): if cid in self._logZ: @@ -57,4 +58,4 @@ class OnlineNaiveBayes: self.cls_counts[cid] *= g for term in list(self.token_counts[cid].keys()): self.token_counts[cid][term] *= g - self._invalidate(cid) \ No newline at end of file + self._invalidate(cid) diff --git a/src/express/expressor_model/tokenizer.py b/src/express/expressor_model/tokenizer.py index 5fd915ae..61a55950 100644 --- a/src/express/expressor_model/tokenizer.py +++ b/src/express/expressor_model/tokenizer.py @@ -3,17 +3,20 @@ from typing import List, Optional, Set try: import jieba + _HAS_JIEBA = True except Exception: _HAS_JIEBA = False _WORD_RE = re.compile(r"[A-Za-z0-9_]+") # 匹配纯符号的正则表达式 -_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$') +_SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$") + def simple_en_tokenize(text: str) -> List[str]: return _WORD_RE.findall(text.lower()) + class Tokenizer: def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True): self.stopwords = stopwords or set() @@ -28,4 +31,4 @@ class Tokenizer: else: toks = simple_en_tokenize(text) # 过滤掉纯符号和停用词 - return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)] \ No newline at end of file + return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)] diff --git a/src/express/style_learner.py b/src/express/style_learner.py index e26b121c..1a40d27b 100644 --- a/src/express/style_learner.py +++ b/src/express/style_learner.py @@ -22,42 +22,42 @@ class StyleLearner: 学习从up_content到style的映射关系 支持动态管理风格集合(无数量上限) """ - + def __init__(self, chat_id: str, model_config: Optional[Dict] = None): self.chat_id = chat_id self.model_config = model_config or { "alpha": 0.5, - "beta": 0.5, + "beta": 0.5, "gamma": 0.99, # 衰减因子,支持遗忘 "vocab_size": 200000, - "use_jieba": True + "use_jieba": True, } - + # 初始化表达模型 self.expressor = ExpressorModel(**self.model_config) - + # 动态风格管理 self.style_to_id: Dict[str, str] = {} # style文本 -> style_id self.id_to_style: Dict[str, str] = {} # style_id -> style文本 self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 self.next_style_id = 0 # 下一个可用的style_id - + # 学习统计 self.learning_stats = { "total_samples": 0, "style_counts": defaultdict(int), "last_update": None, - "style_usage_frequency": defaultdict(int) # 风格使用频率 + "style_usage_frequency": defaultdict(int), # 风格使用频率 } - + def add_style(self, style: str, situation: str = None) -> bool: """ 动态添加一个新的风格 - + Args: style: 风格文本 situation: 对应的situation文本(可选) - + Returns: bool: 添加是否成功 """ @@ -66,35 +66,37 @@ class StyleLearner: if style in self.style_to_id: logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在") return True - + # 生成新的style_id style_id = f"style_{self.next_style_id}" self.next_style_id += 1 - + # 添加到映射 self.style_to_id[style] = style_id self.id_to_style[style_id] = style if situation: self.id_to_situation[style_id] = situation - + # 添加到expressor模型 self.expressor.add_candidate(style_id, style, situation) - - logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" + - (f", situation: '{situation}'" if situation else "")) + + logger.info( + f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" + + (f", situation: '{situation}'" if situation else "") + ) return True - + except Exception as e: logger.error(f"[{self.chat_id}] 添加风格失败: {e}") return False - + def remove_style(self, style: str) -> bool: """ 删除一个风格 - + Args: style: 要删除的风格文本 - + Returns: bool: 删除是否成功 """ @@ -102,33 +104,33 @@ class StyleLearner: if style not in self.style_to_id: logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在") return False - + style_id = self.style_to_id[style] - + # 从映射中删除 del self.style_to_id[style] del self.id_to_style[style_id] if style_id in self.id_to_situation: del self.id_to_situation[style_id] - + # 从expressor模型中删除(通过重新构建) self._rebuild_expressor() - + logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})") return True - + except Exception as e: logger.error(f"[{self.chat_id}] 删除风格失败: {e}") return False - + def update_style(self, old_style: str, new_style: str) -> bool: """ 更新一个风格 - + Args: old_style: 原风格文本 new_style: 新风格文本 - + Returns: bool: 更新是否成功 """ @@ -136,37 +138,37 @@ class StyleLearner: if old_style not in self.style_to_id: logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在") return False - + if new_style in self.style_to_id and new_style != old_style: logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在") return False - + style_id = self.style_to_id[old_style] - + # 更新映射 del self.style_to_id[old_style] self.style_to_id[new_style] = style_id self.id_to_style[style_id] = new_style - + # 更新expressor模型(保留原有的situation) situation = self.id_to_situation.get(style_id) self.expressor.add_candidate(style_id, new_style, situation) - + logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'") return True - + except Exception as e: logger.error(f"[{self.chat_id}] 更新风格失败: {e}") return False - + def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int: """ 批量添加风格 - + Args: styles: 风格文本列表 situations: 对应的situation文本列表(可选) - + Returns: int: 成功添加的数量 """ @@ -175,55 +177,55 @@ class StyleLearner: situation = situations[i] if situations and i < len(situations) else None if self.add_style(style, situation): success_count += 1 - + logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功") return success_count - + def get_all_styles(self) -> List[str]: """获取所有已注册的风格""" return list(self.style_to_id.keys()) - + def get_style_count(self) -> int: """获取当前风格数量""" return len(self.style_to_id) - + def get_situation(self, style: str) -> Optional[str]: """ 获取风格对应的situation - + Args: style: 风格文本 - + Returns: Optional[str]: 对应的situation,如果不存在则返回None """ if style not in self.style_to_id: return None - + style_id = self.style_to_id[style] return self.id_to_situation.get(style_id) - + def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: """ 获取风格的完整信息 - + Args: style: 风格文本 - + Returns: Tuple[Optional[str], Optional[str]]: (style_id, situation) """ if style not in self.style_to_id: return None, None - + style_id = self.style_to_id[style] situation = self.id_to_situation.get(style_id) return style_id, situation - + def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]: """ 获取所有风格的完整信息 - + Returns: Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)} """ @@ -232,32 +234,32 @@ class StyleLearner: situation = self.id_to_situation.get(style_id) result[style] = (style_id, situation) return result - + def _rebuild_expressor(self): """重新构建expressor模型(删除风格后使用)""" try: # 重新创建expressor self.expressor = ExpressorModel(**self.model_config) - + # 重新添加所有风格和situation for style_id, style_text in self.id_to_style.items(): situation = self.id_to_situation.get(style_id) self.expressor.add_candidate(style_id, style_text, situation) - + logger.debug(f"[{self.chat_id}] 已重新构建expressor模型") - + except Exception as e: logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}") - + def learn_mapping(self, up_content: str, style: str) -> bool: """ 学习一个up_content到style的映射 如果style不存在,会自动添加 - + Args: up_content: 输入内容 style: 对应的style文本 - + Returns: bool: 学习是否成功 """ @@ -267,71 +269,71 @@ class StyleLearner: if not self.add_style(style): logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败") return False - + # 获取style_id style_id = self.style_to_id[style] - + # 使用正反馈学习 self.expressor.update_positive(up_content, style_id) - + # 更新统计 self.learning_stats["total_samples"] += 1 self.learning_stats["style_counts"][style_id] += 1 self.learning_stats["style_usage_frequency"][style] += 1 self.learning_stats["last_update"] = asyncio.get_event_loop().time() - + logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'") return True - + except Exception as e: logger.error(f"[{self.chat_id}] 学习映射失败: {e}") traceback.print_exc() return False - + def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: """ 根据up_content预测最合适的style - + Args: up_content: 输入内容 top_k: 返回前k个候选 - + Returns: Tuple[最佳style文本, 所有候选的分数] """ try: best_style_id, scores = self.expressor.predict(up_content, k=top_k) - + if best_style_id is None: return None, {} - + # 将style_id转换为style文本 best_style = self.id_to_style.get(best_style_id) - + # 转换所有分数 style_scores = {} for sid, score in scores.items(): style_text = self.id_to_style.get(sid) if style_text: style_scores[style_text] = score - + return best_style, style_scores - + except Exception as e: logger.error(f"[{self.chat_id}] 预测style失败: {e}") traceback.print_exc() return None, {} - + def decay_learning(self, factor: Optional[float] = None) -> None: """ 对学习到的知识进行衰减(遗忘) - + Args: factor: 衰减因子,None则使用配置中的gamma """ self.expressor.decay(factor) logger.debug(f"[{self.chat_id}] 执行知识衰减") - + def get_stats(self) -> Dict: """获取学习统计信息""" return { @@ -341,20 +343,20 @@ class StyleLearner: "style_counts": dict(self.learning_stats["style_counts"]), "style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]), "last_update": self.learning_stats["last_update"], - "all_styles": list(self.style_to_id.keys()) + "all_styles": list(self.style_to_id.keys()), } - + def save(self, base_path: str) -> bool: """ 保存模型到文件 - + Args: base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl """ try: os.makedirs(base_path, exist_ok=True) file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") - + # 保存模型和统计信息 save_data = { "model_config": self.model_config, @@ -362,43 +364,43 @@ class StyleLearner: "id_to_style": self.id_to_style, "id_to_situation": self.id_to_situation, "next_style_id": self.next_style_id, - "learning_stats": self.learning_stats + "learning_stats": self.learning_stats, } - + # 先保存expressor模型 expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") self.expressor.save(expressor_path) - + # 保存其他数据 with open(file_path, "wb") as f: pickle.dump(save_data, f) - + logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}") return True - + except Exception as e: logger.error(f"[{self.chat_id}] 保存模型失败: {e}") return False - + def load(self, base_path: str) -> bool: """ 从文件加载模型 - + Args: base_path: 基础路径 """ try: file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") - + if not os.path.exists(file_path) or not os.path.exists(expressor_path): logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置") return False - + # 加载其他数据 with open(file_path, "rb") as f: save_data = pickle.load(f) - + # 恢复配置和状态 self.model_config = save_data["model_config"] self.style_to_id = save_data["style_to_id"] @@ -406,14 +408,14 @@ class StyleLearner: self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本 self.next_style_id = save_data["next_style_id"] self.learning_stats = save_data["learning_stats"] - + # 重新创建expressor并加载 self.expressor = ExpressorModel(**self.model_config) self.expressor.load(expressor_path) - + logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载") return True - + except Exception as e: logger.error(f"[{self.chat_id}] 加载模型失败: {e}") return False @@ -425,156 +427,156 @@ class StyleLearnerManager: 为每个chat_id维护独立的StyleLearner实例 每个chat_id可以动态管理自己的风格集合(无数量上限) """ - + def __init__(self, model_save_path: str = "data/style_models"): self.model_save_path = model_save_path self.learners: Dict[str, StyleLearner] = {} - + # 自动保存配置 self.auto_save_interval = 300 # 5分钟 self._auto_save_task: Optional[asyncio.Task] = None - + logger.info("StyleLearnerManager 已初始化") - + def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: """ 获取或创建指定chat_id的学习器 - + Args: chat_id: 聊天室ID model_config: 模型配置,None则使用默认配置 - + Returns: StyleLearner实例 """ if chat_id not in self.learners: # 创建新的学习器 learner = StyleLearner(chat_id, model_config) - + # 尝试加载已保存的模型 learner.load(self.model_save_path) - + self.learners[chat_id] = learner logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner") - + return self.learners[chat_id] - + def add_style(self, chat_id: str, style: str) -> bool: """ 为指定chat_id添加风格 - + Args: chat_id: 聊天室ID style: 风格文本 - + Returns: bool: 添加是否成功 """ learner = self.get_learner(chat_id) return learner.add_style(style) - + def remove_style(self, chat_id: str, style: str) -> bool: """ 为指定chat_id删除风格 - + Args: chat_id: 聊天室ID style: 风格文本 - + Returns: bool: 删除是否成功 """ learner = self.get_learner(chat_id) return learner.remove_style(style) - + def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool: """ 为指定chat_id更新风格 - + Args: chat_id: 聊天室ID old_style: 原风格文本 new_style: 新风格文本 - + Returns: bool: 更新是否成功 """ learner = self.get_learner(chat_id) return learner.update_style(old_style, new_style) - + def get_chat_styles(self, chat_id: str) -> List[str]: """ 获取指定chat_id的所有风格 - + Args: chat_id: 聊天室ID - + Returns: List[str]: 风格列表 """ learner = self.get_learner(chat_id) return learner.get_all_styles() - + def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: """ 学习一个映射关系 - + Args: chat_id: 聊天室ID up_content: 输入内容 style: 对应的style - + Returns: bool: 学习是否成功 """ learner = self.get_learner(chat_id) return learner.learn_mapping(up_content, style) - + def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: """ 预测最合适的style - + Args: chat_id: 聊天室ID up_content: 输入内容 top_k: 返回前k个候选 - + Returns: Tuple[最佳style, 所有候选分数] """ learner = self.get_learner(chat_id) return learner.predict_style(up_content, top_k) - + def decay_all_learners(self, factor: Optional[float] = None) -> None: """ 对所有学习器执行衰减 - + Args: factor: 衰减因子 """ for learner in self.learners.values(): learner.decay_learning(factor) logger.info("已对所有学习器执行衰减") - + def get_all_stats(self) -> Dict[str, Dict]: """获取所有学习器的统计信息""" return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} - + def save_all_models(self) -> bool: """保存所有模型""" success_count = 0 for learner in self.learners.values(): if learner.save(self.model_save_path): success_count += 1 - + logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型") return success_count == len(self.learners) - + def load_all_models(self) -> int: """加载所有已保存的模型""" if not os.path.exists(self.model_save_path): return 0 - + loaded_count = 0 for filename in os.listdir(self.model_save_path): if filename.endswith("_style_model.pkl"): @@ -583,16 +585,16 @@ class StyleLearnerManager: if learner.load(self.model_save_path): self.learners[chat_id] = learner loaded_count += 1 - + logger.info(f"已加载 {loaded_count} 个模型") return loaded_count - + async def start_auto_save(self) -> None: """启动自动保存任务""" if self._auto_save_task is None or self._auto_save_task.done(): self._auto_save_task = asyncio.create_task(self._auto_save_loop()) logger.info("已启动自动保存任务") - + async def stop_auto_save(self) -> None: """停止自动保存任务""" if self._auto_save_task and not self._auto_save_task.done(): @@ -602,7 +604,7 @@ class StyleLearnerManager: except asyncio.CancelledError: pass logger.info("已停止自动保存任务") - + async def _auto_save_loop(self) -> None: """自动保存循环""" while True: diff --git a/src/jargon/__init__.py b/src/jargon/__init__.py index 1a60a94a..37b61644 100644 --- a/src/jargon/__init__.py +++ b/src/jargon/__init__.py @@ -3,5 +3,3 @@ from .jargon_miner import extract_and_store_jargon __all__ = [ "extract_and_store_jargon", ] - - diff --git a/src/jargon/jargon_miner.py b/src/jargon/jargon_miner.py index 3d983521..4bd44959 100644 --- a/src/jargon/jargon_miner.py +++ b/src/jargon/jargon_miner.py @@ -120,31 +120,31 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool: # 如果已完成所有推断,不再推断 if jargon_obj.is_complete: return False - + count = jargon_obj.count or 0 last_inference = jargon_obj.last_inference_count or 0 - + # 阈值列表:3,6, 10, 20, 40, 60, 100 - thresholds = [3,6, 10, 20, 40, 60, 100] - + thresholds = [3, 6, 10, 20, 40, 60, 100] + if count < thresholds[0]: return False - + # 如果count没有超过上次判定值,不需要判定 if count <= last_inference: return False - + # 找到第一个大于last_inference的阈值 next_threshold = None for threshold in thresholds: if threshold > last_inference: next_threshold = threshold break - + # 如果没有找到下一个阈值,说明已经超过100,不应该再推断 if next_threshold is None: return False - + # 检查count是否达到或超过这个阈值 return count >= next_threshold @@ -155,13 +155,13 @@ class JargonMiner: self.last_learning_time: float = time.time() # 频率控制,可按需调整 self.min_messages_for_learning: int = 20 - self.min_learning_interval: float = 30 + self.min_learning_interval: float = 30 self.llm = LLMRequest( model_set=model_config.model_task_config.utils, request_type="jargon.extract", ) - + # 初始化stream_name作为类属性,避免重复提取 chat_manager = get_chat_manager() stream_name = chat_manager.get_stream_name(self.chat_id) @@ -186,17 +186,19 @@ class JargonMiner: try: content = jargon_obj.content raw_content_str = jargon_obj.raw_content or "" - + # 解析raw_content列表 raw_content_list = [] if raw_content_str: try: - raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str + raw_content_list = ( + json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str + ) if not isinstance(raw_content_list, list): raw_content_list = [raw_content_list] if raw_content_list else [] except (json.JSONDecodeError, TypeError): raw_content_list = [raw_content_str] if raw_content_str else [] - + if not raw_content_list: logger.warning(f"jargon {content} 没有raw_content,跳过推断") return @@ -208,12 +210,12 @@ class JargonMiner: content=content, raw_content_list=raw_content_text, ) - + response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3) if not response1: logger.warning(f"jargon {content} 推断1失败:无响应") return - + # 解析推断1结果 inference1 = None try: @@ -235,12 +237,12 @@ class JargonMiner: "jargon_inference_content_only_prompt", content=content, ) - + response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3) if not response2: logger.warning(f"jargon {content} 推断2失败:无响应") return - + # 解析推断2结果 inference2 = None try: @@ -256,7 +258,7 @@ class JargonMiner: except Exception as e: logger.error(f"jargon {content} 推断2解析失败: {e}") return - + if global_config.debug.show_jargon_prompt: logger.info(f"jargon {content} 推断2提示词: {prompt2}") logger.info(f"jargon {content} 推断2结果: {response2}") @@ -264,22 +266,22 @@ class JargonMiner: logger.info(f"jargon {content} 推断1提示词: {prompt1}") logger.info(f"jargon {content} 推断1结果: {response1}") # logger.info(f"jargon {content} 推断1结果: {inference1}") - + # 步骤3: 比较两个推断结果 prompt3 = await global_prompt_manager.format_prompt( "jargon_compare_inference_prompt", inference1=json.dumps(inference1, ensure_ascii=False), inference2=json.dumps(inference2, ensure_ascii=False), ) - + if global_config.debug.show_jargon_prompt: logger.info(f"jargon {content} 比较提示词: {prompt3}") - + response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3) if not response3: logger.warning(f"jargon {content} 比较失败:无响应") return - + # 解析比较结果 comparison = None try: @@ -299,7 +301,7 @@ class JargonMiner: # 判断是否为黑话 is_similar = comparison.get("is_similar", False) is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话 - + # 更新数据库记录 jargon_obj.is_jargon = is_jargon if is_jargon: @@ -308,17 +310,19 @@ class JargonMiner: else: # 不是黑话,也记录含义(使用推断2的结果,因为含义明确) jargon_obj.meaning = inference2.get("meaning", "") - + # 更新最后一次判定的count值,避免重启后重复判定 jargon_obj.last_inference_count = jargon_obj.count or 0 - + # 如果count>=100,标记为完成,不再进行推断 if (jargon_obj.count or 0) >= 100: jargon_obj.is_complete = True - + jargon_obj.save() - logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}") - + logger.info( + f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}" + ) + # 固定输出推断结果,格式化为可读形式 if is_jargon: # 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx @@ -331,10 +335,11 @@ class JargonMiner: else: # 不是黑话,输出格式:[聊天名]xxx 不是黑话 logger.info(f"[{self.stream_name}]{content} 不是黑话") - + except Exception as e: logger.error(f"jargon推断失败: {e}") import traceback + traceback.print_exc() def should_trigger(self) -> bool: @@ -362,7 +367,7 @@ class JargonMiner: # 记录本次提取的时间窗口,避免重复提取 extraction_start_time = self.last_learning_time extraction_end_time = time.time() - + # 拉取学习窗口内的消息 messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, @@ -385,7 +390,7 @@ class JargonMiner: response, _ = await self.llm.generate_response_async(prompt, temperature=0.2) if not response: return - + if global_config.debug.show_jargon_prompt: logger.info(f"jargon提取提示词: {prompt}") logger.info(f"jargon提取结果: {response}") @@ -415,7 +420,7 @@ class JargonMiner: continue content = str(item.get("content", "")).strip() raw_content_value = item.get("raw_content", "") - + # 处理raw_content:可能是字符串或列表 raw_content_list = [] if isinstance(raw_content_value, list): @@ -426,19 +431,15 @@ class JargonMiner: raw_content_str = raw_content_value.strip() if raw_content_str: raw_content_list = [raw_content_str] - + type_str = str(item.get("type", "")).strip().lower() - + # 验证type是否为有效值 if type_str not in ["p", "c", "e"]: type_str = "p" # 默认值 - + if content and raw_content_list: - entries.append({ - "content": content, - "raw_content": raw_content_list, - "type": type_str - }) + entries.append({"content": content, "raw_content": raw_content_list, "type": type_str}) except Exception as e: logger.error(f"解析jargon JSON失败: {e}; 原始: {response}") return @@ -455,7 +456,7 @@ class JargonMiner: if content_key not in seen: seen.add(content_key) uniq_entries.append(entry) - + saved = 0 updated = 0 merged = 0 @@ -466,12 +467,8 @@ class JargonMiner: try: # 步骤1: 检查同chat_id的记录,默认纳入global项目 # 查询条件:chat_id匹配 OR (is_global为True且content匹配) - query = ( - Jargon.select() - .where( - ((Jargon.chat_id == self.chat_id) | Jargon.is_global) & - (Jargon.content == content) - ) + query = Jargon.select().where( + ((Jargon.chat_id == self.chat_id) | Jargon.is_global) & (Jargon.content == content) ) if query.exists(): obj = query.get() @@ -479,82 +476,82 @@ class JargonMiner: obj.count = (obj.count or 0) + 1 except Exception: obj.count = 1 - + # 合并raw_content列表:读取现有列表,追加新值,去重 existing_raw_content = [] if obj.raw_content: try: - existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + existing_raw_content = ( + json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + ) if not isinstance(existing_raw_content, list): existing_raw_content = [existing_raw_content] if existing_raw_content else [] except (json.JSONDecodeError, TypeError): existing_raw_content = [obj.raw_content] if obj.raw_content else [] - + # 合并并去重 merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list)) obj.raw_content = json.dumps(merged_list, ensure_ascii=False) - + # 更新type(如果为空) if type_str and not obj.type: obj.type = type_str obj.save() - + # 检查是否需要推断(达到阈值且超过上次判定值) if _should_infer_meaning(obj): # 异步触发推断,不阻塞主流程 # 重新加载对象以确保数据最新 jargon_id = obj.id asyncio.create_task(self._infer_meaning_by_id(jargon_id)) - + updated += 1 else: # 步骤2: 同chat_id没有找到,检查所有chat_id中是否有相同content的记录 # 查询所有非global的记录(global的已经在步骤1检查过了) - all_content_query = ( - Jargon.select() - .where( - (Jargon.content == content) & - (~Jargon.is_global) - ) - ) + all_content_query = Jargon.select().where((Jargon.content == content) & (~Jargon.is_global)) all_matching = list(all_content_query) - + # 如果找到3个或更多相同content的记录,合并它们 if len(all_matching) >= 3: # 找到3个或更多已有记录,合并它们(新条目也会被包含在合并中) total_count = sum((obj.count or 0) for obj in all_matching) + 1 # +1 是因为当前新条目 - + # 合并所有raw_content列表 all_raw_content = [] for obj in all_matching: if obj.raw_content: try: - obj_raw = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + obj_raw = ( + json.loads(obj.raw_content) + if isinstance(obj.raw_content, str) + else obj.raw_content + ) if not isinstance(obj_raw, list): obj_raw = [obj_raw] if obj_raw else [] all_raw_content.extend(obj_raw) except (json.JSONDecodeError, TypeError): if obj.raw_content: all_raw_content.append(obj.raw_content) - + # 添加当前新条目的raw_content all_raw_content.extend(raw_content_list) # 去重 merged_raw_content = list(dict.fromkeys(all_raw_content)) - + # 合并type:优先使用非空的值 merged_type = type_str for obj in all_matching: if obj.type and not merged_type: merged_type = obj.type break - + # 合并其他字段:优先使用已有值 merged_meaning = None merged_is_jargon = None merged_last_inference_count = None merged_is_complete = False - + for obj in all_matching: if obj.meaning and not merged_meaning: merged_meaning = obj.meaning @@ -564,11 +561,11 @@ class JargonMiner: merged_last_inference_count = obj.last_inference_count if obj.is_complete: merged_is_complete = True - + # 删除旧的记录 for obj in all_matching: obj.delete_instance() - + # 创建新的global记录 Jargon.create( content=content, @@ -580,10 +577,12 @@ class JargonMiner: meaning=merged_meaning, is_jargon=merged_is_jargon, last_inference_count=merged_last_inference_count, - is_complete=merged_is_complete + is_complete=merged_is_complete, ) merged += 1 - logger.info(f"合并jargon为global: content={content}, 合并了{len(all_matching)}条已有记录+1条新记录(共{len(all_matching)+1}条),总count={total_count}") + logger.info( + f"合并jargon为global: content={content}, 合并了{len(all_matching)}条已有记录+1条新记录(共{len(all_matching) + 1}条),总count={total_count}" + ) else: # 找到少于3个已有记录,正常创建新记录 Jargon.create( @@ -592,7 +591,7 @@ class JargonMiner: type=type_str, chat_id=self.chat_id, is_global=False, - count=1 + count=1, ) saved += 1 except Exception as e: @@ -604,15 +603,17 @@ class JargonMiner: # 收集所有提取的jargon内容 jargon_list = [entry["content"] for entry in uniq_entries] jargon_str = ",".join(jargon_list) - + # 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色) logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}") - + # 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口 self.last_learning_time = extraction_end_time - + if saved or updated or merged: - logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,合并为global {merged} 条,chat_id={self.chat_id}") + logger.info( + f"jargon写入: 新增 {saved} 条,更新 {updated} 条,合并为global {merged} 条,chat_id={self.chat_id}" + ) except Exception as e: logger.error(f"JargonMiner 运行失败: {e}") @@ -636,36 +637,29 @@ async def extract_and_store_jargon(chat_id: str) -> None: def search_jargon( - keyword: str, - chat_id: Optional[str] = None, - limit: int = 10, - case_sensitive: bool = False, - fuzzy: bool = True + keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True ) -> List[Dict[str, str]]: """ 搜索jargon,支持大小写不敏感和模糊搜索 - + Args: keyword: 搜索关键词 chat_id: 可选的聊天ID,如果提供则优先搜索该聊天或global的jargon limit: 返回结果数量限制,默认10 case_sensitive: 是否大小写敏感,默认False(不敏感) fuzzy: 是否模糊搜索,默认True(使用LIKE匹配) - + Returns: List[Dict[str, str]]: 包含content, meaning的字典列表 """ if not keyword or not keyword.strip(): return [] - + keyword = keyword.strip() - + # 构建查询 - query = Jargon.select( - Jargon.content, - Jargon.meaning - ) - + query = Jargon.select(Jargon.content, Jargon.meaning) + # 构建搜索条件 if case_sensitive: # 大小写敏感 @@ -674,7 +668,7 @@ def search_jargon( search_condition = Jargon.content.contains(keyword) else: # 精确匹配 - search_condition = (Jargon.content == keyword) + search_condition = Jargon.content == keyword else: # 大小写不敏感 if fuzzy: @@ -682,35 +676,26 @@ def search_jargon( search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) else: # 精确匹配(使用LOWER函数) - search_condition = (fn.LOWER(Jargon.content) == keyword.lower()) - + search_condition = fn.LOWER(Jargon.content) == keyword.lower() + query = query.where(search_condition) - + # 如果提供了chat_id,优先搜索该聊天或global的jargon if chat_id: - query = query.where( - (Jargon.chat_id == chat_id) | Jargon.is_global - ) - + query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global) + # 只返回有meaning的记录 - query = query.where( - (Jargon.meaning.is_null(False)) & (Jargon.meaning != "") - ) - + query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) + # 按count降序排序,优先返回出现频率高的 query = query.order_by(Jargon.count.desc()) - + # 限制结果数量 query = query.limit(limit) - + # 执行查询并返回结果 results = [] for jargon in query: - results.append({ - "content": jargon.content or "", - "meaning": jargon.meaning or "" - }) - + results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) + return results - - diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index b83c3b8f..7d74386f 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -250,7 +250,7 @@ def _build_stream_api_resp( if fr: reason = str(fr) break - + if str(reason).endswith("MAX_TOKENS"): has_visible_output = bool(resp.content and resp.content.strip()) if has_visible_output: @@ -281,8 +281,8 @@ async def _default_stream_response_handler( _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 _usage_record = None # 使用情况记录 last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk - resp = APIResponse() - + resp = APIResponse() + def _insure_buffer_closed(): if _fc_delta_buffer and not _fc_delta_buffer.closed: _fc_delta_buffer.close() @@ -298,7 +298,7 @@ async def _default_stream_response_handler( chunk, _fc_delta_buffer, _tool_calls_buffer, - resp=resp, + resp=resp, ) if chunk.usage_metadata: @@ -314,7 +314,7 @@ async def _default_stream_response_handler( _fc_delta_buffer, _tool_calls_buffer, last_resp=last_resp, - resp=resp, + resp=resp, ), _usage_record except Exception: # 确保缓冲区被关闭 diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 7b350169..c4b206c8 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -239,7 +239,7 @@ def _build_stream_api_resp( # 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出) # 保留 finish_reason 仅用于上层判断 - + if not resp.content and not resp.tool_calls: raise EmptyResponseException() @@ -293,7 +293,7 @@ async def _default_stream_response_handler( if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason: finish_reason = event.choices[0].finish_reason - + if hasattr(event, "model") and event.model and not _model_name: _model_name = event.model # 记录模型名 @@ -341,10 +341,7 @@ async def _default_stream_response_handler( model_dbg = None # 统一日志格式 - logger.info( - "模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" - % (model_dbg or "") - ) + logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (model_dbg or "")) return resp, _usage_record except Exception: @@ -387,9 +384,7 @@ def _default_normal_response_parser( raw_snippet = str(resp)[:300] except Exception: raw_snippet = "" - logger.debug( - f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}" - ) + logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}") except Exception: # 日志采集失败不应影响控制流 pass @@ -447,14 +442,11 @@ def _default_normal_response_parser( # print(resp) _model_name = resp.model # 统一日志格式 - logger.info( - "模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" - % (_model_name or "") - ) + logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (_model_name or "")) return api_response, _usage_record except Exception as e: logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") - + if not api_response.content and not api_response.tool_calls: raise EmptyResponseException() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index f161db95..48441dad 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -277,9 +277,7 @@ class LLMRequest: logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。") raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e - logger.warning( - f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}" - ) + logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}") await asyncio.sleep(api_provider.retry_interval) except NetworkConnectionError as e: @@ -289,9 +287,7 @@ class LLMRequest: logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。") raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e - logger.warning( - f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}" - ) + logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}") await asyncio.sleep(api_provider.retry_interval) except RespNotOkException as e: diff --git a/src/main.py b/src/main.py index 0515e5f1..28e9f137 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from maim_message import MessageServer from src.common.remote import TelemetryHeartBeatTask from src.manager.async_task_manager import async_task_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask + # from src.chat.utils.token_statistics import TokenStatisticsTask from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager @@ -70,9 +71,10 @@ class MainSystem: # 添加遥测心跳任务 await async_task_manager.add_task(TelemetryHeartBeatTask()) - + # 添加记忆遗忘任务 from src.chat.utils.memory_forget_task import MemoryForgetTask + await async_task_manager.add_task(MemoryForgetTask()) # 启动API服务器 @@ -106,7 +108,6 @@ class MainSystem: self.app.register_message_handler(chat_bot.message_process) self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process) - # 触发 ON_START 事件 from src.plugin_system.core.events_manager import events_manager from src.plugin_system.base.component_types import EventType diff --git a/src/memory_system/curious.py b/src/memory_system/curious.py index 044c9cdb..8e5802d8 100644 --- a/src/memory_system/curious.py +++ b/src/memory_system/curious.py @@ -16,7 +16,7 @@ class CuriousDetector: """ 好奇心检测器 - 检测聊天记录中的矛盾、冲突或需要提问的内容 """ - + def __init__(self, chat_id: str): self.chat_id = chat_id self.llm_request = LLMRequest( @@ -27,7 +27,7 @@ class CuriousDetector: self.last_detection_time: float = time.time() self.min_interval_seconds: float = 60.0 self.min_messages: int = 20 - + def should_trigger(self) -> bool: if time.time() - self.last_detection_time < self.min_interval_seconds: return False @@ -41,17 +41,17 @@ class CuriousDetector: async def detect_questions(self, recent_messages: List) -> Optional[str]: """ 检测最近消息中是否有需要提问的内容 - + Args: recent_messages: 最近的消息列表 - + Returns: Optional[str]: 如果检测到需要提问的内容,返回问题文本;否则返回None """ try: if not recent_messages or len(recent_messages) < 2: return None - + # 构建聊天内容 chat_content_block, _ = build_readable_messages_with_id( messages=recent_messages, @@ -60,9 +60,9 @@ class CuriousDetector: truncate=True, show_actions=True, ) - + # 问题跟踪功能已移除,不再检查已有问题 - + # 构建检测提示词 prompt = f"""你是一个严谨的聊天内容分析器。请分析以下聊天记录,检测是否存在需要提问的内容。 @@ -98,20 +98,20 @@ class CuriousDetector: logger.debug("已发送好奇心检测提示词") result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.3) - + logger.info(f"好奇心检测提示词: {prompt}") logger.info(f"好奇心检测结果: {result_text}") - + if not result_text: return None - + result_text = result_text.strip() - + # 检查是否输出NO if result_text.upper() == "NO": logger.debug("未检测到需要提问的内容") return None - + # 尝试解析JSON try: questions, reasoning = parse_md_json(result_text) @@ -119,7 +119,7 @@ class CuriousDetector: question_data = questions[0] question = question_data.get("question", "") reason = question_data.get("reason", "") - + if question and question.strip(): logger.info(f"检测到需要提问的内容: {question}") logger.info(f"提问理由: {reason}") @@ -127,32 +127,32 @@ class CuriousDetector: except Exception as e: logger.warning(f"解析问题JSON失败: {e}") logger.debug(f"原始响应: {result_text}") - + return None - + except Exception as e: logger.error(f"好奇心检测失败: {e}") return None - + async def make_question_from_detection(self, question: str, context: str = "") -> bool: """ 将检测到的问题记录(已移除冲突追踪器功能) - + Args: question: 检测到的问题 context: 问题上下文 - + Returns: bool: 是否成功记录 """ try: if not question or not question.strip(): return False - + # 冲突追踪器功能已移除 logger.info(f"检测到问题(冲突追踪器已移除): {question}") return True - + except Exception as e: logger.error(f"记录问题失败: {e}") return False @@ -174,11 +174,11 @@ curious_manager = CuriousManager() async def check_and_make_question(chat_id: str) -> bool: """ 检查聊天记录并生成问题(如果检测到需要提问的内容) - + Args: chat_id: 聊天ID recent_messages: 最近的消息列表 - + Returns: bool: 是否检测到并记录了问题 """ @@ -199,7 +199,7 @@ async def check_and_make_question(chat_id: str) -> bool: # 检测是否需要提问 question = await detector.detect_questions(recent_messages) - + if question: # 记录问题 success = await detector.make_question_from_detection(question) @@ -207,9 +207,9 @@ async def check_and_make_question(chat_id: str) -> bool: logger.info(f"成功检测并记录问题: {question}") detector.last_detection_time = time.time() return True - + return False - + except Exception as e: logger.error(f"检查并生成问题失败: {e}") return False diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 7e9c8ee0..cd9a262c 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -19,7 +19,7 @@ def init_memory_retrieval_prompt(): """初始化记忆检索相关的 prompt 模板和工具""" # 首先注册所有工具 init_all_tools() - + # 第一步:问题生成prompt Prompt( """ @@ -63,7 +63,7 @@ def init_memory_retrieval_prompt(): """, name="memory_retrieval_question_prompt", ) - + # 第二步:ReAct Agent prompt(工具描述会在运行时动态生成) Prompt( """ @@ -105,10 +105,10 @@ def init_memory_retrieval_prompt(): def _parse_react_response(response: str) -> Optional[Dict[str, Any]]: """解析ReAct Agent的响应 - + Args: response: LLM返回的响应 - + Returns: Dict[str, Any]: 解析后的动作信息,如果解析失败返回None 格式: {"thought": str, "actions": List[Dict[str, Any]]} @@ -118,58 +118,55 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]: # 尝试提取JSON(可能包含在```json代码块中) json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) - + if matches: json_str = matches[0] else: # 尝试直接解析整个响应 json_str = response.strip() - + # 修复可能的JSON错误 repaired_json = repair_json(json_str) - + # 解析JSON action_info = json.loads(repaired_json) - + if not isinstance(action_info, dict): logger.warning(f"解析的JSON不是对象格式: {action_info}") return None - + # 确保actions字段存在且为列表 if "actions" not in action_info: logger.warning(f"响应中缺少actions字段: {action_info}") return None - + if not isinstance(action_info["actions"], list): logger.warning(f"actions字段不是数组格式: {action_info['actions']}") return None - + # 确保actions不为空 if len(action_info["actions"]) == 0: logger.warning("actions数组为空") return None - + return action_info - + except Exception as e: logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...") return None async def _react_agent_solve_question( - question: str, - chat_id: str, - max_iterations: int = 5, - timeout: float = 30.0 + question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0 ) -> Tuple[bool, str, List[Dict[str, Any]], bool]: """使用ReAct架构的Agent来解决问题 - + Args: question: 要回答的问题 chat_id: 聊天ID max_iterations: 最大迭代次数 timeout: 超时时间(秒) - + Returns: Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时) """ @@ -177,26 +174,26 @@ async def _react_agent_solve_question( collected_info = "" thinking_steps = [] is_timeout = False - + for iteration in range(max_iterations): # 检查超时 if time.time() - start_time > timeout: logger.warning(f"ReAct Agent超时,已迭代{iteration}次") is_timeout = True break - + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}") logger.info(f"ReAct Agent 已收集信息: {collected_info if collected_info else '暂无信息'}") - + # 获取工具注册器 tool_registry = get_tool_registry() - + # 获取bot_name bot_name = global_config.bot.nickname - + # 获取当前时间 time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - + # 构建prompt(动态生成工具描述) prompt = await global_prompt_manager.format_prompt( "memory_retrieval_react_prompt", @@ -207,44 +204,39 @@ async def _react_agent_solve_question( tools_description=tool_registry.get_tools_description(), action_types_list=tool_registry.get_action_types_list(), ) - + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 Prompt: {prompt}") - + # 调用LLM success, response, reasoning_content, model_name = await llm_api.generate_with_model( prompt, model_config=model_config.model_task_config.tool_use, request_type="memory.react", ) - + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM响应: {response}") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM推理: {reasoning_content}") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM模型: {model_name}") - + if not success: logger.error(f"ReAct Agent LLM调用失败: {response}") break - + # 解析响应 action_info = _parse_react_response(response) if not action_info: logger.warning(f"无法解析ReAct响应,迭代{iteration + 1}") break - + thought = action_info.get("thought", "") actions = action_info.get("actions", []) - + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考: {thought}") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作数量: {len(actions)}") - + # 记录思考步骤(包含所有actions) - step = { - "iteration": iteration + 1, - "thought": thought, - "actions": actions, - "observations": [] - } - + step = {"iteration": iteration + 1, "thought": thought, "actions": actions, "observations": []} + # 检查是否有final_answer或no_answer for action in actions: action_type = action.get("action_type", "") @@ -265,29 +257,32 @@ async def _react_agent_solve_question( thinking_steps.append(step) logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}") return False, answer, thinking_steps, False - + # 并行执行所有工具 tool_registry = get_tool_registry() tool_tasks = [] - + for i, action in enumerate(actions): action_type = action.get("action_type", "") action_params = action.get("action_params", {}) - - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1}/{len(actions)}: {action_type}({action_params})") - + + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1}/{len(actions)}: {action_type}({action_params})" + ) + tool = tool_registry.get_tool(action_type) - + if tool: # 准备工具参数(需要添加chat_id如果工具需要) tool_params = action_params.copy() - + # 如果工具函数签名需要chat_id,添加它 import inspect + sig = inspect.signature(tool.execute_func) if "chat_id" in sig.parameters: tool_params["chat_id"] = chat_id - + # 创建异步任务 async def execute_single_tool(tool_instance, params, act_type, act_params, iter_num): try: @@ -298,34 +293,36 @@ async def _react_agent_solve_question( error_msg = f"工具执行失败: {str(e)}" logger.error(f"ReAct Agent 第 {iter_num + 1} 次迭代 动作 {act_type} {error_msg}") return f"查询{act_type}失败: {error_msg}" - + tool_tasks.append(execute_single_tool(tool, tool_params, action_type, action_params, iteration)) else: error_msg = f"未知的工具类型: {action_type}" - logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1}/{len(actions)} {error_msg}") + logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1}/{len(actions)} {error_msg}") tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{action_type}失败: {error_msg}"))) - + # 并行执行所有工具 if tool_tasks: observations = await asyncio.gather(*tool_tasks, return_exceptions=True) - + # 处理执行结果 for i, observation in enumerate(observations): if isinstance(observation, Exception): observation = f"工具执行异常: {str(observation)}" - logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1} 执行异常: {observation}") - + logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1} 执行异常: {observation}") + step["observations"].append(observation) collected_info += f"\n{observation}\n" - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1} 执行结果: {observation}") - + logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1} 执行结果: {observation}") + thinking_steps.append(step) - + # 达到最大迭代次数或超时,但Agent没有明确返回final_answer # 迭代超时应该直接视为no_answer,而不是使用已有信息 # 只有Agent明确返回final_answer时,才认为找到了答案 if collected_info: - logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}...") + logger.warning( + f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}..." + ) if is_timeout: logger.warning("ReAct Agent超时,直接视为no_answer") else: @@ -335,35 +332,32 @@ async def _react_agent_solve_question( def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str: """获取最近一段时间内的查询历史 - + Args: chat_id: 聊天ID time_window_seconds: 时间窗口(秒),默认10分钟 - + Returns: str: 格式化的查询历史字符串 """ try: current_time = time.time() start_time = current_time - time_window_seconds - + # 查询最近时间窗口内的记录,按更新时间倒序 records = ( ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.update_time >= start_time) - ) + .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time)) .order_by(ThinkingBack.update_time.desc()) .limit(5) # 最多返回5条最近的记录 ) - + if not records.exists(): return "" - + history_lines = [] history_lines.append("最近已查询的问题和结果:") - + for record in records: status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" answer_preview = "" @@ -373,15 +367,15 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) answer_preview = record.answer[:100] if len(record.answer) > 100: answer_preview += "..." - + history_lines.append(f"- 问题:{record.question}") history_lines.append(f" 状态:{status}") if answer_preview: history_lines.append(f" 答案:{answer_preview}") history_lines.append("") # 空行分隔 - + return "\n".join(history_lines) - + except Exception as e: logger.error(f"获取查询历史失败: {e}") return "" @@ -389,40 +383,40 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> List[str]: """获取最近一段时间内缓存的记忆(只返回找到答案的记录) - + Args: chat_id: 聊天ID time_window_seconds: 时间窗口(秒),默认300秒(5分钟) - + Returns: List[str]: 格式化的记忆列表,每个元素格式为 "问题:xxx\n答案:xxx" """ try: current_time = time.time() start_time = current_time - time_window_seconds - + # 查询最近时间窗口内找到答案的记录,按更新时间倒序 records = ( ThinkingBack.select() .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.update_time >= start_time) & - (ThinkingBack.found_answer == 1) + (ThinkingBack.chat_id == chat_id) + & (ThinkingBack.update_time >= start_time) + & (ThinkingBack.found_answer == 1) ) .order_by(ThinkingBack.update_time.desc()) .limit(5) # 最多返回5条最近的记录 ) - + if not records.exists(): return [] - + cached_memories = [] for record in records: if record.answer: cached_memories.append(f"问题:{record.question}\n答案:{record.answer}") - + return cached_memories - + except Exception as e: logger.error(f"获取缓存记忆失败: {e}") return [] @@ -430,11 +424,11 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]: """从thinking_back数据库中查询是否有现成的答案 - + Args: chat_id: 聊天ID question: 问题 - + Returns: Optional[Tuple[bool, str]]: 如果找到记录,返回(found_answer, answer),否则返回None found_answer: 是否找到答案(True表示found_answer=1,False表示found_answer=0) @@ -445,38 +439,30 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st # 按更新时间倒序,获取最新的记录 records = ( ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.question == question) - ) + .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question)) .order_by(ThinkingBack.update_time.desc()) .limit(1) ) - + if records.exists(): record = records.get() found_answer = bool(record.found_answer) answer = record.answer or "" logger.info(f"在thinking_back中找到记录,问题: {question[:50]}...,found_answer: {found_answer}") return found_answer, answer - + return None - + except Exception as e: logger.error(f"查询thinking_back失败: {e}") return None def _store_thinking_back( - chat_id: str, - question: str, - context: str, - found_answer: bool, - answer: str, - thinking_steps: List[Dict[str, Any]] + chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]] ) -> None: """存储或更新思考过程到数据库(如果已存在则更新,否则创建) - + Args: chat_id: 聊天ID question: 问题 @@ -487,18 +473,15 @@ def _store_thinking_back( """ try: now = time.time() - + # 先查询是否已存在相同chat_id和问题的记录 existing = ( ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.question == question) - ) + .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question)) .order_by(ThinkingBack.update_time.desc()) .limit(1) ) - + if existing.exists(): # 更新现有记录 record = existing.get() @@ -519,37 +502,33 @@ def _store_thinking_back( answer=answer, thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), create_time=now, - update_time=now + update_time=now, ) logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...") except Exception as e: logger.error(f"存储思考过程失败: {e}") -async def _process_single_question( - question: str, - chat_id: str, - context: str -) -> Optional[str]: +async def _process_single_question(question: str, chat_id: str, context: str) -> Optional[str]: """处理单个问题的查询(包含缓存检查逻辑) - + Args: question: 要查询的问题 chat_id: 聊天ID context: 上下文信息 - + Returns: Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None """ logger.info(f"开始处理问题: {question}") - + # 先检查thinking_back数据库中是否有现成答案 cached_result = _query_thinking_back(chat_id, question) should_requery = False - + if cached_result: cached_found_answer, cached_answer = cached_result - + # 根据found_answer的值决定是否重新查询 if cached_found_answer: # found_answer == 1 (True) # found_answer == 1:20%概率重新查询 @@ -561,7 +540,7 @@ async def _process_single_question( if random.random() < 0.4: should_requery = True logger.info(f"found_answer=0,触发40%概率重新查询,问题: {question[:50]}...") - + # 如果不需要重新查询,使用缓存答案 if not should_requery: if cached_answer: @@ -570,21 +549,18 @@ async def _process_single_question( else: # 缓存中没有答案,需要查询 should_requery = True - + # 如果没有缓存答案或需要重新查询,使用ReAct Agent查询 if not cached_result or should_requery: if should_requery: logger.info(f"概率触发重新查询,使用ReAct Agent查询,问题: {question[:50]}...") else: logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...") - + found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( - question=question, - chat_id=chat_id, - max_iterations=5, - timeout=120.0 + question=question, chat_id=chat_id, max_iterations=5, timeout=120.0 ) - + # 存储到数据库(超时时不存储) if not is_timeout: _store_thinking_back( @@ -593,14 +569,14 @@ async def _process_single_question( context=context, found_answer=found_answer, answer=answer, - thinking_steps=thinking_steps + thinking_steps=thinking_steps, ) else: logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...") - + if found_answer and answer: return f"问题:{question}\n答案:{answer}" - + return None @@ -613,30 +589,30 @@ async def build_memory_retrieval_prompt( ) -> str: """构建记忆检索提示 使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案 - + Args: message: 聊天历史记录 sender: 发送者名称 target: 目标消息内容 chat_stream: 聊天流对象 tool_executor: 工具执行器(保留参数以兼容接口) - + Returns: str: 记忆检索结果字符串 """ start_time = time.time() - + logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}") try: time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) bot_name = global_config.bot.nickname chat_id = chat_stream.stream_id - + # 获取最近查询历史(最近1小时内的查询) recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0) if not recent_query_history: recent_query_history = "最近没有查询记录。" - + # 第一步:生成问题 question_prompt = await global_prompt_manager.format_prompt( "memory_retrieval_question_prompt", @@ -647,55 +623,52 @@ async def build_memory_retrieval_prompt( sender=sender, target_message=target, ) - + success, response, reasoning_content, model_name = await llm_api.generate_with_model( question_prompt, model_config=model_config.model_task_config.tool_use, request_type="memory.question", ) - + logger.info(f"记忆检索问题生成提示词: {question_prompt}") logger.info(f"记忆检索问题生成响应: {response}") - + if not success: logger.error(f"LLM生成问题失败: {response}") return "" - + # 解析问题列表 questions = _parse_questions_json(response) - + # 获取缓存的记忆(与question时使用相同的时间窗口和数量限制) cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0) - + if not questions: logger.debug("模型认为不需要检索记忆或解析失败") # 即使没有当次查询,也返回缓存的记忆 if cached_memories: retrieved_memory = "\n\n".join(cached_memories) end_time = time.time() - logger.info(f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆") + logger.info( + f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆" + ) return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" else: return "" - + logger.info(f"解析到 {len(questions)} 个问题: {questions}") - + # 第二步:并行处理所有问题(固定使用5次迭代/120秒超时) logger.info(f"问题数量: {len(questions)},固定设置最大迭代次数: 5,超时时间: 120秒") - + # 并行处理所有问题 question_tasks = [ - _process_single_question( - question=question, - chat_id=chat_id, - context=message - ) - for question in questions + _process_single_question(question=question, chat_id=chat_id, context=message) for question in questions ] - + # 并行执行所有查询任务 results = await asyncio.gather(*question_tasks, return_exceptions=True) - + # 收集所有有效结果 all_results = [] current_questions = set() # 用于去重,避免缓存和当次查询重复 @@ -708,7 +681,7 @@ async def build_memory_retrieval_prompt( if result.startswith("问题:"): question = result.split("\n")[0].replace("问题:", "").strip() current_questions.add(question) - + # 将缓存的记忆添加到结果中(排除当次查询已包含的问题,避免重复) for cached_memory in cached_memories: if cached_memory.startswith("问题:"): @@ -717,17 +690,19 @@ async def build_memory_retrieval_prompt( if question not in current_questions: all_results.append(cached_memory) logger.debug(f"添加缓存记忆: {question[:50]}...") - + end_time = time.time() - + if all_results: retrieved_memory = "\n\n".join(all_results) - logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)") + logger.info( + f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)" + ) return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" else: logger.debug("所有问题均未找到答案,且无缓存记忆") return "" - + except Exception as e: logger.error(f"记忆检索时发生异常: {str(e)}") return "" @@ -735,10 +710,10 @@ async def build_memory_retrieval_prompt( def _parse_questions_json(response: str) -> List[str]: """解析问题JSON - + Args: response: LLM返回的响应 - + Returns: List[str]: 问题列表 """ @@ -746,28 +721,28 @@ def _parse_questions_json(response: str) -> List[str]: # 尝试提取JSON(可能包含在```json代码块中) json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) - + if matches: json_str = matches[0] else: # 尝试直接解析整个响应 json_str = response.strip() - + # 修复可能的JSON错误 repaired_json = repair_json(json_str) - + # 解析JSON questions = json.loads(repaired_json) - + if not isinstance(questions, list): logger.warning(f"解析的JSON不是数组格式: {questions}") return [] - + # 确保所有元素都是字符串 questions = [q for q in questions if isinstance(q, str) and q.strip()] - + return questions - + except Exception as e: logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") return [] diff --git a/src/memory_system/memory_utils.py b/src/memory_system/memory_utils.py index 59c4a143..4f9e2f98 100644 --- a/src/memory_system/memory_utils.py +++ b/src/memory_system/memory_utils.py @@ -3,6 +3,7 @@ 记忆系统工具函数 包含模糊查找、相似度计算等工具函数 """ + import json import re from difflib import SequenceMatcher @@ -12,6 +13,7 @@ from src.common.logger import get_logger logger = get_logger("memory_utils") + def parse_md_json(json_text: str) -> list[str]: """从Markdown格式的内容中提取JSON对象和推理内容""" json_objects = [] @@ -50,14 +52,15 @@ def parse_md_json(json_text: str) -> list[str]: return json_objects, reasoning_content + def calculate_similarity(text1: str, text2: str) -> float: """ 计算两个文本的相似度 - + Args: text1: 第一个文本 text2: 第二个文本 - + Returns: float: 相似度分数 (0-1) """ @@ -65,16 +68,16 @@ def calculate_similarity(text1: str, text2: str) -> float: # 预处理文本 text1 = preprocess_text(text1) text2 = preprocess_text(text2) - + # 使用SequenceMatcher计算相似度 similarity = SequenceMatcher(None, text1, text2).ratio() - + # 如果其中一个文本包含另一个,提高相似度 if text1 in text2 or text2 in text1: similarity = max(similarity, 0.8) - + return similarity - + except Exception as e: logger.error(f"计算相似度时出错: {e}") return 0.0 @@ -83,26 +86,25 @@ def calculate_similarity(text1: str, text2: str) -> float: def preprocess_text(text: str) -> str: """ 预处理文本,提高匹配准确性 - + Args: text: 原始文本 - + Returns: str: 预处理后的文本 """ try: # 转换为小写 text = text.lower() - + # 移除标点符号和特殊字符 - text = re.sub(r'[^\w\s]', '', text) - + text = re.sub(r"[^\w\s]", "", text) + # 移除多余空格 - text = re.sub(r'\s+', ' ', text).strip() - + text = re.sub(r"\s+", " ", text).strip() + return text - + except Exception as e: logger.error(f"预处理文本时出错: {e}") return text - diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index f95ee266..fd97e95e 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -14,20 +14,16 @@ from .tool_utils import parse_datetime_to_timestamp, parse_time_range logger = get_logger("memory_retrieval_tools") -async def query_chat_history( - chat_id: str, - keyword: Optional[str] = None, - time_range: Optional[str] = None -) -> str: +async def query_chat_history(chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None) -> str: """根据时间或关键词在chat_history表中查询聊天记录概述 - + Args: chat_id: 聊天ID keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔) time_range: 时间范围或时间点,格式: - 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS" - 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录) - + Returns: str: 查询结果 """ @@ -35,10 +31,10 @@ async def query_chat_history( # 检查参数 if not keyword and not time_range: return "未指定查询参数(需要提供keyword或time_range之一)" - + # 构建查询条件 query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) - + # 时间过滤条件 if time_range: # 判断是时间点还是时间范围 @@ -46,73 +42,71 @@ async def query_chat_history( # 时间范围:查询与时间范围有交集的记录 start_timestamp, end_timestamp = parse_time_range(time_range) # 交集条件:start_time < end_timestamp AND end_time > start_timestamp - time_filter = ( - (ChatHistory.start_time < end_timestamp) & - (ChatHistory.end_time > start_timestamp) - ) + time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp) else: # 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time) target_timestamp = parse_datetime_to_timestamp(time_range) - time_filter = ( - (ChatHistory.start_time <= target_timestamp) & - (ChatHistory.end_time >= target_timestamp) - ) + time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp) query = query.where(time_filter) - + # 执行查询 records = list(query.order_by(ChatHistory.start_time.desc()).limit(50)) - + if not records: return "未找到相关聊天记录概述" - + # 如果有关键词,进一步过滤 if keyword: # 解析多个关键词(支持空格、逗号等分隔符) keywords_list = parse_keywords_string(keyword) if not keywords_list: keywords_list = [keyword.strip()] if keyword.strip() else [] - + # 转换为小写以便匹配 keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()] - + if not keywords_lower: return "关键词为空" - + filtered_records = [] - + for record in records: # 在theme、keywords、summary、original_text中搜索 theme = (record.theme or "").lower() summary = (record.summary or "").lower() original_text = (record.original_text or "").lower() - + # 解析record中的keywords JSON record_keywords_list = [] if record.keywords: try: - keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords + keywords_data = ( + json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords + ) if isinstance(keywords_data, list): record_keywords_list = [str(k).lower() for k in keywords_data] except (json.JSONDecodeError, TypeError, ValueError): pass - + # 检查是否包含任意一个关键词(OR关系) matched = False for kw in keywords_lower: - if (kw in theme or - kw in summary or - kw in original_text or - any(kw in k for k in record_keywords_list)): + if ( + kw in theme + or kw in summary + or kw in original_text + or any(kw in k for k in record_keywords_list) + ): matched = True break - + if matched: filtered_records.append(record) - + if not filtered_records: keywords_str = "、".join(keywords_list) return f"未找到包含关键词'{keywords_str}'的聊天记录概述" - + records = filtered_records # 对即将返回的记录增加使用计数 @@ -123,22 +117,23 @@ async def query_chat_history( record.count = (record.count or 0) + 1 except Exception as update_error: logger.error(f"更新聊天记录概述计数失败: {update_error}") - + # 构建结果文本 results = [] for record in records_to_use: # 最多返回3条记录 result_parts = [] - + # 添加主题 if record.theme: result_parts.append(f"主题:{record.theme}") - + # 添加时间范围 from datetime import datetime + start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S") end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S") result_parts.append(f"时间:{start_str} - {end_str}") - + # 添加概括(优先使用summary,如果没有则使用original_text的前200字符) if record.summary: result_parts.append(f"概括:{record.summary}") @@ -147,18 +142,18 @@ async def query_chat_history( if len(record.original_text) > 200: text_preview += "..." result_parts.append(f"内容:{text_preview}") - + results.append("\n".join(result_parts)) - + if not results: return "未找到相关聊天记录概述" - + response_text = "\n\n---\n\n".join(results) if len(records) > len(records_to_use): omitted_count = len(records) - len(records_to_use) response_text += f"\n\n(还有{omitted_count}条历史记录已省略)" return response_text - + except Exception as e: logger.error(f"查询聊天历史概述失败: {e}") return f"查询失败: {str(e)}" @@ -174,14 +169,14 @@ def register_tool(): "name": "keyword", "type": "string", "description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)", - "required": False + "required": False, }, { "name": "time_range", "type": "string", "description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)", - "required": False - } + "required": False, + }, ], - execute_func=query_chat_history + execute_func=query_chat_history, ) diff --git a/src/memory_system/retrieval_tools/query_jargon.py b/src/memory_system/retrieval_tools/query_jargon.py index 0f007f8c..65050b85 100644 --- a/src/memory_system/retrieval_tools/query_jargon.py +++ b/src/memory_system/retrieval_tools/query_jargon.py @@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool logger = get_logger("memory_retrieval_tools") -async def query_jargon( - keyword: str, - chat_id: str -) -> str: +async def query_jargon(keyword: str, chat_id: str) -> str: """根据关键词在jargon库中查询 - + Args: keyword: 关键词(黑话/俚语/缩写) chat_id: 聊天ID - + Returns: str: 查询结果 """ @@ -26,29 +23,17 @@ async def query_jargon( content = str(keyword).strip() if not content: return "关键词为空" - + # 先尝试精确匹配 - results = search_jargon( - keyword=content, - chat_id=chat_id, - limit=10, - case_sensitive=False, - fuzzy=False - ) - + results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False) + is_fuzzy_match = False - + # 如果精确匹配未找到,尝试模糊搜索 if not results: - results = search_jargon( - keyword=content, - chat_id=chat_id, - limit=10, - case_sensitive=False, - fuzzy=True - ) + results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True) is_fuzzy_match = True - + if results: # 如果是模糊匹配,显示找到的实际jargon内容 if is_fuzzy_match: @@ -71,11 +56,11 @@ async def query_jargon( output = ";".join(output_parts) if len(output_parts) > 1 else output_parts[0] logger.info(f"在jargon库中找到匹配(当前会话或全局,精确匹配): {content},找到{len(results)}条结果") return output - + # 未命中 logger.info(f"在jargon库中未找到匹配(当前会话或全局,精确匹配和模糊搜索都未找到): {content}") return f"未在jargon库中找到'{content}'的解释" - + except Exception as e: logger.error(f"查询jargon失败: {e}") return f"查询失败: {str(e)}" @@ -86,14 +71,6 @@ def register_tool(): register_memory_retrieval_tool( name="query_jargon", description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。", - parameters=[ - { - "name": "keyword", - "type": "string", - "description": "关键词(黑话/俚语/缩写)", - "required": True - } - ], - execute_func=query_jargon + parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}], + execute_func=query_jargon, ) - diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py index 920a1bb6..8e5503a5 100644 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ b/src/memory_system/retrieval_tools/tool_registry.py @@ -11,17 +11,13 @@ logger = get_logger("memory_retrieval_tools") class MemoryRetrievalTool: """记忆检索工具基类""" - + def __init__( - self, - name: str, - description: str, - parameters: List[Dict[str, Any]], - execute_func: Callable[..., Awaitable[str]] + self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]] ): """ 初始化工具 - + Args: name: 工具名称 description: 工具描述 @@ -32,7 +28,7 @@ class MemoryRetrievalTool: self.description = description self.parameters = parameters self.execute_func = execute_func - + def get_tool_description(self) -> str: """获取工具的文本描述,用于prompt""" param_descriptions = [] @@ -43,10 +39,10 @@ class MemoryRetrievalTool: required = param.get("required", True) required_str = "必填" if required else "可选" param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}") - + params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" - + async def execute(self, **kwargs) -> str: """执行工具""" return await self.execute_func(**kwargs) @@ -54,30 +50,30 @@ class MemoryRetrievalTool: class MemoryRetrievalToolRegistry: """工具注册器""" - + def __init__(self): self.tools: Dict[str, MemoryRetrievalTool] = {} - + def register_tool(self, tool: MemoryRetrievalTool) -> None: """注册工具""" self.tools[tool.name] = tool logger.info(f"注册记忆检索工具: {tool.name}") - + def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]: """获取工具""" return self.tools.get(name) - + def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]: """获取所有工具""" return self.tools.copy() - + def get_tools_description(self) -> str: """获取所有工具的描述,用于prompt""" descriptions = [] for i, tool in enumerate(self.tools.values(), 1): descriptions.append(f"{i}. {tool.get_tool_description()}") return "\n".join(descriptions) - + def get_action_types_list(self) -> str: """获取所有动作类型的列表,用于prompt""" action_types = [tool.name for tool in self.tools.values()] @@ -91,13 +87,10 @@ _tool_registry = MemoryRetrievalToolRegistry() def register_memory_retrieval_tool( - name: str, - description: str, - parameters: List[Dict[str, Any]], - execute_func: Callable[..., Awaitable[str]] + name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]] ) -> None: """注册记忆检索工具的便捷函数 - + Args: name: 工具名称 description: 工具描述 @@ -111,4 +104,3 @@ def register_memory_retrieval_tool( def get_tool_registry() -> MemoryRetrievalToolRegistry: """获取工具注册器实例""" return _tool_registry - diff --git a/src/memory_system/retrieval_tools/tool_utils.py b/src/memory_system/retrieval_tools/tool_utils.py index d0ca334f..be98c72d 100644 --- a/src/memory_system/retrieval_tools/tool_utils.py +++ b/src/memory_system/retrieval_tools/tool_utils.py @@ -40,25 +40,24 @@ def parse_datetime_to_timestamp(value: str) -> float: def parse_time_range(time_range: str) -> Tuple[float, float]: """ 解析时间范围字符串,返回开始和结束时间戳 - + Args: time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS" - + Returns: Tuple[float, float]: (开始时间戳, 结束时间戳) """ if " - " not in time_range: raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}") - + parts = time_range.split(" - ", 1) if len(parts) != 2: raise ValueError(f"时间范围格式错误: {time_range}") - + start_str = parts[0].strip() end_str = parts[1].strip() - + start_timestamp = parse_datetime_to_timestamp(start_str) end_timestamp = parse_datetime_to_timestamp(end_str) - - return start_timestamp, end_timestamp + return start_timestamp, end_timestamp diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index d64c2d9d..ad6a1ce9 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -1,10 +1,7 @@ -import math -import random import time from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive diff --git a/src/plugin_system/apis/frequency_api.py b/src/plugin_system/apis/frequency_api.py index 47b3a95f..bc906186 100644 --- a/src/plugin_system/apis/frequency_api.py +++ b/src/plugin_system/apis/frequency_api.py @@ -6,7 +6,9 @@ logger = get_logger("frequency_api") def get_current_talk_value(chat_id: str) -> float: - return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id) + return frequency_control_manager.get_or_create_frequency_control( + chat_id + ).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id) def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 8a2f8389..cfacd558 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -109,7 +109,7 @@ def get_messages_by_time_in_chat( limit=limit, limit_mode=limit_mode, filter_bot=filter_mai, - filter_command=filter_command + filter_command=filter_command, ) diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index 03a563f6..bc0b32f0 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -12,11 +12,11 @@ logger = get_logger("tool_api") def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]: """获取公开工具实例 - + Args: tool_name: 工具名称 chat_stream: 聊天流对象,用于传递聊天上下文信息 - + Returns: Optional[BaseTool]: 工具实例,如果未找到则返回None """ diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 4e55a945..769bce9d 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -77,7 +77,7 @@ class BaseAction(ABC): self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy() """NORMAL模式下的激活类型""" - self.activation_type = getattr(self.__class__, "activation_type") + self.activation_type = self.__class__.activation_type """激活类型""" self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0) """当激活类型为RANDOM时的概率""" @@ -108,21 +108,16 @@ class BaseAction(ABC): self.is_group = False self.target_id = None - self.group_id = ( - str(self.action_message.chat_info.group_info.group_id) - if self.action_message.chat_info.group_info - else None + str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None ) self.group_name = ( - self.action_message.chat_info.group_info.group_name - if self.action_message.chat_info.group_info - else None + self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None ) self.user_id = str(self.action_message.user_info.user_id) self.user_nickname = self.action_message.user_info.user_nickname - + if self.group_id: self.is_group = True self.target_id = self.group_id @@ -132,7 +127,6 @@ class BaseAction(ABC): self.target_id = self.user_id self.log_prefix = f"[{self.user_nickname} 的 私聊]" - logger.debug( f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" ) @@ -448,7 +442,6 @@ class BaseAction(ABC): wait_start_time = asyncio.get_event_loop().time() while True: - # 检查新消息 current_time = time.time() new_message_count = message_api.count_new_messages( @@ -497,7 +490,7 @@ class BaseAction(ABC): raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") # 获取focus_activation_type和normal_activation_type focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS) - normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) + _normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) # 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type activation_type = getattr(cls, "activation_type", focus_activation_type) diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 072d68b1..71d55101 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -34,17 +34,17 @@ class BaseTool(ABC): def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None): """初始化工具基类 - + Args: plugin_config: 插件配置字典 chat_stream: 聊天流对象,用于获取聊天上下文信息 """ self.plugin_config = plugin_config or {} # 直接存储插件配置字典 - + # ============================================================================= # 便捷属性 - 直接在初始化时获取常用聊天信息(与BaseAction保持一致) # ============================================================================= - + # 获取聊天流对象 self.chat_stream = chat_stream self.chat_id = self.chat_stream.stream_id if self.chat_stream else None diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 576f830c..3fe62937 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -346,9 +346,7 @@ class EventsManager: if not isinstance(result, tuple) or len(result) != 5: if isinstance(result, tuple): - annotated = ", ".join( - f"{name}={val!r}" for name, val in zip(expected_fields, result) - ) + annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False)) actual_desc = f"{len(result)} 个元素 ({annotated})" else: actual_desc = f"非 tuple 类型: {type(result)}" @@ -380,7 +378,6 @@ class EventsManager: logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True) return True, None # 发生异常时默认不中断其他处理 - def _task_done_callback( self, task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]], diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index aad7cad6..ed6dd070 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -180,9 +180,8 @@ class ToolExecutor: tool_info["content"] = str(content) # 空内容直接跳过(空字符串、全空白字符串、空列表/空元组) content_check = tool_info["content"] - if ( - (isinstance(content_check, str) and not content_check.strip()) - or (isinstance(content_check, (list, tuple)) and len(content_check) == 0) + if (isinstance(content_check, str) and not content_check.strip()) or ( + isinstance(content_check, (list, tuple)) and len(content_check) == 0 ): logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示") continue diff --git a/view_pkl.py b/view_pkl.py index 0897e174..2d50681b 100644 --- a/view_pkl.py +++ b/view_pkl.py @@ -8,29 +8,30 @@ import sys import os from pprint import pprint + def view_pkl_file(file_path): """查看 pkl 文件内容""" if not os.path.exists(file_path): print(f"❌ 文件不存在: {file_path}") return - + try: - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: data = pickle.load(f) - + print(f"📁 文件: {file_path}") print(f"📊 数据类型: {type(data)}") print("=" * 50) - + if isinstance(data, dict): print("🔑 字典键:") for key in data.keys(): print(f" - {key}: {type(data[key])}") print() - + print("📋 详细内容:") pprint(data, width=120, depth=10) - + elif isinstance(data, list): print(f"📝 列表长度: {len(data)}") if data: @@ -38,16 +39,16 @@ def view_pkl_file(file_path): print("📋 前几个元素:") for i, item in enumerate(data[:3]): print(f" [{i}]: {item}") - + else: print("📋 内容:") pprint(data, width=120, depth=10) - + # 如果是 expressor 模型,特别显示 token_counts 的详细信息 - if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']: - print("\n" + "="*50) + if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]: + print("\n" + "=" * 50) print("🔍 详细词汇统计 (token_counts):") - token_counts = data['nb']['token_counts'] + token_counts = data["nb"]["token_counts"] for style_id, tokens in token_counts.items(): print(f"\n📝 {style_id}:") if tokens: @@ -59,18 +60,20 @@ def view_pkl_file(file_path): print(f" ... 还有 {len(sorted_tokens) - 10} 个词") else: print(" (无词汇数据)") - + except Exception as e: print(f"❌ 读取文件失败: {e}") + def main(): if len(sys.argv) != 2: print("用法: python view_pkl.py ") print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl") return - + file_path = sys.argv[1] view_pkl_file(file_path) + if __name__ == "__main__": main() diff --git a/view_tokens.py b/view_tokens.py index 03fe8992..02ca1ea0 100644 --- a/view_tokens.py +++ b/view_tokens.py @@ -7,57 +7,60 @@ import pickle import sys import os + def view_token_counts(file_path): """查看 expressor.pkl 文件中的词汇统计""" if not os.path.exists(file_path): print(f"❌ 文件不存在: {file_path}") return - + try: - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: data = pickle.load(f) - + print(f"📁 文件: {file_path}") print("=" * 60) - - if 'nb' not in data or 'token_counts' not in data['nb']: + + if "nb" not in data or "token_counts" not in data["nb"]: print("❌ 这不是一个 expressor 模型文件") return - - token_counts = data['nb']['token_counts'] - candidates = data.get('candidates', {}) - + + token_counts = data["nb"]["token_counts"] + candidates = data.get("candidates", {}) + print(f"🎯 找到 {len(token_counts)} 个风格") print("=" * 60) - + for style_id, tokens in token_counts.items(): style_text = candidates.get(style_id, "未知风格") print(f"\n📝 {style_id}: {style_text}") print(f"📊 词汇数量: {len(tokens)}") - + if tokens: # 按词频排序 sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True) - + print("🔤 词汇统计 (按频率排序):") for i, (word, count) in enumerate(sorted_tokens): - print(f" {i+1:2d}. '{word}': {count}") + print(f" {i + 1:2d}. '{word}': {count}") else: print(" (无词汇数据)") - + print("-" * 40) - + except Exception as e: print(f"❌ 读取文件失败: {e}") + def main(): if len(sys.argv) != 2: print("用法: python view_tokens.py ") print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl") return - + file_path = sys.argv[1] view_token_counts(file_path) + if __name__ == "__main__": main()